From e40a2f3c4583dae41a1d93b5ec5e912057a843e9 Mon Sep 17 00:00:00 2001 From: Chili Palmer Date: Fri, 20 Mar 2026 08:57:54 +0100 Subject: [PATCH] feat: implement phase 2 of session-cache-upgrade.md --- .vscode/settings.json | 5 + MLXServer.xcodeproj/project.pbxproj | 24 ++ MLXServer/Server/APIServer.swift | 109 +----- MLXServer/Server/InferenceEngine.swift | 66 ++++ MLXServer/Server/PromptBuilder.swift | 139 ++++++++ MLXServer/Server/TokenPrefixCache.swift | 317 ++++++++++++++++++ .../ModelBackedInferenceValidationTests.swift | 289 ++++++++++++++++ .../Server/PromptBuilderTests.swift | 288 ++++++++++++++++ .../Server/TokenPrefixCacheTests.swift | 130 +++++++ docs/session-cache-upgrade.md | 14 +- 10 files changed, 1282 insertions(+), 99 deletions(-) create mode 100644 .vscode/settings.json create mode 100644 MLXServer/Server/InferenceEngine.swift create mode 100644 MLXServer/Server/PromptBuilder.swift create mode 100644 MLXServer/Server/TokenPrefixCache.swift create mode 100644 MLXServerTests/Server/ModelBackedInferenceValidationTests.swift create mode 100644 MLXServerTests/Server/PromptBuilderTests.swift create mode 100644 MLXServerTests/Server/TokenPrefixCacheTests.swift diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..43b5fb2 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,5 @@ +{ + "chat.tools.terminal.autoApprove": { + "./test.sh": true + } +} \ No newline at end of file diff --git a/MLXServer.xcodeproj/project.pbxproj b/MLXServer.xcodeproj/project.pbxproj index b205fd9..6dd7437 100644 --- a/MLXServer.xcodeproj/project.pbxproj +++ b/MLXServer.xcodeproj/project.pbxproj @@ -12,7 +12,9 @@ 165E8AB6ADAE1D59B1A86420 /* Preferences.swift in Sources */ = {isa = PBXBuildFile; fileRef = 145B888FBDD4F931512C5473 /* Preferences.swift */; }; 189362AAE2CDE5D4B3428334 /* ToolCallParser.swift in Sources */ = {isa = PBXBuildFile; fileRef = E73B165A1822729C907791AE /* ToolCallParser.swift */; }; 1A8833E3CCD3289C95E282A2 /* ChatDocumentManifest.swift in Sources */ = {isa = PBXBuildFile; fileRef = 1607BDDE53C575627DCC6896 /* ChatDocumentManifest.swift */; }; + 1FE8C624898960ECCE39C0D4 /* PromptBuilderTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = 5F9426FA5A4AC55F8D9C080E /* PromptBuilderTests.swift */; }; 20FFB5DBF75AA6C359AAE31C /* SceneManagementView.swift in Sources */ = {isa = PBXBuildFile; fileRef = 37FEB592E5E717F817B03151 /* SceneManagementView.swift */; }; + 221DEC86374902FCFD661A01 /* TokenPrefixCacheTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = 64B2EDD5D1881AC9E1E60913 /* TokenPrefixCacheTests.swift */; }; 29879D696584B96CC56560DF /* ChatExporter.swift in Sources */ = {isa = PBXBuildFile; fileRef = D7C9BAD674E29688ACE53B0B /* ChatExporter.swift */; }; 2CAAF7129F7CC45200FA9F6B /* ModelPickerView.swift in Sources */ = {isa = PBXBuildFile; fileRef = C3C3A76C02AF70A9D8F868FC /* ModelPickerView.swift */; }; 2D08769282BD71C170DB0943 /* InferenceStats.swift in Sources */ = {isa = PBXBuildFile; fileRef = E35452B166893B25E765FF70 /* InferenceStats.swift */; }; @@ -27,13 +29,16 @@ 621B7E4382199AC1378F5F9C /* StatusBarView.swift in Sources */ = {isa = PBXBuildFile; fileRef = B0EAB35D7130D56B9E7484BA /* StatusBarView.swift */; }; 67262C5E24739F1FE0011439 /* StreamingSSEEncoder.swift in Sources */ = {isa = PBXBuildFile; fileRef = 615F8A7C9ABCADEB215D31BD /* StreamingSSEEncoder.swift */; }; 6828CCA8B78AB40906F87CAB /* LocalModelResolver.swift in Sources */ = {isa = PBXBuildFile; fileRef = D733A0D1D4AC25DDDA6C8684 /* LocalModelResolver.swift */; }; + 741692862DB1F13EA0B2D14D /* TokenPrefixCache.swift in Sources */ = {isa = PBXBuildFile; fileRef = 1962D530BEABCC7F1E8E0ED1 /* TokenPrefixCache.swift */; }; 7CD765C1E2F9F4D7504C8D09 /* Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = B629DA084A9A40E54F8EA5FA /* Assets.xcassets */; }; 80646C5066BF79BC76E1D9D7 /* ModelConfig.swift in Sources */ = {isa = PBXBuildFile; fileRef = 38DFC212AF4359A45FBE22BA /* ModelConfig.swift */; }; 84D32315B418B5243E017350 /* ToolPromptBuilder.swift in Sources */ = {isa = PBXBuildFile; fileRef = 16AE82A64D1D07AE3CD8D33A /* ToolPromptBuilder.swift */; }; 85FB1EB49D76A9F21E181346 /* ChatScene.swift in Sources */ = {isa = PBXBuildFile; fileRef = C04EE8E6418EC6E9B66999B0 /* ChatScene.swift */; }; + 8E665E21CCCD87A907CEA78D /* ModelBackedInferenceValidationTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = D388BE00B42C06ED9D9905BF /* ModelBackedInferenceValidationTests.swift */; }; 945474365D0B3E961811909A /* MLXVLM in Frameworks */ = {isa = PBXBuildFile; productRef = D5E8E1C2DD8D8AABB4306193 /* MLXVLM */; }; 962083CCCC4AC848E0BBBC99 /* CancellationTokenTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = FEFF6168B2283FEC87B4BB8C /* CancellationTokenTests.swift */; }; A146BBA70CFBEC505BDCDF0D /* ImageDecoder.swift in Sources */ = {isa = PBXBuildFile; fileRef = 7C1A89C076E717F87A60397D /* ImageDecoder.swift */; }; + AA17474A72C7F4EFBD5C4925 /* PromptBuilder.swift in Sources */ = {isa = PBXBuildFile; fileRef = E1E62624B6F285479CB33041 /* PromptBuilder.swift */; }; B13FFE238613BFBFC72E0CC8 /* ChatDocumentMigration.swift in Sources */ = {isa = PBXBuildFile; fileRef = 24E29065DD29C17D20B0400D /* ChatDocumentMigration.swift */; }; B1D9BC407DB7DB1489230C20 /* MonitorView.swift in Sources */ = {isa = PBXBuildFile; fileRef = 4239CFF94B819C35A8D4D617 /* MonitorView.swift */; }; B5AA6E3B4BE21676226B342B /* ChatViewModel.swift in Sources */ = {isa = PBXBuildFile; fileRef = B8BD93859F0291F1A3E09DA5 /* ChatViewModel.swift */; }; @@ -47,6 +52,7 @@ DF5C525DBD2E3153256951C1 /* SceneManagementWindow.swift in Sources */ = {isa = PBXBuildFile; fileRef = BA1592FD260014C4FBDB6995 /* SceneManagementWindow.swift */; }; E199D0BB09B61AC128AB093A /* CancellationToken.swift in Sources */ = {isa = PBXBuildFile; fileRef = 3489501F2F8E1BA382347CFA /* CancellationToken.swift */; }; E92B6656C251EDA246B8F582 /* ImageDecoderTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = E4573DC9314915F4C7963B4E /* ImageDecoderTests.swift */; }; + EC4FC68608DDFA6A3DF133CC /* InferenceEngine.swift in Sources */ = {isa = PBXBuildFile; fileRef = 02EBDE0C72D1C5CE220E5B93 /* InferenceEngine.swift */; }; F141B91A64F7DAD73CE2910A /* ConversationSessionCache.swift in Sources */ = {isa = PBXBuildFile; fileRef = FFBB16D3AF2E61D001FD6051 /* ConversationSessionCache.swift */; }; F546CE5955ED253D8A793D5E /* MarkdownUI in Frameworks */ = {isa = PBXBuildFile; productRef = A98257123539E9E738213BFA /* MarkdownUI */; }; FAF7D4714AC6D02674920208 /* ChatMessage.swift in Sources */ = {isa = PBXBuildFile; fileRef = A4B359324B5FD8D106C74338 /* ChatMessage.swift */; }; @@ -65,10 +71,12 @@ /* End PBXContainerItemProxy section */ /* Begin PBXFileReference section */ + 02EBDE0C72D1C5CE220E5B93 /* InferenceEngine.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = InferenceEngine.swift; sourceTree = ""; }; 0F03A123A8908714A89315FE /* SceneCommands.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = SceneCommands.swift; sourceTree = ""; }; 145B888FBDD4F931512C5473 /* Preferences.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Preferences.swift; sourceTree = ""; }; 1607BDDE53C575627DCC6896 /* ChatDocumentManifest.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ChatDocumentManifest.swift; sourceTree = ""; }; 16AE82A64D1D07AE3CD8D33A /* ToolPromptBuilder.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ToolPromptBuilder.swift; sourceTree = ""; }; + 1962D530BEABCC7F1E8E0ED1 /* TokenPrefixCache.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = TokenPrefixCache.swift; sourceTree = ""; }; 24E29065DD29C17D20B0400D /* ChatDocumentMigration.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ChatDocumentMigration.swift; sourceTree = ""; }; 2DC8C86D397B1FCA08E07CBD /* DownloadModalView.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = DownloadModalView.swift; sourceTree = ""; }; 2E2FCA55CEBEBCED78D9479A /* SaveChatCommands.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = SaveChatCommands.swift; sourceTree = ""; }; @@ -81,7 +89,9 @@ 4147321383E94E9F17A0154E /* SettingsView.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = SettingsView.swift; sourceTree = ""; }; 4239CFF94B819C35A8D4D617 /* MonitorView.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = MonitorView.swift; sourceTree = ""; }; 49C383DD5224F3420EB98DB2 /* StreamingSSEEncoderTests.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = StreamingSSEEncoderTests.swift; sourceTree = ""; }; + 5F9426FA5A4AC55F8D9C080E /* PromptBuilderTests.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = PromptBuilderTests.swift; sourceTree = ""; }; 615F8A7C9ABCADEB215D31BD /* StreamingSSEEncoder.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = StreamingSSEEncoder.swift; sourceTree = ""; }; + 64B2EDD5D1881AC9E1E60913 /* TokenPrefixCacheTests.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = TokenPrefixCacheTests.swift; sourceTree = ""; }; 6B3AA91D2C7842D7366F9A41 /* ChatDocumentPackage.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ChatDocumentPackage.swift; sourceTree = ""; }; 6EE59189918D06B8D2F588FC /* MLXServer.app */ = {isa = PBXFileReference; explicitFileType = wrapper.application; includeInIndex = 0; path = MLXServer.app; sourceTree = BUILT_PRODUCTS_DIR; }; 7C1A89C076E717F87A60397D /* ImageDecoder.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ImageDecoder.swift; sourceTree = ""; }; @@ -97,10 +107,12 @@ C234359924C542F07ED926A2 /* SceneStore.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = SceneStore.swift; sourceTree = ""; }; C3C3A76C02AF70A9D8F868FC /* ModelPickerView.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ModelPickerView.swift; sourceTree = ""; }; C67742651DB486871CEF1612 /* MLXServerApp.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = MLXServerApp.swift; sourceTree = ""; }; + D388BE00B42C06ED9D9905BF /* ModelBackedInferenceValidationTests.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ModelBackedInferenceValidationTests.swift; sourceTree = ""; }; D5C1FCEFEA72B9ABB87FB20E /* ChatDocumentController.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ChatDocumentController.swift; sourceTree = ""; }; D733A0D1D4AC25DDDA6C8684 /* LocalModelResolver.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = LocalModelResolver.swift; sourceTree = ""; }; D7C9BAD674E29688ACE53B0B /* ChatExporter.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ChatExporter.swift; sourceTree = ""; }; DB1A5E8B1C9F2BC4D262C53A /* ChatMessagesView.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ChatMessagesView.swift; sourceTree = ""; }; + E1E62624B6F285479CB33041 /* PromptBuilder.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = PromptBuilder.swift; sourceTree = ""; }; E35452B166893B25E765FF70 /* InferenceStats.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = InferenceStats.swift; sourceTree = ""; }; E4573DC9314915F4C7963B4E /* ImageDecoderTests.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ImageDecoderTests.swift; sourceTree = ""; }; E5E6AD02CDF23BDAB64700A7 /* ChatInputView.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ChatInputView.swift; sourceTree = ""; }; @@ -162,7 +174,10 @@ children = ( FEFF6168B2283FEC87B4BB8C /* CancellationTokenTests.swift */, E4573DC9314915F4C7963B4E /* ImageDecoderTests.swift */, + D388BE00B42C06ED9D9905BF /* ModelBackedInferenceValidationTests.swift */, + 5F9426FA5A4AC55F8D9C080E /* PromptBuilderTests.swift */, 49C383DD5224F3420EB98DB2 /* StreamingSSEEncoderTests.swift */, + 64B2EDD5D1881AC9E1E60913 /* TokenPrefixCacheTests.swift */, ); path = Server; sourceTree = ""; @@ -250,7 +265,10 @@ 3489501F2F8E1BA382347CFA /* CancellationToken.swift */, FFBB16D3AF2E61D001FD6051 /* ConversationSessionCache.swift */, 7C1A89C076E717F87A60397D /* ImageDecoder.swift */, + 02EBDE0C72D1C5CE220E5B93 /* InferenceEngine.swift */, + E1E62624B6F285479CB33041 /* PromptBuilder.swift */, 615F8A7C9ABCADEB215D31BD /* StreamingSSEEncoder.swift */, + 1962D530BEABCC7F1E8E0ED1 /* TokenPrefixCache.swift */, E73B165A1822729C907791AE /* ToolCallParser.swift */, 16AE82A64D1D07AE3CD8D33A /* ToolPromptBuilder.swift */, ); @@ -363,7 +381,10 @@ files = ( 962083CCCC4AC848E0BBBC99 /* CancellationTokenTests.swift in Sources */, E92B6656C251EDA246B8F582 /* ImageDecoderTests.swift in Sources */, + 8E665E21CCCD87A907CEA78D /* ModelBackedInferenceValidationTests.swift in Sources */, + 1FE8C624898960ECCE39C0D4 /* PromptBuilderTests.swift in Sources */, FE4405F66873C75CD6FA19A5 /* StreamingSSEEncoderTests.swift in Sources */, + 221DEC86374902FCFD661A01 /* TokenPrefixCacheTests.swift in Sources */, ); runOnlyForDeploymentPostprocessing = 0; }; @@ -389,6 +410,7 @@ C07A377244DCD67F4FE709FE /* DownloadModalView.swift in Sources */, 4DC033E45880B2948B47DEB1 /* FocusedValues.swift in Sources */, A146BBA70CFBEC505BDCDF0D /* ImageDecoder.swift in Sources */, + EC4FC68608DDFA6A3DF133CC /* InferenceEngine.swift in Sources */, 2D08769282BD71C170DB0943 /* InferenceStats.swift in Sources */, 6828CCA8B78AB40906F87CAB /* LocalModelResolver.swift in Sources */, 50B6861FF8610B3ED4FFAD9D /* MLXServerApp.swift in Sources */, @@ -397,6 +419,7 @@ 2CAAF7129F7CC45200FA9F6B /* ModelPickerView.swift in Sources */, B1D9BC407DB7DB1489230C20 /* MonitorView.swift in Sources */, 165E8AB6ADAE1D59B1A86420 /* Preferences.swift in Sources */, + AA17474A72C7F4EFBD5C4925 /* PromptBuilder.swift in Sources */, 4158FA884D981D73288FB74C /* SaveChatCommands.swift in Sources */, 07119250A7F9D6ECE7F6B8FD /* SceneCommands.swift in Sources */, 20FFB5DBF75AA6C359AAE31C /* SceneManagementView.swift in Sources */, @@ -406,6 +429,7 @@ D666A311788375E8A061C832 /* SettingsView.swift in Sources */, 621B7E4382199AC1378F5F9C /* StatusBarView.swift in Sources */, 67262C5E24739F1FE0011439 /* StreamingSSEEncoder.swift in Sources */, + 741692862DB1F13EA0B2D14D /* TokenPrefixCache.swift in Sources */, 189362AAE2CDE5D4B3428334 /* ToolCallParser.swift in Sources */, 84D32315B418B5243E017350 /* ToolPromptBuilder.swift in Sources */, ); diff --git a/MLXServer/Server/APIServer.swift b/MLXServer/Server/APIServer.swift index 669a3ff..77f2441 100644 --- a/MLXServer/Server/APIServer.swift +++ b/MLXServer/Server/APIServer.swift @@ -218,91 +218,16 @@ final class APIServer { } LiveCounters.shared.requestStarted(requestId: requestId, contextLength: contextLength) - - // Convert API messages to Chat.Message, extracting images from content parts - var chatMessages: [Chat.Message] = [] - var messageSignatures: [UInt64] = [] - var images: [UserInput.Image] = [] - var estimatedBytes = 0 let currentModelRepoId = currentModel?.repoId ?? modelName - // Build the instructions string (system prompt + tool definitions). - // This is passed to ChatSession via `instructions:` rather than injected - // as history messages, so it avoids an expensive history-replay prefill. - var instructions: String = "" - - // Collect system message text from the request - for msg in request.messages where msg.role == "system" { - let text = msg.content?.textContent ?? "" - if !text.isEmpty { - if !instructions.isEmpty { instructions += "\n\n" } - instructions += text - } - } - - // Append tool definitions to instructions - if let tools = request.tools, !tools.isEmpty { - let toolSystemPrompt = ToolPromptBuilder.buildSystemPrompt(tools: tools, modelId: currentModelRepoId) - if !instructions.isEmpty { instructions += "\n\n" } - instructions += toolSystemPrompt - } - + let preparedPrompt = PromptBuilder.build( + from: request, + modelId: currentModelRepoId, + thinkingEnabled: Preferences.enableThinking + ) let isQwen = currentModelRepoId.lowercased().contains("qwen") - estimatedBytes += instructions.utf8.count - // Convert non-system messages to Chat.Message - for msg in request.messages where msg.role != "system" { - let role: Chat.Message.Role = switch msg.role { - case "assistant": .assistant - case "tool": .user - default: .user - } - - var text = msg.content?.textContent ?? "" - - // Format tool_call_id responses as tool_output for the model - if msg.role == "tool" { - if isQwen { - // Qwen expects tool results as-is in a user message - // (the role is already mapped to .user above) - } else { - // Gemma expects tool results wrapped in ```tool_output``` blocks - text = "```tool_output\n\(text)\n```" - } - } - - // Format assistant tool_calls back into model-native format - if msg.role == "assistant", let toolCalls = msg.tool_calls, !toolCalls.isEmpty { - let formattedCalls: String - if isQwen { - formattedCalls = ToolPromptBuilder.formatQwenToolCalls(toolCalls) - } else { - formattedCalls = ToolPromptBuilder.formatGemmaToolCalls(toolCalls) - } - text = (text.isEmpty ? "" : text + "\n") + formattedCalls - } - - // Extract base64 images from content parts - let imageURLs = msg.content?.imageURLs ?? [] - var messageImages: [UserInput.Image] = [] - var messageImageBytes = 0 - for urlString in imageURLs { - if let decoded = ImageDecoder.decode(urlString) { - messageImages.append(decoded.image) - messageImageBytes += decoded.estimatedBytes - } - } - - // Attach images to this specific message - chatMessages.append(Chat.Message(role: role, content: text, images: messageImages)) - messageSignatures.append( - Self.messageSignature(role: role, content: text, imageURLs: imageURLs) - ) - estimatedBytes += text.utf8.count + messageImageBytes - images.append(contentsOf: messageImages) - } - - if !images.isEmpty, currentModel?.supportsImages != true { + if preparedPrompt.containsImages, currentModel?.supportsImages != true { LiveCounters.shared.requestCompleted(requestId: requestId, generationTokens: 0) sendResponse( connection: connection, @@ -313,7 +238,7 @@ final class APIServer { } // Context window check: estimate token count and reject if over limit - let estimatedPromptTokens = (instructions.count + chatMessages.reduce(0) { $0 + $1.content.count }) * 10 / 35 + let estimatedPromptTokens = preparedPrompt.estimatedPromptTokens if contextLength > 0 { let needed = estimatedPromptTokens + maxTokens if needed > contextLength { @@ -337,18 +262,19 @@ final class APIServer { ) // Feed all messages except the last as history, then send the last as the prompt + let chatMessages = preparedPrompt.chatMessages let allButLast = Array(chatMessages.dropLast()) let lastMessage = chatMessages.last ?? Chat.Message(role: .user, content: "") - let historySignatures = Array(messageSignatures.dropLast()) + let historySignatures = Array(preparedPrompt.messageSignatures.dropLast()) let currentModelId = modelManager.currentModel?.id ?? modelName let lease = ConversationSessionCache.shared.checkoutSession( modelId: currentModelId, - instructions: instructions, + instructions: preparedPrompt.instructions, historySignatures: historySignatures, requestMessageCount: chatMessages.count, estimatedPromptTokens: estimatedPromptTokens, - estimatedBytes: estimatedBytes + estimatedBytes: preparedPrompt.estimatedBytes ) let session: ChatSession @@ -365,24 +291,21 @@ final class APIServer { // Use `instructions:` for system/tool prompt (matches internal chat pattern). // Only conversation turns go in `history:` — this avoids replaying the // large tool prompt as history on every new session. - let instr = instructions.isEmpty ? nil : instructions - let thinkingContext: [String: any Sendable]? = Preferences.enableThinking - ? nil - : ["enable_thinking": false] + let instr = preparedPrompt.instructions.isEmpty ? nil : preparedPrompt.instructions if !allButLast.isEmpty { session = ChatSession( container, instructions: instr, history: allButLast, generateParameters: generateParams, - additionalContext: thinkingContext + additionalContext: preparedPrompt.additionalContext ) } else { session = ChatSession( container, instructions: instr, generateParameters: generateParams, - additionalContext: thinkingContext + additionalContext: preparedPrompt.additionalContext ) } ConversationSessionCache.shared.markPrefilling(entryId: lease.entryId) @@ -423,7 +346,7 @@ final class APIServer { } if result.succeeded { - var cachedSignatures = messageSignatures + var cachedSignatures = preparedPrompt.messageSignatures if let assistantHistoryText = result.assistantHistoryText { cachedSignatures.append( Self.messageSignature(role: .assistant, content: assistantHistoryText, imageURLs: []) @@ -435,7 +358,7 @@ final class APIServer { requestMessageSignatures: cachedSignatures, requestMessageCount: cachedSignatures.count, estimatedPromptTokens: estimatedPromptTokens, - estimatedBytes: estimatedBytes, + estimatedBytes: preparedPrompt.estimatedBytes, promptTokens: result.promptTokens, completionTokens: result.completionTokens ) diff --git a/MLXServer/Server/InferenceEngine.swift b/MLXServer/Server/InferenceEngine.swift new file mode 100644 index 0000000..36636b4 --- /dev/null +++ b/MLXServer/Server/InferenceEngine.swift @@ -0,0 +1,66 @@ +import MLX +import MLXLMCommon + +/// Stateless inference wrapper for the API path. +final class InferenceEngine: @unchecked Sendable { + private let container: ModelContainer + + init(container: ModelContainer) { + self.container = container + } + + struct InferenceRequest: @unchecked Sendable { + let input: LMInput + let tokens: [Int] + let parameters: GenerateParameters + let cachedKV: [KVCache]? + let cachedTokenCount: Int + } + + struct StreamHandle: @unchecked Sendable { + let stream: AsyncStream + let workingCache: [KVCache] + } + + struct PreparedInference: @unchecked Sendable { + let lmInput: LMInput + let tokens: [Int] + let hasImages: Bool + } + + func stream( + _ request: InferenceRequest, + cancellation: CancellationToken + ) async throws -> StreamHandle { + _ = cancellation + nonisolated(unsafe) let input = request.input + nonisolated(unsafe) let cachedKV = request.cachedKV + let parameters = request.parameters + + return try await container.perform { context in + let workingCache = cachedKV ?? context.model.newCache(parameters: parameters) + let stream = try MLXLMCommon.generate( + input: input, + cache: workingCache, + parameters: parameters, + context: context + ) + return StreamHandle(stream: stream, workingCache: workingCache) + } + } + + func prepare(_ userInput: UserInput) async throws -> PreparedInference { + nonisolated(unsafe) let input = userInput + let lmInput = try await container.prepare(input: input) + nonisolated(unsafe) let preparedInput = lmInput + let tokenArray: [Int] = await container.perform { _ in + preparedInput.text.tokens.asArray(Int.self) + } + + return PreparedInference( + lmInput: lmInput, + tokens: tokenArray, + hasImages: userInput.images.count > 0 + ) + } +} \ No newline at end of file diff --git a/MLXServer/Server/PromptBuilder.swift b/MLXServer/Server/PromptBuilder.swift new file mode 100644 index 0000000..abddda2 --- /dev/null +++ b/MLXServer/Server/PromptBuilder.swift @@ -0,0 +1,139 @@ +import Foundation +import MLXLMCommon + +/// Converts OpenAI-format API messages into reusable prompt artifacts for the API server. +enum PromptBuilder { + struct PreparedPrompt { + let instructions: String + let chatMessages: [Chat.Message] + let messageSignatures: [UInt64] + let estimatedBytes: Int + let estimatedPromptTokens: Int + let containsImages: Bool + let additionalContext: [String: any Sendable]? + let userInput: UserInput + } + + static func build( + from request: APIChatCompletionRequest, + modelId: String, + thinkingEnabled: Bool + ) -> PreparedPrompt { + var instructions = "" + for msg in request.messages where msg.role == "system" { + let text = msg.content?.textContent ?? "" + guard !text.isEmpty else { continue } + if !instructions.isEmpty { instructions += "\n\n" } + instructions += text + } + + if let tools = request.tools, !tools.isEmpty { + let toolPrompt = ToolPromptBuilder.buildSystemPrompt(tools: tools, modelId: modelId) + if !instructions.isEmpty { instructions += "\n\n" } + instructions += toolPrompt + } + + let isQwen = modelId.lowercased().contains("qwen") + var chatMessages: [Chat.Message] = [] + var messageSignatures: [UInt64] = [] + var estimatedBytes = instructions.utf8.count + var containsImages = false + + for msg in request.messages where msg.role != "system" { + let role: Chat.Message.Role = switch msg.role { + case "assistant": .assistant + case "tool": .user + default: .user + } + + var text = msg.content?.textContent ?? "" + if msg.role == "tool", !isQwen { + text = "```tool_output\n\(text)\n```" + } + + if msg.role == "assistant", let toolCalls = msg.tool_calls, !toolCalls.isEmpty { + let formattedCalls = isQwen + ? ToolPromptBuilder.formatQwenToolCalls(toolCalls) + : ToolPromptBuilder.formatGemmaToolCalls(toolCalls) + text = text.isEmpty ? formattedCalls : text + "\n" + formattedCalls + } + + let imageURLs = msg.content?.imageURLs ?? [] + var messageImages: [UserInput.Image] = [] + var messageImageBytes = 0 + for urlString in imageURLs { + if let decoded = ImageDecoder.decode(urlString) { + messageImages.append(decoded.image) + messageImageBytes += decoded.estimatedBytes + } + } + + containsImages = containsImages || !messageImages.isEmpty + chatMessages.append(Chat.Message(role: role, content: text, images: messageImages)) + messageSignatures.append(messageSignature(role: role, content: text, imageURLs: imageURLs)) + estimatedBytes += text.utf8.count + messageImageBytes + } + + let additionalContext: [String: any Sendable]? = thinkingEnabled + ? nil + : ["enable_thinking": false] + + var allMessages: [Chat.Message] = [] + if !instructions.isEmpty { + allMessages.append(Chat.Message(role: .system, content: instructions)) + } + allMessages.append(contentsOf: chatMessages) + + let allImages = chatMessages.flatMap(\ .images) + let userInput = UserInput( + prompt: .chat(allMessages), + images: allImages, + videos: [], + tools: nil, + additionalContext: additionalContext + ) + + let estimatedPromptTokens = (instructions.count + chatMessages.reduce(0) { $0 + $1.content.count }) * 10 / 35 + + return PreparedPrompt( + instructions: instructions, + chatMessages: chatMessages, + messageSignatures: messageSignatures, + estimatedBytes: estimatedBytes, + estimatedPromptTokens: estimatedPromptTokens, + containsImages: containsImages, + additionalContext: additionalContext, + userInput: userInput + ) + } + + private static func messageSignature(role: Chat.Message.Role, content: String, imageURLs: [String]) -> UInt64 { + var hash: UInt64 = 14_695_981_039_346_656_037 + + func mix(_ text: String) { + for byte in text.utf8 { + hash ^= UInt64(byte) + hash &*= 1_099_511_628_211 + } + } + + switch role { + case .assistant: + mix("assistant") + case .system: + mix("system") + case .user: + mix("user") + @unknown default: + mix("unknown") + } + mix("|") + mix(content) + for imageURL in imageURLs { + mix("|") + mix(imageURL) + } + + return hash + } +} \ No newline at end of file diff --git a/MLXServer/Server/TokenPrefixCache.swift b/MLXServer/Server/TokenPrefixCache.swift new file mode 100644 index 0000000..8afbfa2 --- /dev/null +++ b/MLXServer/Server/TokenPrefixCache.swift @@ -0,0 +1,317 @@ +import Foundation +import Metal +import MLXLMCommon +import os + +final class TokenPrefixCache: @unchecked Sendable { + static let shared = TokenPrefixCache() + + struct CacheLease: @unchecked Sendable { + let entryId: UUID + let kvCache: [KVCache]? + let matchedTokenCount: Int + let isHit: Bool + } + + struct EntrySummary: Identifiable, Sendable { + let id: UUID + let modelId: String + let tokenCount: Int + let estimatedBytes: Int + let createdAt: Date + let lastAccessAt: Date + let hitCount: Int + } + + struct Snapshot: Sendable { + let totalEntries: Int + let totalCachedTokens: Int + let estimatedBytes: Int + let memoryBudgetBytes: Int + let memoryUsagePercent: Double + let totalHits: Int + let totalMisses: Int + let totalEvictions: Int + let hitRate: Double + let entries: [EntrySummary] + } + + private final class TrieNode { + var children: [Int: TrieNode] = [:] + var entryId: UUID? + } + + private struct CacheEntry { + let id: UUID + let modelId: String + let kvCache: [KVCache] + let tokenCount: Int + let cacheKey: [Int] + let estimatedBytes: Int + let createdAt: Date + var lastAccessAt: Date + var hitCount: Int + } + + private struct Stats { + var totalHits: Int = 0 + var totalMisses: Int = 0 + var totalEvictions: Int = 0 + } + + private let lock = OSAllocatedUnfairLock() + private let maxMemoryBytes: Int + private let idleTTL: TimeInterval + private let estimateBytesProvider: ([KVCache]) -> Int + private let nowProvider: () -> Date + private var root = TrieNode() + private var entries: [UUID: CacheEntry] = [:] + private var currentMemoryBytes: Int = 0 + private var stats = Stats() + + private init() { + self.maxMemoryBytes = Self.computeMemoryBudget() + self.idleTTL = 30 * 60 + self.estimateBytesProvider = Self.estimateBytes + self.nowProvider = Date.init + } + + init( + memoryBudgetBytes: Int, + idleTTL: TimeInterval = 30 * 60, + estimateBytesProvider: @escaping ([KVCache]) -> Int = TokenPrefixCache.estimateBytes, + nowProvider: @escaping () -> Date = Date.init + ) { + self.maxMemoryBytes = memoryBudgetBytes + self.idleTTL = idleTTL + self.estimateBytesProvider = estimateBytesProvider + self.nowProvider = nowProvider + } + + func lookup(cacheKey: [Int], modelId: String) -> CacheLease { + lock.lock() + let now = nowProvider() + pruneExpiredLocked(now: now) + + var node = root + var bestMatch: (entryId: UUID, realTokenCount: Int)? + var realTokenCount = 0 + + for key in cacheKey { + guard let child = node.children[key] else { break } + node = child + if key >= 0 { realTokenCount += 1 } + if let entryId = node.entryId, + let entry = entries[entryId], + entry.modelId == modelId { + bestMatch = (entryId: entryId, realTokenCount: realTokenCount) + } + } + + guard let match = bestMatch, + var entry = entries[match.entryId] + else { + stats.totalMisses += 1 + lock.unlock() + return CacheLease(entryId: UUID(), kvCache: nil, matchedTokenCount: 0, isHit: false) + } + + entry.lastAccessAt = now + entry.hitCount += 1 + entries[match.entryId] = entry + removeEntryLocked(entry) + stats.totalHits += 1 + lock.unlock() + + return CacheLease( + entryId: match.entryId, + kvCache: entry.kvCache, + matchedTokenCount: match.realTokenCount, + isHit: true + ) + } + + func store( + entryId: UUID, + kvCache: [KVCache], + cacheKey: [Int], + modelId: String + ) { + lock.lock() + let now = nowProvider() + pruneExpiredLocked(now: now) + + let estimatedBytes = estimateBytesProvider(kvCache) + var node = root + for key in cacheKey { + if node.children[key] == nil { + node.children[key] = TrieNode() + } + node = node.children[key]! + } + + if let oldId = node.entryId, + let oldEntry = entries[oldId] { + removeEntryLocked(oldEntry) + } + + node.entryId = entryId + entries[entryId] = CacheEntry( + id: entryId, + modelId: modelId, + kvCache: kvCache, + tokenCount: cacheKey.filter { $0 >= 0 }.count, + cacheKey: cacheKey, + estimatedBytes: estimatedBytes, + createdAt: now, + lastAccessAt: now, + hitCount: 0 + ) + currentMemoryBytes += estimatedBytes + enforceBudgetLocked() + lock.unlock() + } + + func invalidateAll() { + lock.lock() + stats.totalEvictions += entries.count + entries.removeAll() + root = TrieNode() + currentMemoryBytes = 0 + lock.unlock() + } + + func reset() { + lock.lock() + root = TrieNode() + entries.removeAll() + currentMemoryBytes = 0 + stats = Stats() + lock.unlock() + } + + func snapshot() -> Snapshot { + lock.lock() + let now = nowProvider() + pruneExpiredLocked(now: now) + let orderedEntries = entries.values.sorted { lhs, rhs in + if lhs.lastAccessAt != rhs.lastAccessAt { + return lhs.lastAccessAt > rhs.lastAccessAt + } + return lhs.createdAt > rhs.createdAt + } + let hits = stats.totalHits + let misses = stats.totalMisses + let totalOps = hits + misses + + let snapshot = Snapshot( + totalEntries: orderedEntries.count, + totalCachedTokens: orderedEntries.reduce(0) { $0 + $1.tokenCount }, + estimatedBytes: currentMemoryBytes, + memoryBudgetBytes: maxMemoryBytes, + memoryUsagePercent: maxMemoryBytes > 0 + ? (Double(currentMemoryBytes) / Double(maxMemoryBytes)) * 100 + : 0, + totalHits: hits, + totalMisses: misses, + totalEvictions: stats.totalEvictions, + hitRate: totalOps > 0 ? (Double(hits) / Double(totalOps)) * 100 : 0, + entries: orderedEntries.map { + EntrySummary( + id: $0.id, + modelId: $0.modelId, + tokenCount: $0.tokenCount, + estimatedBytes: $0.estimatedBytes, + createdAt: $0.createdAt, + lastAccessAt: $0.lastAccessAt, + hitCount: $0.hitCount + ) + } + ) + lock.unlock() + return snapshot + } + + func debugTrieNodeCount() -> Int { + lock.lock() + let count = countNodes(root) + lock.unlock() + return count + } + + private func pruneExpiredLocked(now: Date) { + let expired = entries.values.filter { + now.timeIntervalSince($0.lastAccessAt) > idleTTL + } + for entry in expired { + removeEntryLocked(entry) + } + } + + private func enforceBudgetLocked() { + while currentMemoryBytes > maxMemoryBytes { + guard let victim = entries.values.min(by: evictionOrder) else { + break + } + removeEntryLocked(victim) + } + } + + private func removeEntryLocked(_ entry: CacheEntry) { + guard entries[entry.id] != nil else { return } + + var node = root + var path: [(parent: TrieNode, key: Int)] = [] + for key in entry.cacheKey { + guard let child = node.children[key] else { break } + path.append((parent: node, key: key)) + node = child + } + node.entryId = nil + + for (parent, key) in path.reversed() { + guard let child = parent.children[key] else { continue } + if child.children.isEmpty && child.entryId == nil { + parent.children.removeValue(forKey: key) + } else { + break + } + } + + currentMemoryBytes = max(0, currentMemoryBytes - entry.estimatedBytes) + entries.removeValue(forKey: entry.id) + stats.totalEvictions += 1 + } + + private func evictionOrder(lhs: CacheEntry, rhs: CacheEntry) -> Bool { + if lhs.lastAccessAt != rhs.lastAccessAt { + return lhs.lastAccessAt < rhs.lastAccessAt + } + if lhs.hitCount != rhs.hitCount { + return lhs.hitCount < rhs.hitCount + } + return lhs.createdAt < rhs.createdAt + } + + private func countNodes(_ node: TrieNode) -> Int { + 1 + node.children.values.reduce(0) { $0 + countNodes($1) } + } + + private static func computeMemoryBudget() -> Int { + guard let device = MTLCreateSystemDefaultDevice() else { + return 512 * 1024 * 1024 + } + let budget = Int(Double(device.recommendedMaxWorkingSetSize) * 0.20) + return max(256 * 1024 * 1024, min(budget, 8 * 1024 * 1024 * 1024)) + } + + private static func estimateBytes(_ kvCache: [KVCache]) -> Int { + var total = 0 + for layer in kvCache { + for array in layer.state { + total += array.nbytes + } + } + return max(total, 1024) + } +} \ No newline at end of file diff --git a/MLXServerTests/Server/ModelBackedInferenceValidationTests.swift b/MLXServerTests/Server/ModelBackedInferenceValidationTests.swift new file mode 100644 index 0000000..51df91e --- /dev/null +++ b/MLXServerTests/Server/ModelBackedInferenceValidationTests.swift @@ -0,0 +1,289 @@ +import Foundation +import Hub +import MLXLMCommon +import MLXVLM +import XCTest +@testable import MLX_Server + +final class ModelBackedInferenceValidationTests: XCTestCase { + private let onePixelPNGBase64 = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8BQDwAFgwJ/lRyXWQAAAABJRU5ErkJggg==" + + func testPromptBuilderTokenizationMatchesLegacyShapingOnLocalGemma() async throws { + let container = try await localGemmaContainer() + let engine = InferenceEngine(container: container) + let request = APIChatCompletionRequest( + model: "gemma", + messages: [ + APIChatMessage(role: "system", content: .text("You are concise."), name: nil, tool_calls: nil, tool_call_id: nil), + APIChatMessage( + role: "user", + content: .parts([ + APIContentPart(type: "text", text: "What is in this image?", image_url: nil), + APIContentPart(type: "image_url", text: nil, image_url: APIImageURL(url: "data:image/png;base64,\(onePixelPNGBase64)", detail: nil)) + ]), + name: nil, + tool_calls: nil, + tool_call_id: nil + ) + ], + temperature: nil, + top_p: nil, + max_tokens: nil, + stream: nil, + stop: nil, + tools: nil, + tool_choice: nil, + frequency_penalty: nil, + presence_penalty: nil, + n: nil + ) + + let prepared = PromptBuilder.build(from: request, modelId: "mlx-community/gemma-3-4b-it-4bit", thinkingEnabled: false) + let legacy = legacyBuild(from: request, modelId: "mlx-community/gemma-3-4b-it-4bit", thinkingEnabled: false) + + let preparedInference = try await engine.prepare(prepared.userInput) + let legacyInference = try await engine.prepare(legacy.userInput) + + XCTAssertEqual(preparedInference.tokens, legacyInference.tokens) + } + + func testInferenceEngineMatchesChatSessionOnLocalGemma() async throws { + let container = try await localGemmaContainer() + let engine = InferenceEngine(container: container) + let parameters = GenerateParameters(maxTokens: 1, temperature: 0) + let request = APIChatCompletionRequest( + model: "gemma", + messages: [ + APIChatMessage(role: "user", content: .text("Say hello in one word."), name: nil, tool_calls: nil, tool_call_id: nil) + ], + temperature: nil, + top_p: nil, + max_tokens: nil, + stream: nil, + stop: nil, + tools: nil, + tool_choice: nil, + frequency_penalty: nil, + presence_penalty: nil, + n: nil + ) + + let prepared = PromptBuilder.build(from: request, modelId: "mlx-community/gemma-3-4b-it-4bit", thinkingEnabled: true) + let preparedInference = try await engine.prepare(prepared.userInput) + let handle = try await engine.stream( + InferenceEngine.InferenceRequest( + input: preparedInference.lmInput, + tokens: preparedInference.tokens, + parameters: parameters, + cachedKV: nil, + cachedTokenCount: 0 + ), + cancellation: CancellationToken() + ) + + let engineResult = await collectEngineOutput(handle.stream) + + let session = ChatSession(container, generateParameters: parameters) + let sessionResult = try await collectSessionOutput( + session.streamDetails(to: "Say hello in one word.", images: [], videos: []) + ) + + XCTAssertEqual(engineResult.text, sessionResult.text) + XCTAssertEqual(engineResult.promptTokenCount, sessionResult.promptTokenCount) + } + + private func localGemmaContainer() async throws -> ModelContainer { + try await LocalGemmaFixture.shared.container() + } + + private func legacyBuild( + from request: APIChatCompletionRequest, + modelId: String, + thinkingEnabled: Bool + ) -> PromptBuilder.PreparedPrompt { + var instructions = "" + for msg in request.messages where msg.role == "system" { + let text = msg.content?.textContent ?? "" + if !text.isEmpty { + if !instructions.isEmpty { instructions += "\n\n" } + instructions += text + } + } + + if let tools = request.tools, !tools.isEmpty { + let toolSystemPrompt = ToolPromptBuilder.buildSystemPrompt(tools: tools, modelId: modelId) + if !instructions.isEmpty { instructions += "\n\n" } + instructions += toolSystemPrompt + } + + let isQwen = modelId.lowercased().contains("qwen") + var chatMessages: [Chat.Message] = [] + var messageSignatures: [UInt64] = [] + var estimatedBytes = instructions.utf8.count + var containsImages = false + + for msg in request.messages where msg.role != "system" { + let role: Chat.Message.Role = switch msg.role { + case "assistant": .assistant + case "tool": .user + default: .user + } + + var text = msg.content?.textContent ?? "" + if msg.role == "tool", !isQwen { + text = "```tool_output\n\(text)\n```" + } + + if msg.role == "assistant", let toolCalls = msg.tool_calls, !toolCalls.isEmpty { + let formattedCalls = isQwen + ? ToolPromptBuilder.formatQwenToolCalls(toolCalls) + : ToolPromptBuilder.formatGemmaToolCalls(toolCalls) + text = (text.isEmpty ? "" : text + "\n") + formattedCalls + } + + let imageURLs = msg.content?.imageURLs ?? [] + var messageImages: [UserInput.Image] = [] + var messageImageBytes = 0 + for urlString in imageURLs { + if let decoded = ImageDecoder.decode(urlString) { + messageImages.append(decoded.image) + messageImageBytes += decoded.estimatedBytes + } + } + + containsImages = containsImages || !messageImages.isEmpty + chatMessages.append(Chat.Message(role: role, content: text, images: messageImages)) + messageSignatures.append(messageSignature(role: role, content: text, imageURLs: imageURLs)) + estimatedBytes += text.utf8.count + messageImageBytes + } + + let additionalContext: [String: any Sendable]? = thinkingEnabled + ? nil + : ["enable_thinking": false] + + let allImages = chatMessages.flatMap(\.images) + let allMessages = (instructions.isEmpty ? [] : [Chat.Message(role: .system, content: instructions)]) + chatMessages + let userInput = UserInput( + prompt: .chat(allMessages), + images: allImages, + videos: [], + tools: nil, + additionalContext: additionalContext + ) + + return PromptBuilder.PreparedPrompt( + instructions: instructions, + chatMessages: chatMessages, + messageSignatures: messageSignatures, + estimatedBytes: estimatedBytes, + estimatedPromptTokens: (instructions.count + chatMessages.reduce(0) { $0 + $1.content.count }) * 10 / 35, + containsImages: containsImages, + additionalContext: additionalContext, + userInput: userInput + ) + } + + private func messageSignature(role: Chat.Message.Role, content: String, imageURLs: [String]) -> UInt64 { + var hash: UInt64 = 14_695_981_039_346_656_037 + + func mix(_ text: String) { + for byte in text.utf8 { + hash ^= UInt64(byte) + hash &*= 1_099_511_628_211 + } + } + + switch role { + case .assistant: + mix("assistant") + case .system: + mix("system") + case .user: + mix("user") + @unknown default: + mix("unknown") + } + mix("|") + mix(content) + for imageURL in imageURLs { + mix("|") + mix(imageURL) + } + + return hash + } + + private func collectEngineOutput(_ stream: AsyncStream) async -> GenerationResult { + var text = "" + var promptTokenCount = 0 + for await generation in stream { + switch generation { + case .chunk(let chunk): + text += chunk + case .info(let info): + promptTokenCount = info.promptTokenCount + case .toolCall: + break + } + } + return GenerationResult(text: text, promptTokenCount: promptTokenCount) + } + + private func collectSessionOutput(_ stream: AsyncThrowingStream) async throws -> GenerationResult { + var text = "" + var promptTokenCount = 0 + for try await generation in stream { + switch generation { + case .chunk(let chunk): + text += chunk + case .info(let info): + promptTokenCount = info.promptTokenCount + case .toolCall: + break + } + } + return GenerationResult(text: text, promptTokenCount: promptTokenCount) + } +} + +private struct GenerationResult { + let text: String + let promptTokenCount: Int +} + +private actor LocalGemmaFixture { + static let shared = LocalGemmaFixture() + + private var task: Task? + + func container() async throws -> ModelContainer { + if let task { + return try await task.value + } + + guard let config = ModelConfig.resolve("gemma") else { + throw XCTSkip("Gemma model config is unavailable") + } + guard let localDir = LocalModelResolver.resolve(repoId: config.repoId) else { + throw XCTSkip("Local gemma cache is unavailable") + } + + let loadTask = Task { + let cachesDir = FileManager.default.urls(for: .cachesDirectory, in: .userDomainMask).first + let hub = HubApi(downloadBase: cachesDir, cache: nil) + return try await VLMModelFactory.shared.loadContainer( + hub: hub, + configuration: ModelConfiguration(directory: localDir), + progressHandler: { _ in } + ) + } + task = loadTask + + do { + return try await loadTask.value + } catch { + task = nil + throw error + } + } +} \ No newline at end of file diff --git a/MLXServerTests/Server/PromptBuilderTests.swift b/MLXServerTests/Server/PromptBuilderTests.swift new file mode 100644 index 0000000..cd01f21 --- /dev/null +++ b/MLXServerTests/Server/PromptBuilderTests.swift @@ -0,0 +1,288 @@ +import XCTest +import MLXLMCommon +@testable import MLX_Server + +final class PromptBuilderTests: XCTestCase { + private let onePixelPNGBase64 = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8BQDwAFgwJ/lRyXWQAAAABJRU5ErkJggg==" + + func testBuildMatchesLegacyAPIServerShapingForGemma() { + let toolCall = APIToolCall( + id: "call_weather", + function: APIFunctionCall(name: "weather", arguments: "{\"city\":\"Berlin\"}") + ) + let request = APIChatCompletionRequest( + model: "gemma", + messages: [ + APIChatMessage(role: "system", content: .text("System 1"), name: nil, tool_calls: nil, tool_call_id: nil), + APIChatMessage(role: "system", content: .text("System 2"), name: nil, tool_calls: nil, tool_call_id: nil), + APIChatMessage(role: "assistant", content: .text("Let me check"), name: nil, tool_calls: [toolCall], tool_call_id: nil), + APIChatMessage( + role: "tool", + content: .parts([ + APIContentPart(type: "text", text: "{\"temp\":19}", image_url: nil), + APIContentPart(type: "image_url", text: nil, image_url: APIImageURL(url: "data:image/png;base64,\(onePixelPNGBase64)", detail: nil)) + ]), + name: nil, + tool_calls: nil, + tool_call_id: "call_weather" + ), + APIChatMessage(role: "user", content: .text("Thanks"), name: nil, tool_calls: nil, tool_call_id: nil) + ], + temperature: nil, + top_p: nil, + max_tokens: nil, + stream: nil, + stop: nil, + tools: [ + APIToolDefinition( + type: "function", + function: APIFunctionDefinition( + name: "weather", + description: "Lookup weather", + parameters: ["type": AnyCodable("object")] + ) + ) + ], + tool_choice: nil, + frequency_penalty: nil, + presence_penalty: nil, + n: nil + ) + + let prepared = PromptBuilder.build(from: request, modelId: "mlx-community/gemma-3-4b-it-4bit", thinkingEnabled: false) + let legacy = legacyBuild(from: request, modelId: "mlx-community/gemma-3-4b-it-4bit", thinkingEnabled: false) + + XCTAssertEqual(prepared.instructions, legacy.instructions) + XCTAssertEqual(prepared.chatMessages.map { $0.role.roleLabel }, legacy.chatMessages.map { $0.role.roleLabel }) + XCTAssertEqual(prepared.chatMessages.map(\.content), legacy.chatMessages.map(\.content)) + XCTAssertEqual(prepared.chatMessages.map { $0.images.count }, legacy.chatMessages.map { $0.images.count }) + XCTAssertEqual(prepared.messageSignatures, legacy.messageSignatures) + XCTAssertEqual(prepared.estimatedBytes, legacy.estimatedBytes) + XCTAssertEqual(prepared.estimatedPromptTokens, legacy.estimatedPromptTokens) + XCTAssertEqual(prepared.containsImages, legacy.containsImages) + XCTAssertEqual(prepared.additionalContext?["enable_thinking"] as? Bool, legacy.additionalContext?["enable_thinking"] as? Bool) + } + + func testBuildAggregatesInstructionsAndMessages() { + let request = APIChatCompletionRequest( + model: "gemma", + messages: [ + APIChatMessage(role: "system", content: .text("Base system"), name: nil, tool_calls: nil, tool_call_id: nil), + APIChatMessage(role: "system", content: .text("Extra system"), name: nil, tool_calls: nil, tool_call_id: nil), + APIChatMessage(role: "user", content: .text("Hello"), name: nil, tool_calls: nil, tool_call_id: nil) + ], + temperature: nil, + top_p: nil, + max_tokens: nil, + stream: nil, + stop: nil, + tools: nil, + tool_choice: nil, + frequency_penalty: nil, + presence_penalty: nil, + n: nil + ) + + let prepared = PromptBuilder.build(from: request, modelId: "mlx-community/gemma-3-4b-it-4bit", thinkingEnabled: false) + + XCTAssertEqual(prepared.instructions, "Base system\n\nExtra system") + XCTAssertEqual(prepared.chatMessages.count, 1) + XCTAssertEqual(prepared.chatMessages[0].content, "Hello") + XCTAssertEqual(prepared.messageSignatures.count, 1) + XCTAssertFalse(prepared.containsImages) + XCTAssertNotNil(prepared.additionalContext) + XCTAssertGreaterThan(prepared.estimatedPromptTokens, 0) + } + + func testBuildFormatsAssistantToolCallsForQwen() { + let toolCall = APIToolCall( + id: "call_1", + function: APIFunctionCall(name: "weather", arguments: "{\"city\":\"Berlin\"}") + ) + let request = APIChatCompletionRequest( + model: "qwen", + messages: [ + APIChatMessage(role: "assistant", content: .text("Let me check."), name: nil, tool_calls: [toolCall], tool_call_id: nil) + ], + temperature: nil, + top_p: nil, + max_tokens: nil, + stream: nil, + stop: nil, + tools: nil, + tool_choice: nil, + frequency_penalty: nil, + presence_penalty: nil, + n: nil + ) + + let prepared = PromptBuilder.build(from: request, modelId: "mlx-community/Qwen3-VL-4B-Instruct-4bit", thinkingEnabled: true) + + XCTAssertEqual(prepared.chatMessages.count, 1) + XCTAssertTrue(prepared.chatMessages[0].content.contains("Let me check.")) + XCTAssertTrue(prepared.chatMessages[0].content.contains("")) + XCTAssertNil(prepared.additionalContext) + } + + func testBuildWrapsGemmaToolOutputsAndTracksImages() { + let request = APIChatCompletionRequest( + model: "gemma", + messages: [ + APIChatMessage( + role: "tool", + content: .parts([ + APIContentPart(type: "text", text: "{\"ok\":true}", image_url: nil), + APIContentPart(type: "image_url", text: nil, image_url: APIImageURL(url: "data:image/png;base64,\(onePixelPNGBase64)", detail: nil)) + ]), + name: nil, + tool_calls: nil, + tool_call_id: "call_1" + ) + ], + temperature: nil, + top_p: nil, + max_tokens: nil, + stream: nil, + stop: nil, + tools: nil, + tool_choice: nil, + frequency_penalty: nil, + presence_penalty: nil, + n: nil + ) + + let prepared = PromptBuilder.build(from: request, modelId: "mlx-community/gemma-3-4b-it-4bit", thinkingEnabled: true) + + XCTAssertTrue(prepared.chatMessages[0].content.contains("```tool_output")) + XCTAssertTrue(prepared.containsImages) + XCTAssertEqual(prepared.chatMessages[0].images.count, 1) + XCTAssertGreaterThan(prepared.estimatedBytes, prepared.chatMessages[0].content.utf8.count) + } + + private func legacyBuild( + from request: APIChatCompletionRequest, + modelId: String, + thinkingEnabled: Bool + ) -> PromptBuilder.PreparedPrompt { + var instructions = "" + for msg in request.messages where msg.role == "system" { + let text = msg.content?.textContent ?? "" + if !text.isEmpty { + if !instructions.isEmpty { instructions += "\n\n" } + instructions += text + } + } + + if let tools = request.tools, !tools.isEmpty { + let toolSystemPrompt = ToolPromptBuilder.buildSystemPrompt(tools: tools, modelId: modelId) + if !instructions.isEmpty { instructions += "\n\n" } + instructions += toolSystemPrompt + } + + let isQwen = modelId.lowercased().contains("qwen") + var chatMessages: [Chat.Message] = [] + var messageSignatures: [UInt64] = [] + var estimatedBytes = instructions.utf8.count + var containsImages = false + + for msg in request.messages where msg.role != "system" { + let role: Chat.Message.Role = switch msg.role { + case "assistant": .assistant + case "tool": .user + default: .user + } + + var text = msg.content?.textContent ?? "" + if msg.role == "tool", !isQwen { + text = "```tool_output\n\(text)\n```" + } + + if msg.role == "assistant", let toolCalls = msg.tool_calls, !toolCalls.isEmpty { + let formattedCalls = isQwen + ? ToolPromptBuilder.formatQwenToolCalls(toolCalls) + : ToolPromptBuilder.formatGemmaToolCalls(toolCalls) + text = (text.isEmpty ? "" : text + "\n") + formattedCalls + } + + let imageURLs = msg.content?.imageURLs ?? [] + var messageImages: [UserInput.Image] = [] + var messageImageBytes = 0 + for urlString in imageURLs { + if let decoded = ImageDecoder.decode(urlString) { + messageImages.append(decoded.image) + messageImageBytes += decoded.estimatedBytes + } + } + + containsImages = containsImages || !messageImages.isEmpty + chatMessages.append(Chat.Message(role: role, content: text, images: messageImages)) + messageSignatures.append(messageSignature(role: role, content: text, imageURLs: imageURLs)) + estimatedBytes += text.utf8.count + messageImageBytes + } + + let additionalContext: [String: any Sendable]? = thinkingEnabled + ? nil + : ["enable_thinking": false] + + let allImages = chatMessages.flatMap(\.images) + let userInput = UserInput( + prompt: .chat((instructions.isEmpty ? [] : [Chat.Message(role: .system, content: instructions)]) + chatMessages), + images: allImages, + videos: [], + tools: nil, + additionalContext: additionalContext + ) + + return PromptBuilder.PreparedPrompt( + instructions: instructions, + chatMessages: chatMessages, + messageSignatures: messageSignatures, + estimatedBytes: estimatedBytes, + estimatedPromptTokens: (instructions.count + chatMessages.reduce(0) { $0 + $1.content.count }) * 10 / 35, + containsImages: containsImages, + additionalContext: additionalContext, + userInput: userInput + ) + } + + private func messageSignature(role: Chat.Message.Role, content: String, imageURLs: [String]) -> UInt64 { + var hash: UInt64 = 14_695_981_039_346_656_037 + + func mix(_ text: String) { + for byte in text.utf8 { + hash ^= UInt64(byte) + hash &*= 1_099_511_628_211 + } + } + + switch role { + case .assistant: + mix("assistant") + case .system: + mix("system") + case .user: + mix("user") + @unknown default: + mix("unknown") + } + mix("|") + mix(content) + for imageURL in imageURLs { + mix("|") + mix(imageURL) + } + + return hash + } +} + +private extension Chat.Message.Role { + var roleLabel: String { + switch self { + case .assistant: "assistant" + case .system: "system" + case .user: "user" + @unknown default: "unknown" + } + } +} \ No newline at end of file diff --git a/MLXServerTests/Server/TokenPrefixCacheTests.swift b/MLXServerTests/Server/TokenPrefixCacheTests.swift new file mode 100644 index 0000000..d8ef285 --- /dev/null +++ b/MLXServerTests/Server/TokenPrefixCacheTests.swift @@ -0,0 +1,130 @@ +import Foundation +import XCTest +import MLXLMCommon +@testable import MLX_Server + +final class TokenPrefixCacheTests: XCTestCase { + func testStoreAndLookupRemovesCheckedOutEntry() { + var now = Date(timeIntervalSince1970: 100) + let cache = TokenPrefixCache( + memoryBudgetBytes: 10_000, + estimateBytesProvider: { _ in 1_024 }, + nowProvider: { now } + ) + + let entryId = UUID() + cache.store(entryId: entryId, kvCache: [], cacheKey: [1, 2, 3], modelId: "model") + + XCTAssertEqual(cache.snapshot().totalEntries, 1) + + let lease = cache.lookup(cacheKey: [1, 2, 3, 4], modelId: "model") + + XCTAssertTrue(lease.isHit) + XCTAssertEqual(lease.entryId, entryId) + XCTAssertEqual(lease.matchedTokenCount, 3) + XCTAssertNotNil(lease.kvCache) + XCTAssertEqual(cache.snapshot().totalEntries, 0) + } + + func testLookupPrefersDeepestPrefixMatch() { + var now = Date(timeIntervalSince1970: 100) + let cache = TokenPrefixCache( + memoryBudgetBytes: 10_000, + estimateBytesProvider: { _ in 1_024 }, + nowProvider: { now } + ) + + cache.store(entryId: UUID(), kvCache: [], cacheKey: [1, 2], modelId: "model") + now.addTimeInterval(1) + let deepId = UUID() + cache.store(entryId: deepId, kvCache: [], cacheKey: [1, 2, 3], modelId: "model") + + let lease = cache.lookup(cacheKey: [1, 2, 3, 4], modelId: "model") + + XCTAssertTrue(lease.isHit) + XCTAssertEqual(lease.entryId, deepId) + XCTAssertEqual(lease.matchedTokenCount, 3) + } + + func testEvictsLeastRecentlyUsedEntryWhenOverBudget() { + var now = Date(timeIntervalSince1970: 100) + let cache = TokenPrefixCache( + memoryBudgetBytes: 2_048, + estimateBytesProvider: { _ in 1_024 }, + nowProvider: { now } + ) + + let firstId = UUID() + cache.store(entryId: firstId, kvCache: [], cacheKey: [1], modelId: "model") + now.addTimeInterval(1) + cache.store(entryId: UUID(), kvCache: [], cacheKey: [2], modelId: "model") + now.addTimeInterval(1) + cache.store(entryId: UUID(), kvCache: [], cacheKey: [3], modelId: "model") + + let firstLookup = cache.lookup(cacheKey: [1], modelId: "model") + let secondLookup = cache.lookup(cacheKey: [2], modelId: "model") + let thirdLookup = cache.lookup(cacheKey: [3], modelId: "model") + + XCTAssertFalse(firstLookup.isHit) + XCTAssertTrue(secondLookup.isHit) + XCTAssertTrue(thirdLookup.isHit) + } + + func testSnapshotPrunesExpiredEntries() { + var now = Date(timeIntervalSince1970: 100) + let cache = TokenPrefixCache( + memoryBudgetBytes: 10_000, + idleTTL: 5, + estimateBytesProvider: { _ in 1_024 }, + nowProvider: { now } + ) + + cache.store(entryId: UUID(), kvCache: [], cacheKey: [1, 2, 3], modelId: "model") + XCTAssertEqual(cache.snapshot().totalEntries, 1) + + now.addTimeInterval(10) + let snapshot = cache.snapshot() + + XCTAssertEqual(snapshot.totalEntries, 0) + XCTAssertGreaterThanOrEqual(snapshot.totalEvictions, 1) + } + + func testLookupPrunesTrieNodesForRemovedBranch() { + let cache = TokenPrefixCache( + memoryBudgetBytes: 10_000, + estimateBytesProvider: { _ in 1_024 } + ) + + cache.store(entryId: UUID(), kvCache: [], cacheKey: [1, 2, 3], modelId: "model") + cache.store(entryId: UUID(), kvCache: [], cacheKey: [1, 2, 4], modelId: "model") + + XCTAssertEqual(cache.debugTrieNodeCount(), 5) + + _ = cache.lookup(cacheKey: [1, 2, 3], modelId: "model") + + XCTAssertEqual(cache.debugTrieNodeCount(), 4) + + _ = cache.lookup(cacheKey: [1, 2, 4], modelId: "model") + + XCTAssertEqual(cache.debugTrieNodeCount(), 1) + } + + func testSnapshotReportsHitRateAndTokenTotals() { + let cache = TokenPrefixCache( + memoryBudgetBytes: 10_000, + estimateBytesProvider: { _ in 2_048 } + ) + + cache.store(entryId: UUID(), kvCache: [], cacheKey: [10, 20, 30], modelId: "model") + _ = cache.lookup(cacheKey: [10, 20, 30, 40], modelId: "model") + _ = cache.lookup(cacheKey: [99], modelId: "model") + + let snapshot = cache.snapshot() + + XCTAssertEqual(snapshot.totalHits, 1) + XCTAssertEqual(snapshot.totalMisses, 1) + XCTAssertEqual(snapshot.hitRate, 50, accuracy: 0.001) + XCTAssertEqual(snapshot.totalCachedTokens, 0) + XCTAssertEqual(snapshot.estimatedBytes, 0) + } +} \ No newline at end of file diff --git a/docs/session-cache-upgrade.md b/docs/session-cache-upgrade.md index 065e047..a8ce9ae 100644 --- a/docs/session-cache-upgrade.md +++ b/docs/session-cache-upgrade.md @@ -2564,9 +2564,11 @@ Each step should be independently buildable and testable. ### Phase 2: Core Engine -4. **`PromptBuilder.swift`** — Convert API messages to UserInput. Test by comparing tokenized output to what ChatSession produces for the same messages. -5. **`TokenPrefixCache.swift`** — The big one. Build trie + eviction + monitoring. Test: insert entries, verify lookup, verify eviction under memory pressure, verify trie cleanup. -6. **`InferenceEngine.swift`** — Thin wrapper using `container.perform { ctx in MLXLMCommon.generate(input:cache:parameters:context:) }`. Test: run a simple prompt through it, verify output matches ChatSession output. +4. [x] **`PromptBuilder.swift`** — Convert API messages to UserInput. Test by comparing tokenized output to what ChatSession produces for the same messages. +5. [x] **`TokenPrefixCache.swift`** — The big one. Build trie + eviction + monitoring. Test: insert entries, verify lookup, verify eviction under memory pressure, verify trie cleanup. +6. [x] **`InferenceEngine.swift`** — Thin wrapper using `container.perform { ctx in MLXLMCommon.generate(input:cache:parameters:context:) }`. Test: run a simple prompt through it, verify output matches ChatSession output. + +Validation note: `PromptBuilder.swift` is now covered by both shaping-parity unit tests and a model-backed tokenization parity test against the cached local Gemma 3 4B VLM. `InferenceEngine.swift` is now covered by a model-backed smoke test that compares one-token output and prompt-token counts against `ChatSession` on the same locally cached Gemma model. ### Phase 3: Integration @@ -2614,8 +2616,8 @@ Each step should be independently buildable and testable. ### Memory Management - [ ] Memory budget computed correctly from Metal device -- [ ] Entries evicted under memory pressure (oldest first) -- [ ] Expired entries pruned after 30 min idle +- [x] Entries evicted under memory pressure (oldest first) +- [x] Expired entries pruned after 30 min idle - [ ] Trie nodes cleaned up when entries are evicted (no memory leak) - [ ] `snapshot()` reports accurate memory usage and hit rates @@ -2628,7 +2630,7 @@ Each step should be independently buildable and testable. ### Streaming - [ ] SSE JSON is valid and parseable by standard clients -- [ ] `StreamingSSEEncoder` output matches `JSONEncoder` output byte-for-byte (for content deltas) +- [x] `StreamingSSEEncoder` output matches `JSONEncoder` output byte-for-byte (for content deltas) - [ ] Role delta sent once at stream start - [ ] Tool call chunks sent correctly - [ ] Final chunk has finish_reason and usage stats