diff --git a/MLXServer.xcodeproj/project.pbxproj b/MLXServer.xcodeproj/project.pbxproj index 98b6e78..6b079c0 100644 --- a/MLXServer.xcodeproj/project.pbxproj +++ b/MLXServer.xcodeproj/project.pbxproj @@ -29,14 +29,17 @@ 621B7E4382199AC1378F5F9C /* StatusBarView.swift in Sources */ = {isa = PBXBuildFile; fileRef = B0EAB35D7130D56B9E7484BA /* StatusBarView.swift */; }; 67262C5E24739F1FE0011439 /* StreamingSSEEncoder.swift in Sources */ = {isa = PBXBuildFile; fileRef = 615F8A7C9ABCADEB215D31BD /* StreamingSSEEncoder.swift */; }; 67B815DC3304BF4B2E9974A8 /* LiveCountersTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = 7E7DF9F68C10C718844B7B01 /* LiveCountersTests.swift */; }; + 67D0628F148FE3C2200E0AEF /* APIServerResponseResolutionTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = 051FEC14CC76A677F79ACD21 /* APIServerResponseResolutionTests.swift */; }; 6828CCA8B78AB40906F87CAB /* LocalModelResolver.swift in Sources */ = {isa = PBXBuildFile; fileRef = D733A0D1D4AC25DDDA6C8684 /* LocalModelResolver.swift */; }; 741692862DB1F13EA0B2D14D /* TokenPrefixCache.swift in Sources */ = {isa = PBXBuildFile; fileRef = 1962D530BEABCC7F1E8E0ED1 /* TokenPrefixCache.swift */; }; 7CD765C1E2F9F4D7504C8D09 /* Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = B629DA084A9A40E54F8EA5FA /* Assets.xcassets */; }; 80646C5066BF79BC76E1D9D7 /* ModelConfig.swift in Sources */ = {isa = PBXBuildFile; fileRef = 38DFC212AF4359A45FBE22BA /* ModelConfig.swift */; }; + 834B49AA3E30A1FED549D057 /* ToolCallParserTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = B89226C9ED585A5296C54441 /* ToolCallParserTests.swift */; }; 84D32315B418B5243E017350 /* ToolPromptBuilder.swift in Sources */ = {isa = PBXBuildFile; fileRef = 16AE82A64D1D07AE3CD8D33A /* ToolPromptBuilder.swift */; }; 85FB1EB49D76A9F21E181346 /* ChatScene.swift in Sources */ = {isa = PBXBuildFile; fileRef = C04EE8E6418EC6E9B66999B0 /* ChatScene.swift */; }; 8E665E21CCCD87A907CEA78D /* ModelBackedInferenceValidationTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = D388BE00B42C06ED9D9905BF /* ModelBackedInferenceValidationTests.swift */; }; 945474365D0B3E961811909A /* MLXVLM in Frameworks */ = {isa = PBXBuildFile; productRef = D5E8E1C2DD8D8AABB4306193 /* MLXVLM */; }; + 95A612524552AF5CC3B1AE62 /* ChatViewModelTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = B758F596F4F3E68793B045BB /* ChatViewModelTests.swift */; }; 962083CCCC4AC848E0BBBC99 /* CancellationTokenTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = FEFF6168B2283FEC87B4BB8C /* CancellationTokenTests.swift */; }; A146BBA70CFBEC505BDCDF0D /* ImageDecoder.swift in Sources */ = {isa = PBXBuildFile; fileRef = 7C1A89C076E717F87A60397D /* ImageDecoder.swift */; }; AA17474A72C7F4EFBD5C4925 /* PromptBuilder.swift in Sources */ = {isa = PBXBuildFile; fileRef = E1E62624B6F285479CB33041 /* PromptBuilder.swift */; }; @@ -73,6 +76,7 @@ /* Begin PBXFileReference section */ 02EBDE0C72D1C5CE220E5B93 /* InferenceEngine.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = InferenceEngine.swift; sourceTree = ""; }; + 051FEC14CC76A677F79ACD21 /* APIServerResponseResolutionTests.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = APIServerResponseResolutionTests.swift; sourceTree = ""; }; 0F03A123A8908714A89315FE /* SceneCommands.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = SceneCommands.swift; sourceTree = ""; }; 145B888FBDD4F931512C5473 /* Preferences.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Preferences.swift; sourceTree = ""; }; 1607BDDE53C575627DCC6896 /* ChatDocumentManifest.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ChatDocumentManifest.swift; sourceTree = ""; }; @@ -103,6 +107,8 @@ B0EAB35D7130D56B9E7484BA /* StatusBarView.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = StatusBarView.swift; sourceTree = ""; }; B5B5ABDEB6F5C54856EB1A9E /* SceneSelectionView.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = SceneSelectionView.swift; sourceTree = ""; }; B629DA084A9A40E54F8EA5FA /* Assets.xcassets */ = {isa = PBXFileReference; lastKnownFileType = folder.assetcatalog; path = Assets.xcassets; sourceTree = ""; }; + B758F596F4F3E68793B045BB /* ChatViewModelTests.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ChatViewModelTests.swift; sourceTree = ""; }; + B89226C9ED585A5296C54441 /* ToolCallParserTests.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ToolCallParserTests.swift; sourceTree = ""; }; B8BD93859F0291F1A3E09DA5 /* ChatViewModel.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ChatViewModel.swift; sourceTree = ""; }; BA1592FD260014C4FBDB6995 /* SceneManagementWindow.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = SceneManagementWindow.swift; sourceTree = ""; }; C04EE8E6418EC6E9B66999B0 /* ChatScene.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ChatScene.swift; sourceTree = ""; }; @@ -174,14 +180,17 @@ 154AF0C071A7DC02EB5F6F49 /* Server */ = { isa = PBXGroup; children = ( + 051FEC14CC76A677F79ACD21 /* APIServerResponseResolutionTests.swift */, E43535D68448F1752D91C3A9 /* APIServerRewriteTests.swift */, FEFF6168B2283FEC87B4BB8C /* CancellationTokenTests.swift */, + B758F596F4F3E68793B045BB /* ChatViewModelTests.swift */, E4573DC9314915F4C7963B4E /* ImageDecoderTests.swift */, 7E7DF9F68C10C718844B7B01 /* LiveCountersTests.swift */, D388BE00B42C06ED9D9905BF /* ModelBackedInferenceValidationTests.swift */, 5F9426FA5A4AC55F8D9C080E /* PromptBuilderTests.swift */, 49C383DD5224F3420EB98DB2 /* StreamingSSEEncoderTests.swift */, 64B2EDD5D1881AC9E1E60913 /* TokenPrefixCacheTests.swift */, + B89226C9ED585A5296C54441 /* ToolCallParserTests.swift */, ); path = Server; sourceTree = ""; @@ -382,14 +391,17 @@ isa = PBXSourcesBuildPhase; buildActionMask = 2147483647; files = ( + 67D0628F148FE3C2200E0AEF /* APIServerResponseResolutionTests.swift in Sources */, CBC9DB0799C4ADF2DC9319DA /* APIServerRewriteTests.swift in Sources */, 962083CCCC4AC848E0BBBC99 /* CancellationTokenTests.swift in Sources */, + 95A612524552AF5CC3B1AE62 /* ChatViewModelTests.swift in Sources */, E92B6656C251EDA246B8F582 /* ImageDecoderTests.swift in Sources */, 67B815DC3304BF4B2E9974A8 /* LiveCountersTests.swift in Sources */, 8E665E21CCCD87A907CEA78D /* ModelBackedInferenceValidationTests.swift in Sources */, 1FE8C624898960ECCE39C0D4 /* PromptBuilderTests.swift in Sources */, FE4405F66873C75CD6FA19A5 /* StreamingSSEEncoderTests.swift in Sources */, 221DEC86374902FCFD661A01 /* TokenPrefixCacheTests.swift in Sources */, + 834B49AA3E30A1FED549D057 /* ToolCallParserTests.swift in Sources */, ); runOnlyForDeploymentPostprocessing = 0; }; diff --git a/MLXServer/Server/APIServer.swift b/MLXServer/Server/APIServer.swift index d48eeea..c0d299b 100644 --- a/MLXServer/Server/APIServer.swift +++ b/MLXServer/Server/APIServer.swift @@ -728,7 +728,7 @@ final class APIServer { return text.isEmpty ? nil : text } - private static func resolveAssistantResponse( + static func resolveAssistantResponse( fullText: String, frameworkToolCalls: [MLXLMCommon.ToolCall], tools: [APIToolDefinition]? diff --git a/MLXServer/Server/TokenPrefixCache.swift b/MLXServer/Server/TokenPrefixCache.swift index f324033..8141135 100644 --- a/MLXServer/Server/TokenPrefixCache.swift +++ b/MLXServer/Server/TokenPrefixCache.swift @@ -459,9 +459,17 @@ final class TokenPrefixCache: @unchecked Sendable { private static func computeMemoryBudget() -> Int { guard let device = MTLCreateSystemDefaultDevice() else { + return computeMemoryBudget(recommendedWorkingSetSize: nil) + } + return computeMemoryBudget(recommendedWorkingSetSize: Int(device.recommendedMaxWorkingSetSize)) + } + + static func computeMemoryBudget(recommendedWorkingSetSize: Int?) -> Int { + guard let recommendedWorkingSetSize else { return 512 * 1024 * 1024 } - let budget = Int(Double(device.recommendedMaxWorkingSetSize) * 0.20) + + let budget = Int(Double(recommendedWorkingSetSize) * 0.20) return max(256 * 1024 * 1024, min(budget, 8 * 1024 * 1024 * 1024)) } diff --git a/MLXServerTests/Server/APIServerResponseResolutionTests.swift b/MLXServerTests/Server/APIServerResponseResolutionTests.swift new file mode 100644 index 0000000..9c05e37 --- /dev/null +++ b/MLXServerTests/Server/APIServerResponseResolutionTests.swift @@ -0,0 +1,44 @@ +import MLXLMCommon +import XCTest +@testable import MLX_Server + +final class APIServerResponseResolutionTests: XCTestCase { + @MainActor + func testResolveAssistantResponseUsesFrameworkToolCalls() throws { + let frameworkToolCalls = [ + ToolCall(function: ToolCall.Function(name: "weather", arguments: ["city": "Berlin"])) + ] + + let resolved = APIServer.resolveAssistantResponse( + fullText: "I will call the tool.", + frameworkToolCalls: frameworkToolCalls, + tools: [mockWeatherTool] + ) + + XCTAssertEqual(resolved.finishReason, "tool_calls") + XCTAssertEqual(resolved.content, "I will call the tool.") + let toolCall = try XCTUnwrap(resolved.toolCalls?.first) + XCTAssertEqual(toolCall.function.name, "weather") + XCTAssertEqual(toolCall.function.arguments, #"{"city":"Berlin"}"#) + } + + private var mockWeatherTool: APIToolDefinition { + APIToolDefinition( + type: "function", + function: APIFunctionDefinition( + name: "weather", + description: "Look up weather for a city.", + parameters: [ + "type": AnyCodable("object"), + "properties": AnyCodable([ + "city": [ + "type": "string", + "description": "City name" + ] + ]), + "required": AnyCodable(["city"]) + ] + ) + ) + } +} diff --git a/MLXServerTests/Server/APIServerRewriteTests.swift b/MLXServerTests/Server/APIServerRewriteTests.swift index 06a173f..91a49e0 100644 --- a/MLXServerTests/Server/APIServerRewriteTests.swift +++ b/MLXServerTests/Server/APIServerRewriteTests.swift @@ -3,6 +3,61 @@ import XCTest @testable import MLX_Server final class APIServerRewriteTests: XCTestCase { + func testQwenNonStreamingChatCompletionCachesAndReusesPrompt() async throws { + let harness = try await makeHarness(initialModelId: "qwen") + defer { harness.stop() } + + let lookups = LookupEventCollector() + APIServer.debugLookupEventHandler = { event in + Task { + await lookups.record(event) + } + } + defer { + APIServer.debugLookupEventHandler = nil + } + + let request = APIChatCompletionRequest( + model: "qwen", + messages: [ + APIChatMessage(role: "user", content: .text("Reply with exactly one short word."), name: nil, tool_calls: nil, tool_call_id: nil) + ], + temperature: 0, + top_p: 1, + max_tokens: 1, + stream: false, + stop: nil, + tools: nil, + tool_choice: nil, + frequency_penalty: nil, + presence_penalty: nil, + n: nil + ) + + let firstResponse = try await sendChatCompletion(request, port: harness.port) + XCTAssertEqual(firstResponse.choices.count, 1) + + try await waitUntil(timeoutSeconds: 5) { + let snapshot = TokenPrefixCache.shared.snapshot() + return snapshot.totalEntries > 0 && snapshot.entries.allSatisfy { $0.modelId == "qwen" } + } + + let firstSnapshot = TokenPrefixCache.shared.snapshot() + _ = try await sendChatCompletion(request, port: harness.port) + + try await waitUntil(timeoutSeconds: 5) { + let events = await lookups.events() + return events.count >= 2 && TokenPrefixCache.shared.snapshot().totalHits > firstSnapshot.totalHits + } + + let secondSnapshot = TokenPrefixCache.shared.snapshot() + let events = await lookups.events() + let secondLookup = try XCTUnwrap(events.last) + XCTAssertGreaterThan(secondSnapshot.totalHits, firstSnapshot.totalHits) + XCTAssertTrue(secondLookup.isHit) + XCTAssertGreaterThan(secondLookup.matchedTokenCount, 0) + } + func testHealthAndModelsEndpointsReturnExpectedPayloads() async throws { let harness = try await makeHarness() defer { harness.stop() } @@ -69,6 +124,16 @@ final class APIServerRewriteTests: XCTestCase { let harness = try await makeHarness() defer { harness.stop() } + let lookups = LookupEventCollector() + APIServer.debugLookupEventHandler = { event in + Task { + await lookups.record(event) + } + } + defer { + APIServer.debugLookupEventHandler = nil + } + let request = APIChatCompletionRequest( model: "gemma", messages: [ @@ -89,10 +154,15 @@ final class APIServerRewriteTests: XCTestCase { _ = try await sendChatCompletion(request, port: harness.port) _ = try await sendChatCompletion(request, port: harness.port) - let live = LiveCounters.shared.snapshot() - XCTAssertGreaterThan(live.currentCacheMatchedPromptTokens, 0) - XCTAssertEqual(live.currentCacheMatchedPromptTokens, live.promptTokens) - XCTAssertEqual(live.currentCacheRebuiltPromptTokens, 0) + try await waitUntil(timeoutSeconds: 5) { + let events = await lookups.events() + return events.count >= 2 + } + + let events = await lookups.events() + let secondLookup = try XCTUnwrap(events.last) + XCTAssertTrue(secondLookup.isHit) + XCTAssertEqual(secondLookup.matchedTokenCount, secondLookup.promptTokenCount) } func testSingleTurnContinuationProducesPartialCacheHit() async throws { @@ -365,6 +435,91 @@ final class APIServerRewriteTests: XCTestCase { XCTAssertEqual(live.currentCacheMatchedPromptTokens, 0) } + func testRequestModelFieldSwapsFromGemmaToQwenAndInvalidatesGemmaCache() async throws { + let harness = try await makeHarness(initialModelId: "gemma") + defer { harness.stop() } + + let lookups = LookupEventCollector() + APIServer.debugLookupEventHandler = { event in + Task { + await lookups.record(event) + } + } + defer { + APIServer.debugLookupEventHandler = nil + } + + let gemmaRequest = APIChatCompletionRequest( + model: "gemma", + messages: [ + APIChatMessage(role: "user", content: .text("Answer with one word: river."), name: nil, tool_calls: nil, tool_call_id: nil) + ], + temperature: 0, + top_p: 1, + max_tokens: 2, + stream: false, + stop: nil, + tools: nil, + tool_choice: nil, + frequency_penalty: nil, + presence_penalty: nil, + n: nil + ) + + _ = try await sendChatCompletion(gemmaRequest, port: harness.port) + try await waitUntil(timeoutSeconds: 5) { + TokenPrefixCache.shared.snapshot().entries.contains(where: { $0.modelId == "gemma" }) + } + + let qwenRequest = APIChatCompletionRequest( + model: "qwen", + messages: [ + APIChatMessage(role: "user", content: .text("Answer with one word: river."), name: nil, tool_calls: nil, tool_call_id: nil) + ], + temperature: 0, + top_p: 1, + max_tokens: 2, + stream: false, + stop: nil, + tools: nil, + tool_choice: nil, + frequency_penalty: nil, + presence_penalty: nil, + n: nil + ) + + _ = try await sendChatCompletion(qwenRequest, port: harness.port) + + try await waitUntil(timeoutSeconds: 5) { + let snapshot = TokenPrefixCache.shared.snapshot() + let modelId = await MainActor.run { harness.modelManager.currentModel?.id } + return modelId == "qwen" + && !snapshot.entries.isEmpty + && snapshot.entries.allSatisfy { $0.modelId == "qwen" } + } + + let afterSwapSnapshot = TokenPrefixCache.shared.snapshot() + let afterSwapEvents = await lookups.events() + let firstQwenLookup = try XCTUnwrap(afterSwapEvents.last) + XCTAssertTrue(afterSwapSnapshot.entries.allSatisfy { $0.modelId == "qwen" }) + XCTAssertFalse(firstQwenLookup.isHit) + XCTAssertEqual(firstQwenLookup.matchedTokenCount, 0) + + _ = try await sendChatCompletion(qwenRequest, port: harness.port) + + try await waitUntil(timeoutSeconds: 5) { + let events = await lookups.events() + return events.count >= 3 && TokenPrefixCache.shared.snapshot().totalHits > afterSwapSnapshot.totalHits + } + + let finalSnapshot = TokenPrefixCache.shared.snapshot() + let finalEvents = await lookups.events() + let secondQwenLookup = try XCTUnwrap(finalEvents.last) + XCTAssertGreaterThan(finalSnapshot.totalHits, afterSwapSnapshot.totalHits) + XCTAssertTrue(secondQwenLookup.isHit) + XCTAssertGreaterThan(secondQwenLookup.matchedTokenCount, 0) + } + func testStreamingChatCompletionReusesCacheAcrossThreeProgressivelyLongerTurns() async throws { let harness = try await makeHarness() defer { harness.stop() } @@ -775,6 +930,130 @@ final class APIServerRewriteTests: XCTestCase { XCTAssertGreaterThan(finalLiveSnapshot.totalCacheReusePromptTokens, afterDisconnectLiveSnapshot.totalCacheReusePromptTokens) } + func testStreamingDisconnectStopsServerWorkWithinTwoHundredMilliseconds() async throws { + let harness = try await makeHarness() + defer { harness.stop() } + + let request = APIChatCompletionRequest( + model: "gemma", + messages: [ + APIChatMessage(role: "user", content: .text("Count from one to fifty with commas, using many tokens."), name: nil, tool_calls: nil, tool_call_id: nil) + ], + temperature: 0, + top_p: 1, + max_tokens: 128, + stream: true, + stop: nil, + tools: nil, + tool_choice: nil, + frequency_penalty: nil, + presence_penalty: nil, + n: nil + ) + + let url = URL(string: "http://127.0.0.1:\(harness.port)/v1/chat/completions")! + var urlRequest = URLRequest(url: url) + urlRequest.httpMethod = "POST" + urlRequest.setValue("application/json", forHTTPHeaderField: "Content-Type") + urlRequest.httpBody = try JSONEncoder().encode(request) + + let observer = StreamCancellationObserver() + let session = URLSession(configuration: .ephemeral) + let baselineDisconnects = LiveCounters.shared.snapshot().totalDisconnects + let task = Task { + let (bytes, response) = try await session.bytes(for: urlRequest) + let httpResponse = try XCTUnwrap(response as? HTTPURLResponse) + XCTAssertEqual(httpResponse.statusCode, 200) + + for try await line in bytes.lines { + guard line.hasPrefix("data: ") else { continue } + let payload = String(line.dropFirst(6)) + if payload == "[DONE]" { + break + } + guard let data = payload.data(using: .utf8) else { continue } + let chunk = try JSONDecoder().decode(APIChatCompletionChunk.self, from: data) + if let deltaContent = chunk.choices.first?.delta.content, !deltaContent.isEmpty { + await observer.markFirstContentSeen() + try await Task.sleep(nanoseconds: 30_000_000_000) + } + } + } + + try await waitUntil(timeoutSeconds: 10) { + await observer.hasSeenFirstContent + } + + let disconnectStartedAt = Date() + session.invalidateAndCancel() + task.cancel() + + try await waitUntil(timeoutSeconds: 5, intervalNanoseconds: 10_000_000) { + let snapshot = LiveCounters.shared.snapshot() + return snapshot.totalDisconnects > baselineDisconnects && snapshot.activeRequests == 0 + } + + _ = try? await task.value + let elapsed = Date().timeIntervalSince(disconnectStartedAt) + XCTAssertLessThan(elapsed, 0.2) + } + + func testRepeatedStreamingDisconnectsDoNotBreakSubsequentGeneration() async throws { + let harness = try await makeHarness() + defer { harness.stop() } + + let request = APIChatCompletionRequest( + model: "gemma", + messages: [ + APIChatMessage(role: "user", content: .text("Count from one to forty with commas, using many tokens."), name: nil, tool_calls: nil, tool_call_id: nil) + ], + temperature: 0, + top_p: 1, + max_tokens: 96, + stream: true, + stop: nil, + tools: nil, + tool_choice: nil, + frequency_penalty: nil, + presence_penalty: nil, + n: nil + ) + + for expectedDisconnectCount in 1...3 { + try await cancelStreamingChatCompletionAfterFirstContentAndWaitForServerDisconnect( + request, + port: harness.port, + expectedDisconnectCount: expectedDisconnectCount + ) + + let liveSnapshot = LiveCounters.shared.snapshot() + XCTAssertEqual(liveSnapshot.totalDisconnects, expectedDisconnectCount) + XCTAssertEqual(liveSnapshot.activeRequests, 0) + } + + let recoveryRequest = APIChatCompletionRequest( + model: "gemma", + messages: [ + APIChatMessage(role: "user", content: .text("Reply with exactly one short word."), name: nil, tool_calls: nil, tool_call_id: nil) + ], + temperature: 0, + top_p: 1, + max_tokens: 2, + stream: false, + stop: nil, + tools: nil, + tool_choice: nil, + frequency_penalty: nil, + presence_penalty: nil, + n: nil + ) + + let response = try await sendChatCompletion(recoveryRequest, port: harness.port) + XCTAssertEqual(response.choices.count, 1) + XCTAssertEqual(response.choices[0].message.role, "assistant") + XCTAssertFalse((response.choices[0].message.content ?? "").trimmingCharacters(in: .whitespacesAndNewlines).isEmpty) + } + func testStreamingToolCallChunksArriveInOpenAICompatibleOrder() async throws { let harness = try await makeHarness() defer { harness.stop() } @@ -846,9 +1125,9 @@ final class APIServerRewriteTests: XCTestCase { ) } - private func makeHarness() async throws -> TestHarness { + private func makeHarness(initialModelId: String = "gemma") async throws -> TestHarness { let modelManager = await MainActor.run { ModelManager() } - let config = try XCTUnwrap(ModelConfig.resolve("gemma")) + let config = try XCTUnwrap(ModelConfig.resolve(initialModelId)) LiveCounters.shared.reset() TokenPrefixCache.shared.reset() @@ -994,6 +1273,19 @@ final class APIServerRewriteTests: XCTestCase { _ = try? await task.value } + private func cancelStreamingChatCompletionAfterFirstContentAndWaitForServerDisconnect( + _ request: APIChatCompletionRequest, + port: UInt16, + expectedDisconnectCount: Int + ) async throws { + try await cancelStreamingChatCompletionAfterFirstContent(request, port: port) + + try await waitUntil(timeoutSeconds: 5, intervalNanoseconds: 10_000_000) { + let snapshot = LiveCounters.shared.snapshot() + return snapshot.totalDisconnects >= expectedDisconnectCount && snapshot.activeRequests == 0 + } + } + private func waitUntil( timeoutSeconds: TimeInterval, intervalNanoseconds: UInt64 = 100_000_000, diff --git a/MLXServerTests/Server/ChatViewModelTests.swift b/MLXServerTests/Server/ChatViewModelTests.swift new file mode 100644 index 0000000..f9a7ed1 --- /dev/null +++ b/MLXServerTests/Server/ChatViewModelTests.swift @@ -0,0 +1,46 @@ +import XCTest +@testable import MLX_Server + +@MainActor +final class ChatViewModelTests: XCTestCase { + func testGemmaChatViewModelSendProducesAssistantReply() async throws { + let modelManager = ModelManager() + let config = try XCTUnwrap(ModelConfig.resolve("gemma")) + await modelManager.loadModel(config) + defer { modelManager.unloadModel() } + + XCTAssertTrue(modelManager.isReady) + + let viewModel = ChatViewModel(modelManager: modelManager) + viewModel.inputText = "Say hello in one word." + viewModel.send() + + XCTAssertTrue(viewModel.isGenerating) + + try await waitUntil(timeoutSeconds: 15) { + !viewModel.isGenerating + } + + XCTAssertEqual(viewModel.conversation.messages.count, 2) + XCTAssertEqual(viewModel.conversation.messages[0].role, .user) + XCTAssertEqual(viewModel.conversation.messages[0].content, "Say hello in one word.") + XCTAssertEqual(viewModel.conversation.messages[1].role, .assistant) + XCTAssertFalse(viewModel.conversation.messages[1].sessionContent.isEmpty) + XCTAssertGreaterThan(viewModel.promptTokens, 0) + } + + private func waitUntil( + timeoutSeconds: TimeInterval, + intervalNanoseconds: UInt64 = 100_000_000, + condition: @escaping @MainActor () -> Bool + ) async throws { + let deadline = Date().addingTimeInterval(timeoutSeconds) + while Date() < deadline { + if condition() { + return + } + try await Task.sleep(nanoseconds: intervalNanoseconds) + } + XCTFail("Condition not met before timeout") + } +} diff --git a/MLXServerTests/Server/TokenPrefixCacheTests.swift b/MLXServerTests/Server/TokenPrefixCacheTests.swift index ac5f923..6435cc0 100644 --- a/MLXServerTests/Server/TokenPrefixCacheTests.swift +++ b/MLXServerTests/Server/TokenPrefixCacheTests.swift @@ -209,4 +209,28 @@ final class TokenPrefixCacheTests: XCTestCase { XCTAssertEqual(snapshot.supersequenceHits, 0) XCTAssertEqual(snapshot.lcpHits, 0) } + + func testComputeMemoryBudgetUsesFallbackWhenDeviceUnavailable() { + let budget = TokenPrefixCache.computeMemoryBudget(recommendedWorkingSetSize: nil) + + XCTAssertEqual(budget, 512 * 1024 * 1024) + } + + func testComputeMemoryBudgetClampsToMinimumFloor() { + let budget = TokenPrefixCache.computeMemoryBudget(recommendedWorkingSetSize: 512 * 1024 * 1024) + + XCTAssertEqual(budget, 256 * 1024 * 1024) + } + + func testComputeMemoryBudgetUsesTwentyPercentOfWorkingSet() { + let budget = TokenPrefixCache.computeMemoryBudget(recommendedWorkingSetSize: 8 * 1024 * 1024 * 1024) + + XCTAssertEqual(budget, Int(Double(8 * 1024 * 1024 * 1024) * 0.20)) + } + + func testComputeMemoryBudgetClampsToMaximumCap() { + let budget = TokenPrefixCache.computeMemoryBudget(recommendedWorkingSetSize: 80 * 1024 * 1024 * 1024) + + XCTAssertEqual(budget, 8 * 1024 * 1024 * 1024) + } } \ No newline at end of file diff --git a/MLXServerTests/Server/ToolCallParserTests.swift b/MLXServerTests/Server/ToolCallParserTests.swift new file mode 100644 index 0000000..95263c2 --- /dev/null +++ b/MLXServerTests/Server/ToolCallParserTests.swift @@ -0,0 +1,47 @@ +import XCTest +@testable import MLX_Server + +final class ToolCallParserTests: XCTestCase { + func testParseGemmaToolCodeBlockExtractsToolCallAndStripsFence() throws { + let tools = [mockWeatherTool] + let text = "Before\n```tool_code\nweather(city=\"Berlin\")\n```\nAfter" + + let parsed = ToolCallParser.parse(text: text, tools: tools) + + XCTAssertEqual(parsed.0, "Before\n\nAfter") + let toolCall = try XCTUnwrap(parsed.1.first) + XCTAssertEqual(toolCall.name, "weather") + XCTAssertEqual(toolCall.arguments, #"{"city":"Berlin"}"#) + } + + func testParseQwenToolCallTagExtractsJSONPayloadAndStripsTag() throws { + let text = "{\"name\":\"weather\",\"arguments\":{\"city\":\"Paris\"}}" + + let parsed = ToolCallParser.parse(text: text, tools: [mockWeatherTool]) + + XCTAssertEqual(parsed.0, "") + let toolCall = try XCTUnwrap(parsed.1.first) + XCTAssertEqual(toolCall.name, "weather") + XCTAssertEqual(toolCall.arguments, #"{"city":"Paris"}"#) + } + + private var mockWeatherTool: APIToolDefinition { + APIToolDefinition( + type: "function", + function: APIFunctionDefinition( + name: "weather", + description: "Look up weather for a city.", + parameters: [ + "type": AnyCodable("object"), + "properties": AnyCodable([ + "city": [ + "type": "string", + "description": "City name" + ] + ]), + "required": AnyCodable(["city"]) + ] + ) + ) + } +} diff --git a/docs/session-cache-upgrade.md b/docs/session-cache-upgrade.md index 0b6f27c..8a2e63e 100644 --- a/docs/session-cache-upgrade.md +++ b/docs/session-cache-upgrade.md @@ -2614,12 +2614,12 @@ Validation note: `InferenceStats.swift` now samples `TokenPrefixCache` directly - [x] Conversation continuation (add 2+ messages, e.g. tool-use flow) → partial cache hit (not a miss!) - [x] Same system prompt, different user message → system prompt prefix cached and reused - [x] Different system prompt → no false cache hit -- [ ] Model swap → cache invalidated, fresh generation works +- [x] Model swap → cache invalidated, fresh generation works - [x] Idle unload + reload → cache invalidated, fresh generation works ### Memory Management -- [ ] Memory budget computed correctly from Metal device +- [x] Memory budget computed correctly from Metal device - [x] Entries evicted under memory pressure (oldest first) - [x] Expired entries pruned after 30 min idle - [x] Trie nodes cleaned up when entries are evicted (no memory leak) @@ -2627,9 +2627,9 @@ Validation note: `InferenceStats.swift` now samples `TokenPrefixCache` directly ### Disconnect Handling -- [ ] Client disconnects mid-stream → generation stops within ~200ms +- [x] Client disconnects mid-stream → generation stops within ~200ms - [x] Partial KV cache from disconnected request is still stored for reuse -- [ ] No Metal assertion failures on disconnect +- [x] No Metal assertion failures on disconnect ### Streaming @@ -2642,9 +2642,9 @@ Validation note: `InferenceStats.swift` now samples `TokenPrefixCache` directly ### Tool Use -- [ ] Gemma tool_code blocks parsed correctly -- [ ] Qwen `` tags parsed correctly -- [ ] Framework `ToolCall` events handled correctly +- [x] Gemma tool_code blocks parsed correctly +- [x] Qwen `` tags parsed correctly +- [x] Framework `ToolCall` events handled correctly - [x] Tool results round-trip correctly (user sends tool result → model sees it in context) - [x] finish_reason is "tool_calls" when tools are invoked @@ -2700,9 +2700,9 @@ Validation note: `InferenceStats.swift` now samples `TokenPrefixCache` directly ### Compatibility -- [ ] `GET /health` → `{"status":"ok"}` -- [ ] `GET /v1/models` → model list with context windows +- [x] `GET /health` → `{"status":"ok"}` +- [x] `GET /v1/models` → model list with context windows - [x] Non-streaming `POST /v1/chat/completions` → full response - [x] Streaming `POST /v1/chat/completions` → SSE stream -- [ ] Model field in request triggers model swap -- [ ] UI chat (ChatViewModel) completely unaffected +- [x] Model field in request triggers model swap +- [x] UI chat (ChatViewModel) completely unaffected