import Foundation import Hub import MLXLMCommon import MLXVLM import XCTest @testable import MLX_Server final class ModelBackedInferenceValidationTests: XCTestCase { private let onePixelPNGBase64 = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8BQDwAFgwJ/lRyXWQAAAABJRU5ErkJggg==" func testPromptBuilderTokenizationMatchesLegacyShapingOnLocalGemma() async throws { let container = try await localGemmaContainer() let engine = InferenceEngine(container: container) let request = APIChatCompletionRequest( model: "gemma", messages: [ APIChatMessage(role: "system", content: .text("You are concise."), name: nil, tool_calls: nil, tool_call_id: nil), APIChatMessage( role: "user", content: .parts([ APIContentPart(type: "text", text: "What is in this image?", image_url: nil), APIContentPart(type: "image_url", text: nil, image_url: APIImageURL(url: "data:image/png;base64,\(onePixelPNGBase64)", detail: nil)) ]), name: nil, tool_calls: nil, tool_call_id: nil ) ], temperature: nil, top_p: nil, max_tokens: nil, stream: nil, stop: nil, tools: nil, tool_choice: nil, frequency_penalty: nil, presence_penalty: nil, n: nil ) let prepared = PromptBuilder.build(from: request, modelId: "mlx-community/gemma-3-4b-it-4bit", thinkingEnabled: false) let legacy = legacyBuild(from: request, modelId: "mlx-community/gemma-3-4b-it-4bit", thinkingEnabled: false) let preparedInference = try await engine.prepare(prepared.userInput) let legacyInference = try await engine.prepare(legacy.userInput) XCTAssertEqual(preparedInference.tokens, legacyInference.tokens) } func testInferenceEngineMatchesChatSessionOnLocalGemma() async throws { let container = try await localGemmaContainer() let engine = InferenceEngine(container: container) let parameters = GenerateParameters(maxTokens: 1, temperature: 0) let request = APIChatCompletionRequest( model: "gemma", messages: [ APIChatMessage(role: "user", content: .text("Say hello in one word."), name: nil, tool_calls: nil, tool_call_id: nil) ], temperature: nil, top_p: nil, max_tokens: nil, stream: nil, stop: nil, tools: nil, tool_choice: nil, frequency_penalty: nil, presence_penalty: nil, n: nil ) let prepared = PromptBuilder.build(from: request, modelId: "mlx-community/gemma-3-4b-it-4bit", thinkingEnabled: true) let preparedInference = try await engine.prepare(prepared.userInput) let handle = try await engine.stream( InferenceEngine.InferenceRequest( input: preparedInference.lmInput, tokens: preparedInference.tokens, parameters: parameters, cachedKV: nil, cachedTokenCount: 0 ), cancellation: CancellationToken() ) let engineResult = await collectEngineOutput(handle.stream) let session = ChatSession(container, generateParameters: parameters) let sessionResult = try await collectSessionOutput( session.streamDetails(to: "Say hello in one word.", images: [], videos: []) ) XCTAssertEqual(engineResult.text, sessionResult.text) XCTAssertEqual(engineResult.promptTokenCount, sessionResult.promptTokenCount) } func testTokenPrefixCacheFindsLCPHitForSameSystemDifferentUserOnLocalGemmaTokens() async throws { let container = try await localGemmaContainer() let engine = InferenceEngine(container: container) let first = PromptBuilder.build( from: APIChatCompletionRequest( model: "gemma", messages: [ APIChatMessage(role: "system", content: .text("You are terse and literal."), name: nil, tool_calls: nil, tool_call_id: nil), APIChatMessage(role: "user", content: .text("Respond with one word for cat."), name: nil, tool_calls: nil, tool_call_id: nil), ], temperature: nil, top_p: nil, max_tokens: nil, stream: nil, stop: nil, tools: nil, tool_choice: nil, frequency_penalty: nil, presence_penalty: nil, n: nil ), modelId: "mlx-community/gemma-3-4b-it-4bit", thinkingEnabled: true ) let second = PromptBuilder.build( from: APIChatCompletionRequest( model: "gemma", messages: [ APIChatMessage(role: "system", content: .text("You are terse and literal."), name: nil, tool_calls: nil, tool_call_id: nil), APIChatMessage(role: "user", content: .text("Respond with one word for dog."), name: nil, tool_calls: nil, tool_call_id: nil), ], temperature: nil, top_p: nil, max_tokens: nil, stream: nil, stop: nil, tools: nil, tool_choice: nil, frequency_penalty: nil, presence_penalty: nil, n: nil ), modelId: "mlx-community/gemma-3-4b-it-4bit", thinkingEnabled: true ) let firstPrepared = try await engine.prepare(first.userInput) let secondPrepared = try await engine.prepare(second.userInput) let cache = TokenPrefixCache(memoryBudgetBytes: 1_000_000, estimateBytesProvider: { _ in 1_024 }) cache.store(entryId: UUID(), kvCache: [], cacheKey: firstPrepared.tokens, modelId: "gemma") let lease = cache.lookup(cacheKey: secondPrepared.tokens, modelId: "gemma") XCTAssertTrue(lease.isHit) XCTAssertGreaterThan(lease.matchedTokenCount, 0) XCTAssertLessThan(lease.matchedTokenCount, firstPrepared.tokens.count) } func testStoredLiveGemmaCacheSupportsSameSystemDifferentUserLCPReuse() async throws { let container = try await localGemmaContainer() let engine = InferenceEngine(container: container) let first = PromptBuilder.build( from: APIChatCompletionRequest( model: "gemma", messages: [ APIChatMessage(role: "system", content: .text("You are terse and literal."), name: nil, tool_calls: nil, tool_call_id: nil), APIChatMessage(role: "user", content: .text("Respond with one word for cat."), name: nil, tool_calls: nil, tool_call_id: nil), ], temperature: nil, top_p: nil, max_tokens: nil, stream: nil, stop: nil, tools: nil, tool_choice: nil, frequency_penalty: nil, presence_penalty: nil, n: nil ), modelId: "mlx-community/gemma-3-4b-it-4bit", thinkingEnabled: true ) let second = PromptBuilder.build( from: APIChatCompletionRequest( model: "gemma", messages: [ APIChatMessage(role: "system", content: .text("You are terse and literal."), name: nil, tool_calls: nil, tool_call_id: nil), APIChatMessage(role: "user", content: .text("Respond with one word for dog."), name: nil, tool_calls: nil, tool_call_id: nil), ], temperature: nil, top_p: nil, max_tokens: nil, stream: nil, stop: nil, tools: nil, tool_choice: nil, frequency_penalty: nil, presence_penalty: nil, n: nil ), modelId: "mlx-community/gemma-3-4b-it-4bit", thinkingEnabled: true ) let firstPrepared = try await engine.prepare(first.userInput) let secondPrepared = try await engine.prepare(second.userInput) let handle = try await engine.stream( InferenceEngine.InferenceRequest( input: firstPrepared.lmInput, tokens: firstPrepared.tokens, parameters: GenerateParameters(maxTokens: 2, temperature: 0), cachedKV: nil, cachedTokenCount: 0 ), cancellation: CancellationToken() ) _ = await collectEngineOutput(handle.stream) trimCacheToPrompt(handle.workingCache, promptTokenCount: firstPrepared.tokens.count) let cache = TokenPrefixCache(memoryBudgetBytes: 1_000_000_000, estimateBytesProvider: { _ in 1_024 }) cache.store(entryId: UUID(), kvCache: handle.workingCache, cacheKey: firstPrepared.tokens, modelId: "gemma") let lease = cache.lookup(cacheKey: secondPrepared.tokens, modelId: "gemma") XCTAssertTrue(lease.isHit) XCTAssertGreaterThan(lease.matchedTokenCount, 0) XCTAssertLessThan(lease.matchedTokenCount, firstPrepared.tokens.count) } func testTokenPrefixCacheCanFalseHitDifferentSystemPromptsOnRawGemmaTokens() async throws { let container = try await localGemmaContainer() let engine = InferenceEngine(container: container) let first = PromptBuilder.build( from: APIChatCompletionRequest( model: "gemma", messages: [ APIChatMessage(role: "system", content: .text("System Alpha Unique Tokens"), name: nil, tool_calls: nil, tool_call_id: nil), APIChatMessage(role: "user", content: .text("Answer in one word: tree."), name: nil, tool_calls: nil, tool_call_id: nil), ], temperature: nil, top_p: nil, max_tokens: nil, stream: nil, stop: nil, tools: nil, tool_choice: nil, frequency_penalty: nil, presence_penalty: nil, n: nil ), modelId: "mlx-community/gemma-3-4b-it-4bit", thinkingEnabled: true ) let second = PromptBuilder.build( from: APIChatCompletionRequest( model: "gemma", messages: [ APIChatMessage(role: "system", content: .text("Completely Different Beta Markers"), name: nil, tool_calls: nil, tool_call_id: nil), APIChatMessage(role: "user", content: .text("Answer in one word: tree."), name: nil, tool_calls: nil, tool_call_id: nil), ], temperature: nil, top_p: nil, max_tokens: nil, stream: nil, stop: nil, tools: nil, tool_choice: nil, frequency_penalty: nil, presence_penalty: nil, n: nil ), modelId: "mlx-community/gemma-3-4b-it-4bit", thinkingEnabled: true ) let firstPrepared = try await engine.prepare(first.userInput) let secondPrepared = try await engine.prepare(second.userInput) let cache = TokenPrefixCache(memoryBudgetBytes: 1_000_000, estimateBytesProvider: { _ in 1_024 }) cache.store(entryId: UUID(), kvCache: [], cacheKey: firstPrepared.tokens, modelId: "gemma") let lease = cache.lookup(cacheKey: secondPrepared.tokens, modelId: "gemma") XCTAssertFalse(lease.isHit) } private func localGemmaContainer() async throws -> ModelContainer { try await LocalGemmaFixture.shared.container() } private func trimCacheToPrompt(_ cache: [KVCache], promptTokenCount: Int) { for layer in cache { let excess = layer.offset - promptTokenCount if excess > 0 { XCTAssertTrue(layer.isTrimmable) XCTAssertEqual(layer.trim(excess), excess) } } } private func legacyBuild( from request: APIChatCompletionRequest, modelId: String, thinkingEnabled: Bool ) -> PromptBuilder.PreparedPrompt { var instructions = "" for msg in request.messages where msg.role == "system" { let text = msg.content?.textContent ?? "" if !text.isEmpty { if !instructions.isEmpty { instructions += "\n\n" } instructions += text } } if let tools = request.tools, !tools.isEmpty { let toolSystemPrompt = ToolPromptBuilder.buildSystemPrompt(tools: tools, modelId: modelId) if !instructions.isEmpty { instructions += "\n\n" } instructions += toolSystemPrompt } let isQwen = modelId.lowercased().contains("qwen") var chatMessages: [Chat.Message] = [] var messageSignatures: [UInt64] = [] var estimatedBytes = instructions.utf8.count var containsImages = false for msg in request.messages where msg.role != "system" { let role: Chat.Message.Role = switch msg.role { case "assistant": .assistant case "tool": .user default: .user } var text = msg.content?.textContent ?? "" if msg.role == "tool", !isQwen { text = "```tool_output\n\(text)\n```" } if msg.role == "assistant", let toolCalls = msg.tool_calls, !toolCalls.isEmpty { let formattedCalls = isQwen ? ToolPromptBuilder.formatQwenToolCalls(toolCalls) : ToolPromptBuilder.formatGemmaToolCalls(toolCalls) text = (text.isEmpty ? "" : text + "\n") + formattedCalls } let imageURLs = msg.content?.imageURLs ?? [] var messageImages: [UserInput.Image] = [] var messageImageBytes = 0 for urlString in imageURLs { if let decoded = ImageDecoder.decode(urlString) { messageImages.append(decoded.image) messageImageBytes += decoded.estimatedBytes } } containsImages = containsImages || !messageImages.isEmpty chatMessages.append(Chat.Message(role: role, content: text, images: messageImages)) messageSignatures.append(messageSignature(role: role, content: text, imageURLs: imageURLs)) estimatedBytes += text.utf8.count + messageImageBytes } let additionalContext: [String: any Sendable]? = thinkingEnabled ? nil : ["enable_thinking": false] let allImages = chatMessages.flatMap(\.images) let allMessages = (instructions.isEmpty ? [] : [Chat.Message(role: .system, content: instructions)]) + chatMessages let userInput = UserInput( prompt: .chat(allMessages), images: allImages, videos: [], tools: nil, additionalContext: additionalContext ) return PromptBuilder.PreparedPrompt( instructions: instructions, chatMessages: chatMessages, messageSignatures: messageSignatures, estimatedBytes: estimatedBytes, estimatedPromptTokens: (instructions.count + chatMessages.reduce(0) { $0 + $1.content.count }) * 10 / 35, containsImages: containsImages, additionalContext: additionalContext, userInput: userInput ) } private func messageSignature(role: Chat.Message.Role, content: String, imageURLs: [String]) -> UInt64 { var hash: UInt64 = 14_695_981_039_346_656_037 func mix(_ text: String) { for byte in text.utf8 { hash ^= UInt64(byte) hash &*= 1_099_511_628_211 } } switch role { case .assistant: mix("assistant") case .system: mix("system") case .user: mix("user") @unknown default: mix("unknown") } mix("|") mix(content) for imageURL in imageURLs { mix("|") mix(imageURL) } return hash } private func collectEngineOutput(_ stream: AsyncStream) async -> GenerationResult { var text = "" var promptTokenCount = 0 for await generation in stream { switch generation { case .chunk(let chunk): text += chunk case .info(let info): promptTokenCount = info.promptTokenCount case .toolCall: break } } return GenerationResult(text: text, promptTokenCount: promptTokenCount) } private func collectSessionOutput(_ stream: AsyncThrowingStream) async throws -> GenerationResult { var text = "" var promptTokenCount = 0 for try await generation in stream { switch generation { case .chunk(let chunk): text += chunk case .info(let info): promptTokenCount = info.promptTokenCount case .toolCall: break } } return GenerationResult(text: text, promptTokenCount: promptTokenCount) } } private struct GenerationResult { let text: String let promptTokenCount: Int } private actor LocalGemmaFixture { static let shared = LocalGemmaFixture() private var task: Task? func container() async throws -> ModelContainer { if let task { return try await task.value } guard let config = ModelConfig.resolve("gemma") else { throw XCTSkip("Gemma model config is unavailable") } guard let localDir = LocalModelResolver.resolve(repoId: config.repoId) else { throw XCTSkip("Local gemma cache is unavailable") } let loadTask = Task { let cachesDir = FileManager.default.urls(for: .cachesDirectory, in: .userDomainMask).first let hub = HubApi(downloadBase: cachesDir, cache: nil) return try await VLMModelFactory.shared.loadContainer( hub: hub, configuration: ModelConfiguration(directory: localDir), progressHandler: { _ in } ) } task = loadTask do { return try await loadTask.value } catch { task = nil throw error } } }