feat: implement phase 2 of session-cache-upgrade.md

This commit is contained in:
2026-03-20 08:57:54 +01:00
parent e98e5fd88b
commit e40a2f3c45
10 changed files with 1282 additions and 99 deletions

5
.vscode/settings.json vendored Normal file
View File

@@ -0,0 +1,5 @@
{
"chat.tools.terminal.autoApprove": {
"./test.sh": true
}
}

View File

@@ -12,7 +12,9 @@
165E8AB6ADAE1D59B1A86420 /* Preferences.swift in Sources */ = {isa = PBXBuildFile; fileRef = 145B888FBDD4F931512C5473 /* Preferences.swift */; };
189362AAE2CDE5D4B3428334 /* ToolCallParser.swift in Sources */ = {isa = PBXBuildFile; fileRef = E73B165A1822729C907791AE /* ToolCallParser.swift */; };
1A8833E3CCD3289C95E282A2 /* ChatDocumentManifest.swift in Sources */ = {isa = PBXBuildFile; fileRef = 1607BDDE53C575627DCC6896 /* ChatDocumentManifest.swift */; };
1FE8C624898960ECCE39C0D4 /* PromptBuilderTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = 5F9426FA5A4AC55F8D9C080E /* PromptBuilderTests.swift */; };
20FFB5DBF75AA6C359AAE31C /* SceneManagementView.swift in Sources */ = {isa = PBXBuildFile; fileRef = 37FEB592E5E717F817B03151 /* SceneManagementView.swift */; };
221DEC86374902FCFD661A01 /* TokenPrefixCacheTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = 64B2EDD5D1881AC9E1E60913 /* TokenPrefixCacheTests.swift */; };
29879D696584B96CC56560DF /* ChatExporter.swift in Sources */ = {isa = PBXBuildFile; fileRef = D7C9BAD674E29688ACE53B0B /* ChatExporter.swift */; };
2CAAF7129F7CC45200FA9F6B /* ModelPickerView.swift in Sources */ = {isa = PBXBuildFile; fileRef = C3C3A76C02AF70A9D8F868FC /* ModelPickerView.swift */; };
2D08769282BD71C170DB0943 /* InferenceStats.swift in Sources */ = {isa = PBXBuildFile; fileRef = E35452B166893B25E765FF70 /* InferenceStats.swift */; };
@@ -27,13 +29,16 @@
621B7E4382199AC1378F5F9C /* StatusBarView.swift in Sources */ = {isa = PBXBuildFile; fileRef = B0EAB35D7130D56B9E7484BA /* StatusBarView.swift */; };
67262C5E24739F1FE0011439 /* StreamingSSEEncoder.swift in Sources */ = {isa = PBXBuildFile; fileRef = 615F8A7C9ABCADEB215D31BD /* StreamingSSEEncoder.swift */; };
6828CCA8B78AB40906F87CAB /* LocalModelResolver.swift in Sources */ = {isa = PBXBuildFile; fileRef = D733A0D1D4AC25DDDA6C8684 /* LocalModelResolver.swift */; };
741692862DB1F13EA0B2D14D /* TokenPrefixCache.swift in Sources */ = {isa = PBXBuildFile; fileRef = 1962D530BEABCC7F1E8E0ED1 /* TokenPrefixCache.swift */; };
7CD765C1E2F9F4D7504C8D09 /* Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = B629DA084A9A40E54F8EA5FA /* Assets.xcassets */; };
80646C5066BF79BC76E1D9D7 /* ModelConfig.swift in Sources */ = {isa = PBXBuildFile; fileRef = 38DFC212AF4359A45FBE22BA /* ModelConfig.swift */; };
84D32315B418B5243E017350 /* ToolPromptBuilder.swift in Sources */ = {isa = PBXBuildFile; fileRef = 16AE82A64D1D07AE3CD8D33A /* ToolPromptBuilder.swift */; };
85FB1EB49D76A9F21E181346 /* ChatScene.swift in Sources */ = {isa = PBXBuildFile; fileRef = C04EE8E6418EC6E9B66999B0 /* ChatScene.swift */; };
8E665E21CCCD87A907CEA78D /* ModelBackedInferenceValidationTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = D388BE00B42C06ED9D9905BF /* ModelBackedInferenceValidationTests.swift */; };
945474365D0B3E961811909A /* MLXVLM in Frameworks */ = {isa = PBXBuildFile; productRef = D5E8E1C2DD8D8AABB4306193 /* MLXVLM */; };
962083CCCC4AC848E0BBBC99 /* CancellationTokenTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = FEFF6168B2283FEC87B4BB8C /* CancellationTokenTests.swift */; };
A146BBA70CFBEC505BDCDF0D /* ImageDecoder.swift in Sources */ = {isa = PBXBuildFile; fileRef = 7C1A89C076E717F87A60397D /* ImageDecoder.swift */; };
AA17474A72C7F4EFBD5C4925 /* PromptBuilder.swift in Sources */ = {isa = PBXBuildFile; fileRef = E1E62624B6F285479CB33041 /* PromptBuilder.swift */; };
B13FFE238613BFBFC72E0CC8 /* ChatDocumentMigration.swift in Sources */ = {isa = PBXBuildFile; fileRef = 24E29065DD29C17D20B0400D /* ChatDocumentMigration.swift */; };
B1D9BC407DB7DB1489230C20 /* MonitorView.swift in Sources */ = {isa = PBXBuildFile; fileRef = 4239CFF94B819C35A8D4D617 /* MonitorView.swift */; };
B5AA6E3B4BE21676226B342B /* ChatViewModel.swift in Sources */ = {isa = PBXBuildFile; fileRef = B8BD93859F0291F1A3E09DA5 /* ChatViewModel.swift */; };
@@ -47,6 +52,7 @@
DF5C525DBD2E3153256951C1 /* SceneManagementWindow.swift in Sources */ = {isa = PBXBuildFile; fileRef = BA1592FD260014C4FBDB6995 /* SceneManagementWindow.swift */; };
E199D0BB09B61AC128AB093A /* CancellationToken.swift in Sources */ = {isa = PBXBuildFile; fileRef = 3489501F2F8E1BA382347CFA /* CancellationToken.swift */; };
E92B6656C251EDA246B8F582 /* ImageDecoderTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = E4573DC9314915F4C7963B4E /* ImageDecoderTests.swift */; };
EC4FC68608DDFA6A3DF133CC /* InferenceEngine.swift in Sources */ = {isa = PBXBuildFile; fileRef = 02EBDE0C72D1C5CE220E5B93 /* InferenceEngine.swift */; };
F141B91A64F7DAD73CE2910A /* ConversationSessionCache.swift in Sources */ = {isa = PBXBuildFile; fileRef = FFBB16D3AF2E61D001FD6051 /* ConversationSessionCache.swift */; };
F546CE5955ED253D8A793D5E /* MarkdownUI in Frameworks */ = {isa = PBXBuildFile; productRef = A98257123539E9E738213BFA /* MarkdownUI */; };
FAF7D4714AC6D02674920208 /* ChatMessage.swift in Sources */ = {isa = PBXBuildFile; fileRef = A4B359324B5FD8D106C74338 /* ChatMessage.swift */; };
@@ -65,10 +71,12 @@
/* End PBXContainerItemProxy section */
/* Begin PBXFileReference section */
02EBDE0C72D1C5CE220E5B93 /* InferenceEngine.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = InferenceEngine.swift; sourceTree = "<group>"; };
0F03A123A8908714A89315FE /* SceneCommands.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = SceneCommands.swift; sourceTree = "<group>"; };
145B888FBDD4F931512C5473 /* Preferences.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Preferences.swift; sourceTree = "<group>"; };
1607BDDE53C575627DCC6896 /* ChatDocumentManifest.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ChatDocumentManifest.swift; sourceTree = "<group>"; };
16AE82A64D1D07AE3CD8D33A /* ToolPromptBuilder.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ToolPromptBuilder.swift; sourceTree = "<group>"; };
1962D530BEABCC7F1E8E0ED1 /* TokenPrefixCache.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = TokenPrefixCache.swift; sourceTree = "<group>"; };
24E29065DD29C17D20B0400D /* ChatDocumentMigration.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ChatDocumentMigration.swift; sourceTree = "<group>"; };
2DC8C86D397B1FCA08E07CBD /* DownloadModalView.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = DownloadModalView.swift; sourceTree = "<group>"; };
2E2FCA55CEBEBCED78D9479A /* SaveChatCommands.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = SaveChatCommands.swift; sourceTree = "<group>"; };
@@ -81,7 +89,9 @@
4147321383E94E9F17A0154E /* SettingsView.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = SettingsView.swift; sourceTree = "<group>"; };
4239CFF94B819C35A8D4D617 /* MonitorView.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = MonitorView.swift; sourceTree = "<group>"; };
49C383DD5224F3420EB98DB2 /* StreamingSSEEncoderTests.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = StreamingSSEEncoderTests.swift; sourceTree = "<group>"; };
5F9426FA5A4AC55F8D9C080E /* PromptBuilderTests.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = PromptBuilderTests.swift; sourceTree = "<group>"; };
615F8A7C9ABCADEB215D31BD /* StreamingSSEEncoder.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = StreamingSSEEncoder.swift; sourceTree = "<group>"; };
64B2EDD5D1881AC9E1E60913 /* TokenPrefixCacheTests.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = TokenPrefixCacheTests.swift; sourceTree = "<group>"; };
6B3AA91D2C7842D7366F9A41 /* ChatDocumentPackage.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ChatDocumentPackage.swift; sourceTree = "<group>"; };
6EE59189918D06B8D2F588FC /* MLXServer.app */ = {isa = PBXFileReference; explicitFileType = wrapper.application; includeInIndex = 0; path = MLXServer.app; sourceTree = BUILT_PRODUCTS_DIR; };
7C1A89C076E717F87A60397D /* ImageDecoder.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ImageDecoder.swift; sourceTree = "<group>"; };
@@ -97,10 +107,12 @@
C234359924C542F07ED926A2 /* SceneStore.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = SceneStore.swift; sourceTree = "<group>"; };
C3C3A76C02AF70A9D8F868FC /* ModelPickerView.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ModelPickerView.swift; sourceTree = "<group>"; };
C67742651DB486871CEF1612 /* MLXServerApp.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = MLXServerApp.swift; sourceTree = "<group>"; };
D388BE00B42C06ED9D9905BF /* ModelBackedInferenceValidationTests.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ModelBackedInferenceValidationTests.swift; sourceTree = "<group>"; };
D5C1FCEFEA72B9ABB87FB20E /* ChatDocumentController.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ChatDocumentController.swift; sourceTree = "<group>"; };
D733A0D1D4AC25DDDA6C8684 /* LocalModelResolver.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = LocalModelResolver.swift; sourceTree = "<group>"; };
D7C9BAD674E29688ACE53B0B /* ChatExporter.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ChatExporter.swift; sourceTree = "<group>"; };
DB1A5E8B1C9F2BC4D262C53A /* ChatMessagesView.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ChatMessagesView.swift; sourceTree = "<group>"; };
E1E62624B6F285479CB33041 /* PromptBuilder.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = PromptBuilder.swift; sourceTree = "<group>"; };
E35452B166893B25E765FF70 /* InferenceStats.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = InferenceStats.swift; sourceTree = "<group>"; };
E4573DC9314915F4C7963B4E /* ImageDecoderTests.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ImageDecoderTests.swift; sourceTree = "<group>"; };
E5E6AD02CDF23BDAB64700A7 /* ChatInputView.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ChatInputView.swift; sourceTree = "<group>"; };
@@ -162,7 +174,10 @@
children = (
FEFF6168B2283FEC87B4BB8C /* CancellationTokenTests.swift */,
E4573DC9314915F4C7963B4E /* ImageDecoderTests.swift */,
D388BE00B42C06ED9D9905BF /* ModelBackedInferenceValidationTests.swift */,
5F9426FA5A4AC55F8D9C080E /* PromptBuilderTests.swift */,
49C383DD5224F3420EB98DB2 /* StreamingSSEEncoderTests.swift */,
64B2EDD5D1881AC9E1E60913 /* TokenPrefixCacheTests.swift */,
);
path = Server;
sourceTree = "<group>";
@@ -250,7 +265,10 @@
3489501F2F8E1BA382347CFA /* CancellationToken.swift */,
FFBB16D3AF2E61D001FD6051 /* ConversationSessionCache.swift */,
7C1A89C076E717F87A60397D /* ImageDecoder.swift */,
02EBDE0C72D1C5CE220E5B93 /* InferenceEngine.swift */,
E1E62624B6F285479CB33041 /* PromptBuilder.swift */,
615F8A7C9ABCADEB215D31BD /* StreamingSSEEncoder.swift */,
1962D530BEABCC7F1E8E0ED1 /* TokenPrefixCache.swift */,
E73B165A1822729C907791AE /* ToolCallParser.swift */,
16AE82A64D1D07AE3CD8D33A /* ToolPromptBuilder.swift */,
);
@@ -363,7 +381,10 @@
files = (
962083CCCC4AC848E0BBBC99 /* CancellationTokenTests.swift in Sources */,
E92B6656C251EDA246B8F582 /* ImageDecoderTests.swift in Sources */,
8E665E21CCCD87A907CEA78D /* ModelBackedInferenceValidationTests.swift in Sources */,
1FE8C624898960ECCE39C0D4 /* PromptBuilderTests.swift in Sources */,
FE4405F66873C75CD6FA19A5 /* StreamingSSEEncoderTests.swift in Sources */,
221DEC86374902FCFD661A01 /* TokenPrefixCacheTests.swift in Sources */,
);
runOnlyForDeploymentPostprocessing = 0;
};
@@ -389,6 +410,7 @@
C07A377244DCD67F4FE709FE /* DownloadModalView.swift in Sources */,
4DC033E45880B2948B47DEB1 /* FocusedValues.swift in Sources */,
A146BBA70CFBEC505BDCDF0D /* ImageDecoder.swift in Sources */,
EC4FC68608DDFA6A3DF133CC /* InferenceEngine.swift in Sources */,
2D08769282BD71C170DB0943 /* InferenceStats.swift in Sources */,
6828CCA8B78AB40906F87CAB /* LocalModelResolver.swift in Sources */,
50B6861FF8610B3ED4FFAD9D /* MLXServerApp.swift in Sources */,
@@ -397,6 +419,7 @@
2CAAF7129F7CC45200FA9F6B /* ModelPickerView.swift in Sources */,
B1D9BC407DB7DB1489230C20 /* MonitorView.swift in Sources */,
165E8AB6ADAE1D59B1A86420 /* Preferences.swift in Sources */,
AA17474A72C7F4EFBD5C4925 /* PromptBuilder.swift in Sources */,
4158FA884D981D73288FB74C /* SaveChatCommands.swift in Sources */,
07119250A7F9D6ECE7F6B8FD /* SceneCommands.swift in Sources */,
20FFB5DBF75AA6C359AAE31C /* SceneManagementView.swift in Sources */,
@@ -406,6 +429,7 @@
D666A311788375E8A061C832 /* SettingsView.swift in Sources */,
621B7E4382199AC1378F5F9C /* StatusBarView.swift in Sources */,
67262C5E24739F1FE0011439 /* StreamingSSEEncoder.swift in Sources */,
741692862DB1F13EA0B2D14D /* TokenPrefixCache.swift in Sources */,
189362AAE2CDE5D4B3428334 /* ToolCallParser.swift in Sources */,
84D32315B418B5243E017350 /* ToolPromptBuilder.swift in Sources */,
);

View File

@@ -218,91 +218,16 @@ final class APIServer {
}
LiveCounters.shared.requestStarted(requestId: requestId, contextLength: contextLength)
// Convert API messages to Chat.Message, extracting images from content parts
var chatMessages: [Chat.Message] = []
var messageSignatures: [UInt64] = []
var images: [UserInput.Image] = []
var estimatedBytes = 0
let currentModelRepoId = currentModel?.repoId ?? modelName
// Build the instructions string (system prompt + tool definitions).
// This is passed to ChatSession via `instructions:` rather than injected
// as history messages, so it avoids an expensive history-replay prefill.
var instructions: String = ""
// Collect system message text from the request
for msg in request.messages where msg.role == "system" {
let text = msg.content?.textContent ?? ""
if !text.isEmpty {
if !instructions.isEmpty { instructions += "\n\n" }
instructions += text
}
}
// Append tool definitions to instructions
if let tools = request.tools, !tools.isEmpty {
let toolSystemPrompt = ToolPromptBuilder.buildSystemPrompt(tools: tools, modelId: currentModelRepoId)
if !instructions.isEmpty { instructions += "\n\n" }
instructions += toolSystemPrompt
}
let isQwen = currentModelRepoId.lowercased().contains("qwen")
estimatedBytes += instructions.utf8.count
// Convert non-system messages to Chat.Message
for msg in request.messages where msg.role != "system" {
let role: Chat.Message.Role = switch msg.role {
case "assistant": .assistant
case "tool": .user
default: .user
}
var text = msg.content?.textContent ?? ""
// Format tool_call_id responses as tool_output for the model
if msg.role == "tool" {
if isQwen {
// Qwen expects tool results as-is in a user message
// (the role is already mapped to .user above)
} else {
// Gemma expects tool results wrapped in ```tool_output``` blocks
text = "```tool_output\n\(text)\n```"
}
}
// Format assistant tool_calls back into model-native format
if msg.role == "assistant", let toolCalls = msg.tool_calls, !toolCalls.isEmpty {
let formattedCalls: String
if isQwen {
formattedCalls = ToolPromptBuilder.formatQwenToolCalls(toolCalls)
} else {
formattedCalls = ToolPromptBuilder.formatGemmaToolCalls(toolCalls)
}
text = (text.isEmpty ? "" : text + "\n") + formattedCalls
}
// Extract base64 images from content parts
let imageURLs = msg.content?.imageURLs ?? []
var messageImages: [UserInput.Image] = []
var messageImageBytes = 0
for urlString in imageURLs {
if let decoded = ImageDecoder.decode(urlString) {
messageImages.append(decoded.image)
messageImageBytes += decoded.estimatedBytes
}
}
// Attach images to this specific message
chatMessages.append(Chat.Message(role: role, content: text, images: messageImages))
messageSignatures.append(
Self.messageSignature(role: role, content: text, imageURLs: imageURLs)
let preparedPrompt = PromptBuilder.build(
from: request,
modelId: currentModelRepoId,
thinkingEnabled: Preferences.enableThinking
)
estimatedBytes += text.utf8.count + messageImageBytes
images.append(contentsOf: messageImages)
}
let isQwen = currentModelRepoId.lowercased().contains("qwen")
if !images.isEmpty, currentModel?.supportsImages != true {
if preparedPrompt.containsImages, currentModel?.supportsImages != true {
LiveCounters.shared.requestCompleted(requestId: requestId, generationTokens: 0)
sendResponse(
connection: connection,
@@ -313,7 +238,7 @@ final class APIServer {
}
// Context window check: estimate token count and reject if over limit
let estimatedPromptTokens = (instructions.count + chatMessages.reduce(0) { $0 + $1.content.count }) * 10 / 35
let estimatedPromptTokens = preparedPrompt.estimatedPromptTokens
if contextLength > 0 {
let needed = estimatedPromptTokens + maxTokens
if needed > contextLength {
@@ -337,18 +262,19 @@ final class APIServer {
)
// Feed all messages except the last as history, then send the last as the prompt
let chatMessages = preparedPrompt.chatMessages
let allButLast = Array(chatMessages.dropLast())
let lastMessage = chatMessages.last ?? Chat.Message(role: .user, content: "")
let historySignatures = Array(messageSignatures.dropLast())
let historySignatures = Array(preparedPrompt.messageSignatures.dropLast())
let currentModelId = modelManager.currentModel?.id ?? modelName
let lease = ConversationSessionCache.shared.checkoutSession(
modelId: currentModelId,
instructions: instructions,
instructions: preparedPrompt.instructions,
historySignatures: historySignatures,
requestMessageCount: chatMessages.count,
estimatedPromptTokens: estimatedPromptTokens,
estimatedBytes: estimatedBytes
estimatedBytes: preparedPrompt.estimatedBytes
)
let session: ChatSession
@@ -365,24 +291,21 @@ final class APIServer {
// Use `instructions:` for system/tool prompt (matches internal chat pattern).
// Only conversation turns go in `history:` this avoids replaying the
// large tool prompt as history on every new session.
let instr = instructions.isEmpty ? nil : instructions
let thinkingContext: [String: any Sendable]? = Preferences.enableThinking
? nil
: ["enable_thinking": false]
let instr = preparedPrompt.instructions.isEmpty ? nil : preparedPrompt.instructions
if !allButLast.isEmpty {
session = ChatSession(
container,
instructions: instr,
history: allButLast,
generateParameters: generateParams,
additionalContext: thinkingContext
additionalContext: preparedPrompt.additionalContext
)
} else {
session = ChatSession(
container,
instructions: instr,
generateParameters: generateParams,
additionalContext: thinkingContext
additionalContext: preparedPrompt.additionalContext
)
}
ConversationSessionCache.shared.markPrefilling(entryId: lease.entryId)
@@ -423,7 +346,7 @@ final class APIServer {
}
if result.succeeded {
var cachedSignatures = messageSignatures
var cachedSignatures = preparedPrompt.messageSignatures
if let assistantHistoryText = result.assistantHistoryText {
cachedSignatures.append(
Self.messageSignature(role: .assistant, content: assistantHistoryText, imageURLs: [])
@@ -435,7 +358,7 @@ final class APIServer {
requestMessageSignatures: cachedSignatures,
requestMessageCount: cachedSignatures.count,
estimatedPromptTokens: estimatedPromptTokens,
estimatedBytes: estimatedBytes,
estimatedBytes: preparedPrompt.estimatedBytes,
promptTokens: result.promptTokens,
completionTokens: result.completionTokens
)

View File

@@ -0,0 +1,66 @@
import MLX
import MLXLMCommon
/// Stateless inference wrapper for the API path.
final class InferenceEngine: @unchecked Sendable {
private let container: ModelContainer
init(container: ModelContainer) {
self.container = container
}
struct InferenceRequest: @unchecked Sendable {
let input: LMInput
let tokens: [Int]
let parameters: GenerateParameters
let cachedKV: [KVCache]?
let cachedTokenCount: Int
}
struct StreamHandle: @unchecked Sendable {
let stream: AsyncStream<Generation>
let workingCache: [KVCache]
}
struct PreparedInference: @unchecked Sendable {
let lmInput: LMInput
let tokens: [Int]
let hasImages: Bool
}
func stream(
_ request: InferenceRequest,
cancellation: CancellationToken
) async throws -> StreamHandle {
_ = cancellation
nonisolated(unsafe) let input = request.input
nonisolated(unsafe) let cachedKV = request.cachedKV
let parameters = request.parameters
return try await container.perform { context in
let workingCache = cachedKV ?? context.model.newCache(parameters: parameters)
let stream = try MLXLMCommon.generate(
input: input,
cache: workingCache,
parameters: parameters,
context: context
)
return StreamHandle(stream: stream, workingCache: workingCache)
}
}
func prepare(_ userInput: UserInput) async throws -> PreparedInference {
nonisolated(unsafe) let input = userInput
let lmInput = try await container.prepare(input: input)
nonisolated(unsafe) let preparedInput = lmInput
let tokenArray: [Int] = await container.perform { _ in
preparedInput.text.tokens.asArray(Int.self)
}
return PreparedInference(
lmInput: lmInput,
tokens: tokenArray,
hasImages: userInput.images.count > 0
)
}
}

View File

@@ -0,0 +1,139 @@
import Foundation
import MLXLMCommon
/// Converts OpenAI-format API messages into reusable prompt artifacts for the API server.
enum PromptBuilder {
struct PreparedPrompt {
let instructions: String
let chatMessages: [Chat.Message]
let messageSignatures: [UInt64]
let estimatedBytes: Int
let estimatedPromptTokens: Int
let containsImages: Bool
let additionalContext: [String: any Sendable]?
let userInput: UserInput
}
static func build(
from request: APIChatCompletionRequest,
modelId: String,
thinkingEnabled: Bool
) -> PreparedPrompt {
var instructions = ""
for msg in request.messages where msg.role == "system" {
let text = msg.content?.textContent ?? ""
guard !text.isEmpty else { continue }
if !instructions.isEmpty { instructions += "\n\n" }
instructions += text
}
if let tools = request.tools, !tools.isEmpty {
let toolPrompt = ToolPromptBuilder.buildSystemPrompt(tools: tools, modelId: modelId)
if !instructions.isEmpty { instructions += "\n\n" }
instructions += toolPrompt
}
let isQwen = modelId.lowercased().contains("qwen")
var chatMessages: [Chat.Message] = []
var messageSignatures: [UInt64] = []
var estimatedBytes = instructions.utf8.count
var containsImages = false
for msg in request.messages where msg.role != "system" {
let role: Chat.Message.Role = switch msg.role {
case "assistant": .assistant
case "tool": .user
default: .user
}
var text = msg.content?.textContent ?? ""
if msg.role == "tool", !isQwen {
text = "```tool_output\n\(text)\n```"
}
if msg.role == "assistant", let toolCalls = msg.tool_calls, !toolCalls.isEmpty {
let formattedCalls = isQwen
? ToolPromptBuilder.formatQwenToolCalls(toolCalls)
: ToolPromptBuilder.formatGemmaToolCalls(toolCalls)
text = text.isEmpty ? formattedCalls : text + "\n" + formattedCalls
}
let imageURLs = msg.content?.imageURLs ?? []
var messageImages: [UserInput.Image] = []
var messageImageBytes = 0
for urlString in imageURLs {
if let decoded = ImageDecoder.decode(urlString) {
messageImages.append(decoded.image)
messageImageBytes += decoded.estimatedBytes
}
}
containsImages = containsImages || !messageImages.isEmpty
chatMessages.append(Chat.Message(role: role, content: text, images: messageImages))
messageSignatures.append(messageSignature(role: role, content: text, imageURLs: imageURLs))
estimatedBytes += text.utf8.count + messageImageBytes
}
let additionalContext: [String: any Sendable]? = thinkingEnabled
? nil
: ["enable_thinking": false]
var allMessages: [Chat.Message] = []
if !instructions.isEmpty {
allMessages.append(Chat.Message(role: .system, content: instructions))
}
allMessages.append(contentsOf: chatMessages)
let allImages = chatMessages.flatMap(\ .images)
let userInput = UserInput(
prompt: .chat(allMessages),
images: allImages,
videos: [],
tools: nil,
additionalContext: additionalContext
)
let estimatedPromptTokens = (instructions.count + chatMessages.reduce(0) { $0 + $1.content.count }) * 10 / 35
return PreparedPrompt(
instructions: instructions,
chatMessages: chatMessages,
messageSignatures: messageSignatures,
estimatedBytes: estimatedBytes,
estimatedPromptTokens: estimatedPromptTokens,
containsImages: containsImages,
additionalContext: additionalContext,
userInput: userInput
)
}
private static func messageSignature(role: Chat.Message.Role, content: String, imageURLs: [String]) -> UInt64 {
var hash: UInt64 = 14_695_981_039_346_656_037
func mix(_ text: String) {
for byte in text.utf8 {
hash ^= UInt64(byte)
hash &*= 1_099_511_628_211
}
}
switch role {
case .assistant:
mix("assistant")
case .system:
mix("system")
case .user:
mix("user")
@unknown default:
mix("unknown")
}
mix("|")
mix(content)
for imageURL in imageURLs {
mix("|")
mix(imageURL)
}
return hash
}
}

View File

@@ -0,0 +1,317 @@
import Foundation
import Metal
import MLXLMCommon
import os
final class TokenPrefixCache: @unchecked Sendable {
static let shared = TokenPrefixCache()
struct CacheLease: @unchecked Sendable {
let entryId: UUID
let kvCache: [KVCache]?
let matchedTokenCount: Int
let isHit: Bool
}
struct EntrySummary: Identifiable, Sendable {
let id: UUID
let modelId: String
let tokenCount: Int
let estimatedBytes: Int
let createdAt: Date
let lastAccessAt: Date
let hitCount: Int
}
struct Snapshot: Sendable {
let totalEntries: Int
let totalCachedTokens: Int
let estimatedBytes: Int
let memoryBudgetBytes: Int
let memoryUsagePercent: Double
let totalHits: Int
let totalMisses: Int
let totalEvictions: Int
let hitRate: Double
let entries: [EntrySummary]
}
private final class TrieNode {
var children: [Int: TrieNode] = [:]
var entryId: UUID?
}
private struct CacheEntry {
let id: UUID
let modelId: String
let kvCache: [KVCache]
let tokenCount: Int
let cacheKey: [Int]
let estimatedBytes: Int
let createdAt: Date
var lastAccessAt: Date
var hitCount: Int
}
private struct Stats {
var totalHits: Int = 0
var totalMisses: Int = 0
var totalEvictions: Int = 0
}
private let lock = OSAllocatedUnfairLock()
private let maxMemoryBytes: Int
private let idleTTL: TimeInterval
private let estimateBytesProvider: ([KVCache]) -> Int
private let nowProvider: () -> Date
private var root = TrieNode()
private var entries: [UUID: CacheEntry] = [:]
private var currentMemoryBytes: Int = 0
private var stats = Stats()
private init() {
self.maxMemoryBytes = Self.computeMemoryBudget()
self.idleTTL = 30 * 60
self.estimateBytesProvider = Self.estimateBytes
self.nowProvider = Date.init
}
init(
memoryBudgetBytes: Int,
idleTTL: TimeInterval = 30 * 60,
estimateBytesProvider: @escaping ([KVCache]) -> Int = TokenPrefixCache.estimateBytes,
nowProvider: @escaping () -> Date = Date.init
) {
self.maxMemoryBytes = memoryBudgetBytes
self.idleTTL = idleTTL
self.estimateBytesProvider = estimateBytesProvider
self.nowProvider = nowProvider
}
func lookup(cacheKey: [Int], modelId: String) -> CacheLease {
lock.lock()
let now = nowProvider()
pruneExpiredLocked(now: now)
var node = root
var bestMatch: (entryId: UUID, realTokenCount: Int)?
var realTokenCount = 0
for key in cacheKey {
guard let child = node.children[key] else { break }
node = child
if key >= 0 { realTokenCount += 1 }
if let entryId = node.entryId,
let entry = entries[entryId],
entry.modelId == modelId {
bestMatch = (entryId: entryId, realTokenCount: realTokenCount)
}
}
guard let match = bestMatch,
var entry = entries[match.entryId]
else {
stats.totalMisses += 1
lock.unlock()
return CacheLease(entryId: UUID(), kvCache: nil, matchedTokenCount: 0, isHit: false)
}
entry.lastAccessAt = now
entry.hitCount += 1
entries[match.entryId] = entry
removeEntryLocked(entry)
stats.totalHits += 1
lock.unlock()
return CacheLease(
entryId: match.entryId,
kvCache: entry.kvCache,
matchedTokenCount: match.realTokenCount,
isHit: true
)
}
func store(
entryId: UUID,
kvCache: [KVCache],
cacheKey: [Int],
modelId: String
) {
lock.lock()
let now = nowProvider()
pruneExpiredLocked(now: now)
let estimatedBytes = estimateBytesProvider(kvCache)
var node = root
for key in cacheKey {
if node.children[key] == nil {
node.children[key] = TrieNode()
}
node = node.children[key]!
}
if let oldId = node.entryId,
let oldEntry = entries[oldId] {
removeEntryLocked(oldEntry)
}
node.entryId = entryId
entries[entryId] = CacheEntry(
id: entryId,
modelId: modelId,
kvCache: kvCache,
tokenCount: cacheKey.filter { $0 >= 0 }.count,
cacheKey: cacheKey,
estimatedBytes: estimatedBytes,
createdAt: now,
lastAccessAt: now,
hitCount: 0
)
currentMemoryBytes += estimatedBytes
enforceBudgetLocked()
lock.unlock()
}
func invalidateAll() {
lock.lock()
stats.totalEvictions += entries.count
entries.removeAll()
root = TrieNode()
currentMemoryBytes = 0
lock.unlock()
}
func reset() {
lock.lock()
root = TrieNode()
entries.removeAll()
currentMemoryBytes = 0
stats = Stats()
lock.unlock()
}
func snapshot() -> Snapshot {
lock.lock()
let now = nowProvider()
pruneExpiredLocked(now: now)
let orderedEntries = entries.values.sorted { lhs, rhs in
if lhs.lastAccessAt != rhs.lastAccessAt {
return lhs.lastAccessAt > rhs.lastAccessAt
}
return lhs.createdAt > rhs.createdAt
}
let hits = stats.totalHits
let misses = stats.totalMisses
let totalOps = hits + misses
let snapshot = Snapshot(
totalEntries: orderedEntries.count,
totalCachedTokens: orderedEntries.reduce(0) { $0 + $1.tokenCount },
estimatedBytes: currentMemoryBytes,
memoryBudgetBytes: maxMemoryBytes,
memoryUsagePercent: maxMemoryBytes > 0
? (Double(currentMemoryBytes) / Double(maxMemoryBytes)) * 100
: 0,
totalHits: hits,
totalMisses: misses,
totalEvictions: stats.totalEvictions,
hitRate: totalOps > 0 ? (Double(hits) / Double(totalOps)) * 100 : 0,
entries: orderedEntries.map {
EntrySummary(
id: $0.id,
modelId: $0.modelId,
tokenCount: $0.tokenCount,
estimatedBytes: $0.estimatedBytes,
createdAt: $0.createdAt,
lastAccessAt: $0.lastAccessAt,
hitCount: $0.hitCount
)
}
)
lock.unlock()
return snapshot
}
func debugTrieNodeCount() -> Int {
lock.lock()
let count = countNodes(root)
lock.unlock()
return count
}
private func pruneExpiredLocked(now: Date) {
let expired = entries.values.filter {
now.timeIntervalSince($0.lastAccessAt) > idleTTL
}
for entry in expired {
removeEntryLocked(entry)
}
}
private func enforceBudgetLocked() {
while currentMemoryBytes > maxMemoryBytes {
guard let victim = entries.values.min(by: evictionOrder) else {
break
}
removeEntryLocked(victim)
}
}
private func removeEntryLocked(_ entry: CacheEntry) {
guard entries[entry.id] != nil else { return }
var node = root
var path: [(parent: TrieNode, key: Int)] = []
for key in entry.cacheKey {
guard let child = node.children[key] else { break }
path.append((parent: node, key: key))
node = child
}
node.entryId = nil
for (parent, key) in path.reversed() {
guard let child = parent.children[key] else { continue }
if child.children.isEmpty && child.entryId == nil {
parent.children.removeValue(forKey: key)
} else {
break
}
}
currentMemoryBytes = max(0, currentMemoryBytes - entry.estimatedBytes)
entries.removeValue(forKey: entry.id)
stats.totalEvictions += 1
}
private func evictionOrder(lhs: CacheEntry, rhs: CacheEntry) -> Bool {
if lhs.lastAccessAt != rhs.lastAccessAt {
return lhs.lastAccessAt < rhs.lastAccessAt
}
if lhs.hitCount != rhs.hitCount {
return lhs.hitCount < rhs.hitCount
}
return lhs.createdAt < rhs.createdAt
}
private func countNodes(_ node: TrieNode) -> Int {
1 + node.children.values.reduce(0) { $0 + countNodes($1) }
}
private static func computeMemoryBudget() -> Int {
guard let device = MTLCreateSystemDefaultDevice() else {
return 512 * 1024 * 1024
}
let budget = Int(Double(device.recommendedMaxWorkingSetSize) * 0.20)
return max(256 * 1024 * 1024, min(budget, 8 * 1024 * 1024 * 1024))
}
private static func estimateBytes(_ kvCache: [KVCache]) -> Int {
var total = 0
for layer in kvCache {
for array in layer.state {
total += array.nbytes
}
}
return max(total, 1024)
}
}

View File

@@ -0,0 +1,289 @@
import Foundation
import Hub
import MLXLMCommon
import MLXVLM
import XCTest
@testable import MLX_Server
final class ModelBackedInferenceValidationTests: XCTestCase {
private let onePixelPNGBase64 = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8BQDwAFgwJ/lRyXWQAAAABJRU5ErkJggg=="
func testPromptBuilderTokenizationMatchesLegacyShapingOnLocalGemma() async throws {
let container = try await localGemmaContainer()
let engine = InferenceEngine(container: container)
let request = APIChatCompletionRequest(
model: "gemma",
messages: [
APIChatMessage(role: "system", content: .text("You are concise."), name: nil, tool_calls: nil, tool_call_id: nil),
APIChatMessage(
role: "user",
content: .parts([
APIContentPart(type: "text", text: "What is in this image?", image_url: nil),
APIContentPart(type: "image_url", text: nil, image_url: APIImageURL(url: "data:image/png;base64,\(onePixelPNGBase64)", detail: nil))
]),
name: nil,
tool_calls: nil,
tool_call_id: nil
)
],
temperature: nil,
top_p: nil,
max_tokens: nil,
stream: nil,
stop: nil,
tools: nil,
tool_choice: nil,
frequency_penalty: nil,
presence_penalty: nil,
n: nil
)
let prepared = PromptBuilder.build(from: request, modelId: "mlx-community/gemma-3-4b-it-4bit", thinkingEnabled: false)
let legacy = legacyBuild(from: request, modelId: "mlx-community/gemma-3-4b-it-4bit", thinkingEnabled: false)
let preparedInference = try await engine.prepare(prepared.userInput)
let legacyInference = try await engine.prepare(legacy.userInput)
XCTAssertEqual(preparedInference.tokens, legacyInference.tokens)
}
func testInferenceEngineMatchesChatSessionOnLocalGemma() async throws {
let container = try await localGemmaContainer()
let engine = InferenceEngine(container: container)
let parameters = GenerateParameters(maxTokens: 1, temperature: 0)
let request = APIChatCompletionRequest(
model: "gemma",
messages: [
APIChatMessage(role: "user", content: .text("Say hello in one word."), name: nil, tool_calls: nil, tool_call_id: nil)
],
temperature: nil,
top_p: nil,
max_tokens: nil,
stream: nil,
stop: nil,
tools: nil,
tool_choice: nil,
frequency_penalty: nil,
presence_penalty: nil,
n: nil
)
let prepared = PromptBuilder.build(from: request, modelId: "mlx-community/gemma-3-4b-it-4bit", thinkingEnabled: true)
let preparedInference = try await engine.prepare(prepared.userInput)
let handle = try await engine.stream(
InferenceEngine.InferenceRequest(
input: preparedInference.lmInput,
tokens: preparedInference.tokens,
parameters: parameters,
cachedKV: nil,
cachedTokenCount: 0
),
cancellation: CancellationToken()
)
let engineResult = await collectEngineOutput(handle.stream)
let session = ChatSession(container, generateParameters: parameters)
let sessionResult = try await collectSessionOutput(
session.streamDetails(to: "Say hello in one word.", images: [], videos: [])
)
XCTAssertEqual(engineResult.text, sessionResult.text)
XCTAssertEqual(engineResult.promptTokenCount, sessionResult.promptTokenCount)
}
private func localGemmaContainer() async throws -> ModelContainer {
try await LocalGemmaFixture.shared.container()
}
private func legacyBuild(
from request: APIChatCompletionRequest,
modelId: String,
thinkingEnabled: Bool
) -> PromptBuilder.PreparedPrompt {
var instructions = ""
for msg in request.messages where msg.role == "system" {
let text = msg.content?.textContent ?? ""
if !text.isEmpty {
if !instructions.isEmpty { instructions += "\n\n" }
instructions += text
}
}
if let tools = request.tools, !tools.isEmpty {
let toolSystemPrompt = ToolPromptBuilder.buildSystemPrompt(tools: tools, modelId: modelId)
if !instructions.isEmpty { instructions += "\n\n" }
instructions += toolSystemPrompt
}
let isQwen = modelId.lowercased().contains("qwen")
var chatMessages: [Chat.Message] = []
var messageSignatures: [UInt64] = []
var estimatedBytes = instructions.utf8.count
var containsImages = false
for msg in request.messages where msg.role != "system" {
let role: Chat.Message.Role = switch msg.role {
case "assistant": .assistant
case "tool": .user
default: .user
}
var text = msg.content?.textContent ?? ""
if msg.role == "tool", !isQwen {
text = "```tool_output\n\(text)\n```"
}
if msg.role == "assistant", let toolCalls = msg.tool_calls, !toolCalls.isEmpty {
let formattedCalls = isQwen
? ToolPromptBuilder.formatQwenToolCalls(toolCalls)
: ToolPromptBuilder.formatGemmaToolCalls(toolCalls)
text = (text.isEmpty ? "" : text + "\n") + formattedCalls
}
let imageURLs = msg.content?.imageURLs ?? []
var messageImages: [UserInput.Image] = []
var messageImageBytes = 0
for urlString in imageURLs {
if let decoded = ImageDecoder.decode(urlString) {
messageImages.append(decoded.image)
messageImageBytes += decoded.estimatedBytes
}
}
containsImages = containsImages || !messageImages.isEmpty
chatMessages.append(Chat.Message(role: role, content: text, images: messageImages))
messageSignatures.append(messageSignature(role: role, content: text, imageURLs: imageURLs))
estimatedBytes += text.utf8.count + messageImageBytes
}
let additionalContext: [String: any Sendable]? = thinkingEnabled
? nil
: ["enable_thinking": false]
let allImages = chatMessages.flatMap(\.images)
let allMessages = (instructions.isEmpty ? [] : [Chat.Message(role: .system, content: instructions)]) + chatMessages
let userInput = UserInput(
prompt: .chat(allMessages),
images: allImages,
videos: [],
tools: nil,
additionalContext: additionalContext
)
return PromptBuilder.PreparedPrompt(
instructions: instructions,
chatMessages: chatMessages,
messageSignatures: messageSignatures,
estimatedBytes: estimatedBytes,
estimatedPromptTokens: (instructions.count + chatMessages.reduce(0) { $0 + $1.content.count }) * 10 / 35,
containsImages: containsImages,
additionalContext: additionalContext,
userInput: userInput
)
}
private func messageSignature(role: Chat.Message.Role, content: String, imageURLs: [String]) -> UInt64 {
var hash: UInt64 = 14_695_981_039_346_656_037
func mix(_ text: String) {
for byte in text.utf8 {
hash ^= UInt64(byte)
hash &*= 1_099_511_628_211
}
}
switch role {
case .assistant:
mix("assistant")
case .system:
mix("system")
case .user:
mix("user")
@unknown default:
mix("unknown")
}
mix("|")
mix(content)
for imageURL in imageURLs {
mix("|")
mix(imageURL)
}
return hash
}
private func collectEngineOutput(_ stream: AsyncStream<Generation>) async -> GenerationResult {
var text = ""
var promptTokenCount = 0
for await generation in stream {
switch generation {
case .chunk(let chunk):
text += chunk
case .info(let info):
promptTokenCount = info.promptTokenCount
case .toolCall:
break
}
}
return GenerationResult(text: text, promptTokenCount: promptTokenCount)
}
private func collectSessionOutput(_ stream: AsyncThrowingStream<Generation, any Error>) async throws -> GenerationResult {
var text = ""
var promptTokenCount = 0
for try await generation in stream {
switch generation {
case .chunk(let chunk):
text += chunk
case .info(let info):
promptTokenCount = info.promptTokenCount
case .toolCall:
break
}
}
return GenerationResult(text: text, promptTokenCount: promptTokenCount)
}
}
private struct GenerationResult {
let text: String
let promptTokenCount: Int
}
private actor LocalGemmaFixture {
static let shared = LocalGemmaFixture()
private var task: Task<ModelContainer, Error>?
func container() async throws -> ModelContainer {
if let task {
return try await task.value
}
guard let config = ModelConfig.resolve("gemma") else {
throw XCTSkip("Gemma model config is unavailable")
}
guard let localDir = LocalModelResolver.resolve(repoId: config.repoId) else {
throw XCTSkip("Local gemma cache is unavailable")
}
let loadTask = Task<ModelContainer, Error> {
let cachesDir = FileManager.default.urls(for: .cachesDirectory, in: .userDomainMask).first
let hub = HubApi(downloadBase: cachesDir, cache: nil)
return try await VLMModelFactory.shared.loadContainer(
hub: hub,
configuration: ModelConfiguration(directory: localDir),
progressHandler: { _ in }
)
}
task = loadTask
do {
return try await loadTask.value
} catch {
task = nil
throw error
}
}
}

View File

@@ -0,0 +1,288 @@
import XCTest
import MLXLMCommon
@testable import MLX_Server
final class PromptBuilderTests: XCTestCase {
private let onePixelPNGBase64 = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8BQDwAFgwJ/lRyXWQAAAABJRU5ErkJggg=="
func testBuildMatchesLegacyAPIServerShapingForGemma() {
let toolCall = APIToolCall(
id: "call_weather",
function: APIFunctionCall(name: "weather", arguments: "{\"city\":\"Berlin\"}")
)
let request = APIChatCompletionRequest(
model: "gemma",
messages: [
APIChatMessage(role: "system", content: .text("System 1"), name: nil, tool_calls: nil, tool_call_id: nil),
APIChatMessage(role: "system", content: .text("System 2"), name: nil, tool_calls: nil, tool_call_id: nil),
APIChatMessage(role: "assistant", content: .text("Let me check"), name: nil, tool_calls: [toolCall], tool_call_id: nil),
APIChatMessage(
role: "tool",
content: .parts([
APIContentPart(type: "text", text: "{\"temp\":19}", image_url: nil),
APIContentPart(type: "image_url", text: nil, image_url: APIImageURL(url: "data:image/png;base64,\(onePixelPNGBase64)", detail: nil))
]),
name: nil,
tool_calls: nil,
tool_call_id: "call_weather"
),
APIChatMessage(role: "user", content: .text("Thanks"), name: nil, tool_calls: nil, tool_call_id: nil)
],
temperature: nil,
top_p: nil,
max_tokens: nil,
stream: nil,
stop: nil,
tools: [
APIToolDefinition(
type: "function",
function: APIFunctionDefinition(
name: "weather",
description: "Lookup weather",
parameters: ["type": AnyCodable("object")]
)
)
],
tool_choice: nil,
frequency_penalty: nil,
presence_penalty: nil,
n: nil
)
let prepared = PromptBuilder.build(from: request, modelId: "mlx-community/gemma-3-4b-it-4bit", thinkingEnabled: false)
let legacy = legacyBuild(from: request, modelId: "mlx-community/gemma-3-4b-it-4bit", thinkingEnabled: false)
XCTAssertEqual(prepared.instructions, legacy.instructions)
XCTAssertEqual(prepared.chatMessages.map { $0.role.roleLabel }, legacy.chatMessages.map { $0.role.roleLabel })
XCTAssertEqual(prepared.chatMessages.map(\.content), legacy.chatMessages.map(\.content))
XCTAssertEqual(prepared.chatMessages.map { $0.images.count }, legacy.chatMessages.map { $0.images.count })
XCTAssertEqual(prepared.messageSignatures, legacy.messageSignatures)
XCTAssertEqual(prepared.estimatedBytes, legacy.estimatedBytes)
XCTAssertEqual(prepared.estimatedPromptTokens, legacy.estimatedPromptTokens)
XCTAssertEqual(prepared.containsImages, legacy.containsImages)
XCTAssertEqual(prepared.additionalContext?["enable_thinking"] as? Bool, legacy.additionalContext?["enable_thinking"] as? Bool)
}
func testBuildAggregatesInstructionsAndMessages() {
let request = APIChatCompletionRequest(
model: "gemma",
messages: [
APIChatMessage(role: "system", content: .text("Base system"), name: nil, tool_calls: nil, tool_call_id: nil),
APIChatMessage(role: "system", content: .text("Extra system"), name: nil, tool_calls: nil, tool_call_id: nil),
APIChatMessage(role: "user", content: .text("Hello"), name: nil, tool_calls: nil, tool_call_id: nil)
],
temperature: nil,
top_p: nil,
max_tokens: nil,
stream: nil,
stop: nil,
tools: nil,
tool_choice: nil,
frequency_penalty: nil,
presence_penalty: nil,
n: nil
)
let prepared = PromptBuilder.build(from: request, modelId: "mlx-community/gemma-3-4b-it-4bit", thinkingEnabled: false)
XCTAssertEqual(prepared.instructions, "Base system\n\nExtra system")
XCTAssertEqual(prepared.chatMessages.count, 1)
XCTAssertEqual(prepared.chatMessages[0].content, "Hello")
XCTAssertEqual(prepared.messageSignatures.count, 1)
XCTAssertFalse(prepared.containsImages)
XCTAssertNotNil(prepared.additionalContext)
XCTAssertGreaterThan(prepared.estimatedPromptTokens, 0)
}
func testBuildFormatsAssistantToolCallsForQwen() {
let toolCall = APIToolCall(
id: "call_1",
function: APIFunctionCall(name: "weather", arguments: "{\"city\":\"Berlin\"}")
)
let request = APIChatCompletionRequest(
model: "qwen",
messages: [
APIChatMessage(role: "assistant", content: .text("Let me check."), name: nil, tool_calls: [toolCall], tool_call_id: nil)
],
temperature: nil,
top_p: nil,
max_tokens: nil,
stream: nil,
stop: nil,
tools: nil,
tool_choice: nil,
frequency_penalty: nil,
presence_penalty: nil,
n: nil
)
let prepared = PromptBuilder.build(from: request, modelId: "mlx-community/Qwen3-VL-4B-Instruct-4bit", thinkingEnabled: true)
XCTAssertEqual(prepared.chatMessages.count, 1)
XCTAssertTrue(prepared.chatMessages[0].content.contains("Let me check."))
XCTAssertTrue(prepared.chatMessages[0].content.contains("<tool_call>"))
XCTAssertNil(prepared.additionalContext)
}
func testBuildWrapsGemmaToolOutputsAndTracksImages() {
let request = APIChatCompletionRequest(
model: "gemma",
messages: [
APIChatMessage(
role: "tool",
content: .parts([
APIContentPart(type: "text", text: "{\"ok\":true}", image_url: nil),
APIContentPart(type: "image_url", text: nil, image_url: APIImageURL(url: "data:image/png;base64,\(onePixelPNGBase64)", detail: nil))
]),
name: nil,
tool_calls: nil,
tool_call_id: "call_1"
)
],
temperature: nil,
top_p: nil,
max_tokens: nil,
stream: nil,
stop: nil,
tools: nil,
tool_choice: nil,
frequency_penalty: nil,
presence_penalty: nil,
n: nil
)
let prepared = PromptBuilder.build(from: request, modelId: "mlx-community/gemma-3-4b-it-4bit", thinkingEnabled: true)
XCTAssertTrue(prepared.chatMessages[0].content.contains("```tool_output"))
XCTAssertTrue(prepared.containsImages)
XCTAssertEqual(prepared.chatMessages[0].images.count, 1)
XCTAssertGreaterThan(prepared.estimatedBytes, prepared.chatMessages[0].content.utf8.count)
}
private func legacyBuild(
from request: APIChatCompletionRequest,
modelId: String,
thinkingEnabled: Bool
) -> PromptBuilder.PreparedPrompt {
var instructions = ""
for msg in request.messages where msg.role == "system" {
let text = msg.content?.textContent ?? ""
if !text.isEmpty {
if !instructions.isEmpty { instructions += "\n\n" }
instructions += text
}
}
if let tools = request.tools, !tools.isEmpty {
let toolSystemPrompt = ToolPromptBuilder.buildSystemPrompt(tools: tools, modelId: modelId)
if !instructions.isEmpty { instructions += "\n\n" }
instructions += toolSystemPrompt
}
let isQwen = modelId.lowercased().contains("qwen")
var chatMessages: [Chat.Message] = []
var messageSignatures: [UInt64] = []
var estimatedBytes = instructions.utf8.count
var containsImages = false
for msg in request.messages where msg.role != "system" {
let role: Chat.Message.Role = switch msg.role {
case "assistant": .assistant
case "tool": .user
default: .user
}
var text = msg.content?.textContent ?? ""
if msg.role == "tool", !isQwen {
text = "```tool_output\n\(text)\n```"
}
if msg.role == "assistant", let toolCalls = msg.tool_calls, !toolCalls.isEmpty {
let formattedCalls = isQwen
? ToolPromptBuilder.formatQwenToolCalls(toolCalls)
: ToolPromptBuilder.formatGemmaToolCalls(toolCalls)
text = (text.isEmpty ? "" : text + "\n") + formattedCalls
}
let imageURLs = msg.content?.imageURLs ?? []
var messageImages: [UserInput.Image] = []
var messageImageBytes = 0
for urlString in imageURLs {
if let decoded = ImageDecoder.decode(urlString) {
messageImages.append(decoded.image)
messageImageBytes += decoded.estimatedBytes
}
}
containsImages = containsImages || !messageImages.isEmpty
chatMessages.append(Chat.Message(role: role, content: text, images: messageImages))
messageSignatures.append(messageSignature(role: role, content: text, imageURLs: imageURLs))
estimatedBytes += text.utf8.count + messageImageBytes
}
let additionalContext: [String: any Sendable]? = thinkingEnabled
? nil
: ["enable_thinking": false]
let allImages = chatMessages.flatMap(\.images)
let userInput = UserInput(
prompt: .chat((instructions.isEmpty ? [] : [Chat.Message(role: .system, content: instructions)]) + chatMessages),
images: allImages,
videos: [],
tools: nil,
additionalContext: additionalContext
)
return PromptBuilder.PreparedPrompt(
instructions: instructions,
chatMessages: chatMessages,
messageSignatures: messageSignatures,
estimatedBytes: estimatedBytes,
estimatedPromptTokens: (instructions.count + chatMessages.reduce(0) { $0 + $1.content.count }) * 10 / 35,
containsImages: containsImages,
additionalContext: additionalContext,
userInput: userInput
)
}
private func messageSignature(role: Chat.Message.Role, content: String, imageURLs: [String]) -> UInt64 {
var hash: UInt64 = 14_695_981_039_346_656_037
func mix(_ text: String) {
for byte in text.utf8 {
hash ^= UInt64(byte)
hash &*= 1_099_511_628_211
}
}
switch role {
case .assistant:
mix("assistant")
case .system:
mix("system")
case .user:
mix("user")
@unknown default:
mix("unknown")
}
mix("|")
mix(content)
for imageURL in imageURLs {
mix("|")
mix(imageURL)
}
return hash
}
}
private extension Chat.Message.Role {
var roleLabel: String {
switch self {
case .assistant: "assistant"
case .system: "system"
case .user: "user"
@unknown default: "unknown"
}
}
}

View File

@@ -0,0 +1,130 @@
import Foundation
import XCTest
import MLXLMCommon
@testable import MLX_Server
final class TokenPrefixCacheTests: XCTestCase {
func testStoreAndLookupRemovesCheckedOutEntry() {
var now = Date(timeIntervalSince1970: 100)
let cache = TokenPrefixCache(
memoryBudgetBytes: 10_000,
estimateBytesProvider: { _ in 1_024 },
nowProvider: { now }
)
let entryId = UUID()
cache.store(entryId: entryId, kvCache: [], cacheKey: [1, 2, 3], modelId: "model")
XCTAssertEqual(cache.snapshot().totalEntries, 1)
let lease = cache.lookup(cacheKey: [1, 2, 3, 4], modelId: "model")
XCTAssertTrue(lease.isHit)
XCTAssertEqual(lease.entryId, entryId)
XCTAssertEqual(lease.matchedTokenCount, 3)
XCTAssertNotNil(lease.kvCache)
XCTAssertEqual(cache.snapshot().totalEntries, 0)
}
func testLookupPrefersDeepestPrefixMatch() {
var now = Date(timeIntervalSince1970: 100)
let cache = TokenPrefixCache(
memoryBudgetBytes: 10_000,
estimateBytesProvider: { _ in 1_024 },
nowProvider: { now }
)
cache.store(entryId: UUID(), kvCache: [], cacheKey: [1, 2], modelId: "model")
now.addTimeInterval(1)
let deepId = UUID()
cache.store(entryId: deepId, kvCache: [], cacheKey: [1, 2, 3], modelId: "model")
let lease = cache.lookup(cacheKey: [1, 2, 3, 4], modelId: "model")
XCTAssertTrue(lease.isHit)
XCTAssertEqual(lease.entryId, deepId)
XCTAssertEqual(lease.matchedTokenCount, 3)
}
func testEvictsLeastRecentlyUsedEntryWhenOverBudget() {
var now = Date(timeIntervalSince1970: 100)
let cache = TokenPrefixCache(
memoryBudgetBytes: 2_048,
estimateBytesProvider: { _ in 1_024 },
nowProvider: { now }
)
let firstId = UUID()
cache.store(entryId: firstId, kvCache: [], cacheKey: [1], modelId: "model")
now.addTimeInterval(1)
cache.store(entryId: UUID(), kvCache: [], cacheKey: [2], modelId: "model")
now.addTimeInterval(1)
cache.store(entryId: UUID(), kvCache: [], cacheKey: [3], modelId: "model")
let firstLookup = cache.lookup(cacheKey: [1], modelId: "model")
let secondLookup = cache.lookup(cacheKey: [2], modelId: "model")
let thirdLookup = cache.lookup(cacheKey: [3], modelId: "model")
XCTAssertFalse(firstLookup.isHit)
XCTAssertTrue(secondLookup.isHit)
XCTAssertTrue(thirdLookup.isHit)
}
func testSnapshotPrunesExpiredEntries() {
var now = Date(timeIntervalSince1970: 100)
let cache = TokenPrefixCache(
memoryBudgetBytes: 10_000,
idleTTL: 5,
estimateBytesProvider: { _ in 1_024 },
nowProvider: { now }
)
cache.store(entryId: UUID(), kvCache: [], cacheKey: [1, 2, 3], modelId: "model")
XCTAssertEqual(cache.snapshot().totalEntries, 1)
now.addTimeInterval(10)
let snapshot = cache.snapshot()
XCTAssertEqual(snapshot.totalEntries, 0)
XCTAssertGreaterThanOrEqual(snapshot.totalEvictions, 1)
}
func testLookupPrunesTrieNodesForRemovedBranch() {
let cache = TokenPrefixCache(
memoryBudgetBytes: 10_000,
estimateBytesProvider: { _ in 1_024 }
)
cache.store(entryId: UUID(), kvCache: [], cacheKey: [1, 2, 3], modelId: "model")
cache.store(entryId: UUID(), kvCache: [], cacheKey: [1, 2, 4], modelId: "model")
XCTAssertEqual(cache.debugTrieNodeCount(), 5)
_ = cache.lookup(cacheKey: [1, 2, 3], modelId: "model")
XCTAssertEqual(cache.debugTrieNodeCount(), 4)
_ = cache.lookup(cacheKey: [1, 2, 4], modelId: "model")
XCTAssertEqual(cache.debugTrieNodeCount(), 1)
}
func testSnapshotReportsHitRateAndTokenTotals() {
let cache = TokenPrefixCache(
memoryBudgetBytes: 10_000,
estimateBytesProvider: { _ in 2_048 }
)
cache.store(entryId: UUID(), kvCache: [], cacheKey: [10, 20, 30], modelId: "model")
_ = cache.lookup(cacheKey: [10, 20, 30, 40], modelId: "model")
_ = cache.lookup(cacheKey: [99], modelId: "model")
let snapshot = cache.snapshot()
XCTAssertEqual(snapshot.totalHits, 1)
XCTAssertEqual(snapshot.totalMisses, 1)
XCTAssertEqual(snapshot.hitRate, 50, accuracy: 0.001)
XCTAssertEqual(snapshot.totalCachedTokens, 0)
XCTAssertEqual(snapshot.estimatedBytes, 0)
}
}

View File

@@ -2564,9 +2564,11 @@ Each step should be independently buildable and testable.
### Phase 2: Core Engine
4. **`PromptBuilder.swift`** — Convert API messages to UserInput. Test by comparing tokenized output to what ChatSession produces for the same messages.
5. **`TokenPrefixCache.swift`** — The big one. Build trie + eviction + monitoring. Test: insert entries, verify lookup, verify eviction under memory pressure, verify trie cleanup.
6. **`InferenceEngine.swift`** — Thin wrapper using `container.perform { ctx in MLXLMCommon.generate(input:cache:parameters:context:) }`. Test: run a simple prompt through it, verify output matches ChatSession output.
4. [x] **`PromptBuilder.swift`** — Convert API messages to UserInput. Test by comparing tokenized output to what ChatSession produces for the same messages.
5. [x] **`TokenPrefixCache.swift`** — The big one. Build trie + eviction + monitoring. Test: insert entries, verify lookup, verify eviction under memory pressure, verify trie cleanup.
6. [x] **`InferenceEngine.swift`** — Thin wrapper using `container.perform { ctx in MLXLMCommon.generate(input:cache:parameters:context:) }`. Test: run a simple prompt through it, verify output matches ChatSession output.
Validation note: `PromptBuilder.swift` is now covered by both shaping-parity unit tests and a model-backed tokenization parity test against the cached local Gemma 3 4B VLM. `InferenceEngine.swift` is now covered by a model-backed smoke test that compares one-token output and prompt-token counts against `ChatSession` on the same locally cached Gemma model.
### Phase 3: Integration
@@ -2614,8 +2616,8 @@ Each step should be independently buildable and testable.
### Memory Management
- [ ] Memory budget computed correctly from Metal device
- [ ] Entries evicted under memory pressure (oldest first)
- [ ] Expired entries pruned after 30 min idle
- [x] Entries evicted under memory pressure (oldest first)
- [x] Expired entries pruned after 30 min idle
- [ ] Trie nodes cleaned up when entries are evicted (no memory leak)
- [ ] `snapshot()` reports accurate memory usage and hit rates
@@ -2628,7 +2630,7 @@ Each step should be independently buildable and testable.
### Streaming
- [ ] SSE JSON is valid and parseable by standard clients
- [ ] `StreamingSSEEncoder` output matches `JSONEncoder` output byte-for-byte (for content deltas)
- [x] `StreamingSSEEncoder` output matches `JSONEncoder` output byte-for-byte (for content deltas)
- [ ] Role delta sent once at stream start
- [ ] Tool call chunks sent correctly
- [ ] Final chunk has finish_reason and usage stats