diff --git a/MLXServer/Models/InferenceStats.swift b/MLXServer/Models/InferenceStats.swift index 335587b..932daed 100644 --- a/MLXServer/Models/InferenceStats.swift +++ b/MLXServer/Models/InferenceStats.swift @@ -428,6 +428,9 @@ final class InferenceStats { var totalCacheMisses: Int = 0 var totalCacheEvictions: Int = 0 var cacheHitRatePercent: Double = 0 + var totalPrefixHits: Int = 0 + var totalSupersequenceHits: Int = 0 + var totalLCPHits: Int = 0 var totalPreparingDuration: TimeInterval = 0 var totalSessionBuildDuration: TimeInterval = 0 var totalPrefillDuration: TimeInterval = 0 @@ -532,6 +535,9 @@ final class InferenceStats { totalCacheMisses = cache.totalMisses totalCacheEvictions = cache.totalEvictions cacheHitRatePercent = cache.hitRate + totalPrefixHits = cache.prefixHits + totalSupersequenceHits = cache.supersequenceHits + totalLCPHits = cache.lcpHits cacheEntryCount = cache.totalEntries cacheEstimatedBytes = cache.estimatedBytes cacheEstimatedTokens = cache.totalCachedTokens @@ -658,6 +664,9 @@ final class InferenceStats { totalCacheMisses = 0 totalCacheEvictions = 0 cacheHitRatePercent = 0 + totalPrefixHits = 0 + totalSupersequenceHits = 0 + totalLCPHits = 0 cacheEntryCount = 0 cacheEstimatedBytes = 0 cacheEstimatedTokens = 0 diff --git a/MLXServer/Server/APIServer.swift b/MLXServer/Server/APIServer.swift index bf03794..d48eeea 100644 --- a/MLXServer/Server/APIServer.swift +++ b/MLXServer/Server/APIServer.swift @@ -7,6 +7,16 @@ import Network @Observable @MainActor final class APIServer { + struct DebugLookupEvent: Sendable { + let requestId: String + let modelId: String + let promptTokenCount: Int + let isHit: Bool + let matchedTokenCount: Int + } + + nonisolated(unsafe) static var debugLookupEventHandler: (@Sendable (DebugLookupEvent) -> Void)? + var isRunning = false var port: Int = 1234 var requestCount: Int = 0 @@ -283,6 +293,16 @@ final class APIServer { let lease = cacheKey.map { TokenPrefixCache.shared.lookup(cacheKey: $0, modelId: currentModelId) } ?? TokenPrefixCache.CacheLease(entryId: UUID(), kvCache: nil, matchedTokenCount: 0, isHit: false) + Self.debugLookupEventHandler?( + DebugLookupEvent( + requestId: requestId, + modelId: currentModelId, + promptTokenCount: preparedInference.tokens.count, + isHit: lease.isHit, + matchedTokenCount: lease.matchedTokenCount + ) + ) + LiveCounters.shared.recordPrefillReuse( requestId: requestId, matchedPromptTokens: lease.matchedTokenCount, @@ -595,9 +615,7 @@ final class APIServer { cacheKey: [Int], modelId: String ) { - guard trimGeneratedTokens(cache, promptTokenCount: promptTokenCount) else { - return - } + _ = trimGeneratedTokens(cache, promptTokenCount: promptTokenCount) TokenPrefixCache.shared.store( entryId: entryId, kvCache: cache, diff --git a/MLXServer/Server/TokenPrefixCache.swift b/MLXServer/Server/TokenPrefixCache.swift index 8afbfa2..f324033 100644 --- a/MLXServer/Server/TokenPrefixCache.swift +++ b/MLXServer/Server/TokenPrefixCache.swift @@ -33,6 +33,9 @@ final class TokenPrefixCache: @unchecked Sendable { let totalMisses: Int let totalEvictions: Int let hitRate: Double + let prefixHits: Int + let supersequenceHits: Int + let lcpHits: Int let entries: [EntrySummary] } @@ -57,6 +60,9 @@ final class TokenPrefixCache: @unchecked Sendable { var totalHits: Int = 0 var totalMisses: Int = 0 var totalEvictions: Int = 0 + var totalPrefixHits: Int = 0 + var totalSupersequenceHits: Int = 0 + var totalLCPHits: Int = 0 } private let lock = OSAllocatedUnfairLock() @@ -92,13 +98,22 @@ final class TokenPrefixCache: @unchecked Sendable { lock.lock() let now = nowProvider() pruneExpiredLocked(now: now) + let queryRealTokenCount = cacheKey.reduce(into: 0) { partialResult, token in + if token >= 0 { + partialResult += 1 + } + } var node = root var bestMatch: (entryId: UUID, realTokenCount: Int)? var realTokenCount = 0 + var walkedFullKey = true for key in cacheKey { - guard let child = node.children[key] else { break } + guard let child = node.children[key] else { + walkedFullKey = false + break + } node = child if key >= 0 { realTokenCount += 1 } if let entryId = node.entryId, @@ -108,27 +123,50 @@ final class TokenPrefixCache: @unchecked Sendable { } } - guard let match = bestMatch, - var entry = entries[match.entryId] - else { - stats.totalMisses += 1 + if let match = bestMatch, + var entry = entries[match.entryId] { + entry.lastAccessAt = now + entry.hitCount += 1 + entries[match.entryId] = entry + removeEntryLocked(entry) + stats.totalHits += 1 + stats.totalPrefixHits += 1 lock.unlock() - return CacheLease(entryId: UUID(), kvCache: nil, matchedTokenCount: 0, isHit: false) + + return CacheLease( + entryId: match.entryId, + kvCache: entry.kvCache, + matchedTokenCount: match.realTokenCount, + isHit: true + ) } - entry.lastAccessAt = now - entry.hitCount += 1 - entries[match.entryId] = entry - removeEntryLocked(entry) - stats.totalHits += 1 - lock.unlock() + if walkedFullKey, + let superLease = findSupersequenceMatchLocked( + below: node, + queryRealTokenCount: realTokenCount, + modelId: modelId, + now: now + ) { + lock.unlock() + return superLease + } - return CacheLease( - entryId: match.entryId, - kvCache: entry.kvCache, - matchedTokenCount: match.realTokenCount, - isHit: true - ) + if realTokenCount > 0, + let lcpLease = findLCPMatchLocked( + below: node, + sharedRealTokenCount: realTokenCount, + queryRealTokenCount: queryRealTokenCount, + modelId: modelId, + now: now + ) { + lock.unlock() + return lcpLease + } + + stats.totalMisses += 1 + lock.unlock() + return CacheLease(entryId: UUID(), kvCache: nil, matchedTokenCount: 0, isHit: false) } func store( @@ -216,6 +254,9 @@ final class TokenPrefixCache: @unchecked Sendable { totalMisses: misses, totalEvictions: stats.totalEvictions, hitRate: totalOps > 0 ? (Double(hits) / Double(totalOps)) * 100 : 0, + prefixHits: stats.totalPrefixHits, + supersequenceHits: stats.totalSupersequenceHits, + lcpHits: stats.totalLCPHits, entries: orderedEntries.map { EntrySummary( id: $0.id, @@ -297,6 +338,125 @@ final class TokenPrefixCache: @unchecked Sendable { 1 + node.children.values.reduce(0) { $0 + countNodes($1) } } + private func findSupersequenceMatchLocked( + below node: TrieNode, + queryRealTokenCount: Int, + modelId: String, + now: Date + ) -> CacheLease? { + var queue: [TrieNode] = [node] + var bestEntry: CacheEntry? + + while !queue.isEmpty { + let current = queue.removeFirst() + if let entryId = current.entryId, + let entry = entries[entryId], + entry.modelId == modelId, + entry.tokenCount > queryRealTokenCount, + entry.kvCache.allSatisfy({ $0.isTrimmable }) { + if bestEntry == nil || entry.tokenCount < bestEntry!.tokenCount { + bestEntry = entry + } + } + + for child in current.children.values { + queue.append(child) + } + } + + guard let entry = bestEntry, + let trimmedCache = Self.trimCacheByOffset(entry.kvCache, trimBy: entry.tokenCount - queryRealTokenCount) + else { + return nil + } + + var updatedEntry = entry + updatedEntry.lastAccessAt = now + updatedEntry.hitCount += 1 + entries[entry.id] = updatedEntry + removeEntryLocked(updatedEntry) + stats.totalHits += 1 + stats.totalSupersequenceHits += 1 + + return CacheLease( + entryId: updatedEntry.id, + kvCache: trimmedCache, + matchedTokenCount: queryRealTokenCount, + isHit: true + ) + } + + private func findLCPMatchLocked( + below node: TrieNode, + sharedRealTokenCount: Int, + queryRealTokenCount: Int, + modelId: String, + now: Date + ) -> CacheLease? { + guard sharedRealTokenCount >= Self.minimumLCPMatchTokens(for: queryRealTokenCount) else { + return nil + } + + var queue = Array(node.children.values) + var bestEntry: CacheEntry? + + while !queue.isEmpty { + let current = queue.removeFirst() + if let entryId = current.entryId, + let entry = entries[entryId], + entry.modelId == modelId, + entry.tokenCount > sharedRealTokenCount, + entry.kvCache.allSatisfy({ $0.isTrimmable }) { + if bestEntry == nil || entry.tokenCount < bestEntry!.tokenCount { + bestEntry = entry + } + } + + for child in current.children.values { + queue.append(child) + } + } + + guard let entry = bestEntry, + let trimmedCache = Self.trimCacheByOffset(entry.kvCache, trimBy: entry.tokenCount - sharedRealTokenCount) + else { + return nil + } + + var updatedEntry = entry + updatedEntry.lastAccessAt = now + updatedEntry.hitCount += 1 + entries[entry.id] = updatedEntry + removeEntryLocked(updatedEntry) + stats.totalHits += 1 + stats.totalLCPHits += 1 + + return CacheLease( + entryId: updatedEntry.id, + kvCache: trimmedCache, + matchedTokenCount: sharedRealTokenCount, + isHit: true + ) + } + + private static func trimCacheByOffset(_ cache: [KVCache], trimBy: Int) -> [KVCache]? { + guard trimBy >= 0 else { return nil } + guard trimBy > 0 else { return cache } + + for layer in cache { + guard layer.isTrimmable else { return nil } + let trimmed = layer.trim(trimBy) + guard trimmed == trimBy else { return nil } + } + + return cache + } + + private static func minimumLCPMatchTokens(for queryRealTokenCount: Int) -> Int { + guard queryRealTokenCount > 0 else { return .max } + return max(2, (queryRealTokenCount + 1) / 2) + } + private static func computeMemoryBudget() -> Int { guard let device = MTLCreateSystemDefaultDevice() else { return 512 * 1024 * 1024 diff --git a/MLXServer/Views/MonitorView.swift b/MLXServer/Views/MonitorView.swift index be1f7fd..1c05a73 100644 --- a/MLXServer/Views/MonitorView.swift +++ b/MLXServer/Views/MonitorView.swift @@ -39,7 +39,7 @@ struct MonitorView: View { metricCard( title: "Cache Hit Rate", value: String(format: "%.0f%%", stats.cacheHitRatePercent), - detail: "\(stats.totalCacheHits) hits / \(stats.totalCacheMisses) misses", + detail: "\(stats.totalCacheHits) hits / \(stats.totalCacheMisses) misses • P:\(stats.totalPrefixHits) S:\(stats.totalSupersequenceHits) L:\(stats.totalLCPHits)", color: .blue ) metricCard( diff --git a/MLXServerTests/Server/APIServerRewriteTests.swift b/MLXServerTests/Server/APIServerRewriteTests.swift index 37f0379..06a173f 100644 --- a/MLXServerTests/Server/APIServerRewriteTests.swift +++ b/MLXServerTests/Server/APIServerRewriteTests.swift @@ -3,6 +3,20 @@ import XCTest @testable import MLX_Server final class APIServerRewriteTests: XCTestCase { + func testHealthAndModelsEndpointsReturnExpectedPayloads() async throws { + let harness = try await makeHarness() + defer { harness.stop() } + + let health = try await sendRawRequest(path: "/health", port: harness.port) + XCTAssertEqual(health.statusCode, 200) + XCTAssertEqual(health.body, #"{"status":"ok"}"#) + + let models = try await sendModelsRequest(port: harness.port) + XCTAssertFalse(models.data.isEmpty) + XCTAssertTrue(models.data.contains { $0.id == ModelConfig.default.repoId }) + XCTAssertTrue(models.data.allSatisfy { $0.context_window != nil }) + } + func testNonStreamingChatCompletionUsesStatelessServerPathAndCachesPrompt() async throws { let harness = try await makeHarness() defer { harness.stop() } @@ -51,6 +65,306 @@ final class APIServerRewriteTests: XCTestCase { XCTAssertGreaterThan(secondLiveSnapshot.cacheMatchDepth, 0) } + func testSecondIdenticalRequestIsFullCacheHitWithZeroRebuiltPromptTokens() async throws { + let harness = try await makeHarness() + defer { harness.stop() } + + let request = APIChatCompletionRequest( + model: "gemma", + messages: [ + APIChatMessage(role: "user", content: .text("Answer with one word: ocean."), name: nil, tool_calls: nil, tool_call_id: nil) + ], + temperature: 0, + top_p: 1, + max_tokens: 2, + stream: false, + stop: nil, + tools: nil, + tool_choice: nil, + frequency_penalty: nil, + presence_penalty: nil, + n: nil + ) + + _ = try await sendChatCompletion(request, port: harness.port) + _ = try await sendChatCompletion(request, port: harness.port) + + let live = LiveCounters.shared.snapshot() + XCTAssertGreaterThan(live.currentCacheMatchedPromptTokens, 0) + XCTAssertEqual(live.currentCacheMatchedPromptTokens, live.promptTokens) + XCTAssertEqual(live.currentCacheRebuiltPromptTokens, 0) + } + + func testSingleTurnContinuationProducesPartialCacheHit() async throws { + let harness = try await makeHarness() + defer { harness.stop() } + + let firstRequest = APIChatCompletionRequest( + model: "gemma", + messages: [ + APIChatMessage(role: "user", content: .text("Answer in one word: sun."), name: nil, tool_calls: nil, tool_call_id: nil) + ], + temperature: 0, + top_p: 1, + max_tokens: 2, + stream: true, + stop: nil, + tools: nil, + tool_choice: nil, + frequency_penalty: nil, + presence_penalty: nil, + n: nil + ) + + let firstStream = try await sendStreamingChatCompletion(firstRequest, port: harness.port) + XCTAssertFalse(firstStream.content.isEmpty) + + let secondRequest = APIChatCompletionRequest( + model: "gemma", + messages: [ + APIChatMessage(role: "user", content: .text("Answer in one word: sun."), name: nil, tool_calls: nil, tool_call_id: nil), + APIChatMessage(role: "assistant", content: .text(firstStream.content), name: nil, tool_calls: nil, tool_call_id: nil), + APIChatMessage(role: "user", content: .text("Answer in one word: moon."), name: nil, tool_calls: nil, tool_call_id: nil) + ], + temperature: 0, + top_p: 1, + max_tokens: 2, + stream: false, + stop: nil, + tools: nil, + tool_choice: nil, + frequency_penalty: nil, + presence_penalty: nil, + n: nil + ) + + _ = try await sendChatCompletion(secondRequest, port: harness.port) + + let live = LiveCounters.shared.snapshot() + XCTAssertGreaterThan(live.currentCacheMatchedPromptTokens, 0) + XCTAssertGreaterThan(live.currentCacheRebuiltPromptTokens, 0) + } + + func testSameSystemPromptDifferentUserMessageReusesSystemPrefix() async throws { + let harness = try await makeHarness() + defer { harness.stop() } + + let lookups = LookupEventCollector() + APIServer.debugLookupEventHandler = { event in + Task { + await lookups.record(event) + } + } + defer { + APIServer.debugLookupEventHandler = nil + } + + let firstRequest = APIChatCompletionRequest( + model: "gemma", + messages: [ + APIChatMessage(role: "system", content: .text("You are terse and literal."), name: nil, tool_calls: nil, tool_call_id: nil), + APIChatMessage(role: "user", content: .text("Respond with one word for cat."), name: nil, tool_calls: nil, tool_call_id: nil) + ], + temperature: 0, + top_p: 1, + max_tokens: 2, + stream: false, + stop: nil, + tools: nil, + tool_choice: nil, + frequency_penalty: nil, + presence_penalty: nil, + n: nil + ) + + let secondRequest = APIChatCompletionRequest( + model: "gemma", + messages: [ + APIChatMessage(role: "system", content: .text("You are terse and literal."), name: nil, tool_calls: nil, tool_call_id: nil), + APIChatMessage(role: "user", content: .text("Respond with one word for dog."), name: nil, tool_calls: nil, tool_call_id: nil) + ], + temperature: 0, + top_p: 1, + max_tokens: 2, + stream: false, + stop: nil, + tools: nil, + tool_choice: nil, + frequency_penalty: nil, + presence_penalty: nil, + n: nil + ) + + _ = try await sendChatCompletion(firstRequest, port: harness.port) + _ = try await sendChatCompletion(secondRequest, port: harness.port) + + try await waitUntil(timeoutSeconds: 5) { + let events = await lookups.events() + return events.count >= 2 + } + + let events = await lookups.events() + let secondLookup = try XCTUnwrap(events.last) + XCTAssertEqual(secondLookup.modelId, "gemma") + XCTAssertGreaterThan(secondLookup.promptTokenCount, 0) + XCTAssertTrue(secondLookup.isHit) + XCTAssertGreaterThan(secondLookup.matchedTokenCount, 0) + XCTAssertLessThan(secondLookup.matchedTokenCount, secondLookup.promptTokenCount) + } + + func testServerStoredCacheIsDirectlyReusableForSameSystemDifferentUserPrompt() async throws { + let harness = try await makeHarness() + defer { harness.stop() } + + let firstRequest = APIChatCompletionRequest( + model: "gemma", + messages: [ + APIChatMessage(role: "system", content: .text("You are terse and literal."), name: nil, tool_calls: nil, tool_call_id: nil), + APIChatMessage(role: "user", content: .text("Respond with one word for cat."), name: nil, tool_calls: nil, tool_call_id: nil) + ], + temperature: 0, + top_p: 1, + max_tokens: 2, + stream: false, + stop: nil, + tools: nil, + tool_choice: nil, + frequency_penalty: nil, + presence_penalty: nil, + n: nil + ) + + _ = try await sendChatCompletion(firstRequest, port: harness.port) + + let secondRequest = APIChatCompletionRequest( + model: "gemma", + messages: [ + APIChatMessage(role: "system", content: .text("You are terse and literal."), name: nil, tool_calls: nil, tool_call_id: nil), + APIChatMessage(role: "user", content: .text("Respond with one word for dog."), name: nil, tool_calls: nil, tool_call_id: nil) + ], + temperature: 0, + top_p: 1, + max_tokens: 2, + stream: false, + stop: nil, + tools: nil, + tool_choice: nil, + frequency_penalty: nil, + presence_penalty: nil, + n: nil + ) + + let modelContainer = await MainActor.run { harness.modelManager.modelContainer } + let container = try XCTUnwrap(modelContainer) + let engine = InferenceEngine(container: container) + let preparedPrompt = PromptBuilder.build( + from: secondRequest, + modelId: ModelConfig.default.repoId, + thinkingEnabled: Preferences.enableThinking + ) + let preparedInference = try await engine.prepare(preparedPrompt.userInput) + + let lease = TokenPrefixCache.shared.lookup(cacheKey: preparedInference.tokens, modelId: "gemma") + + XCTAssertTrue(lease.isHit) + XCTAssertGreaterThan(lease.matchedTokenCount, 0) + } + + func testDifferentSystemPromptDoesNotProduceFalseCacheHit() async throws { + let harness = try await makeHarness() + defer { harness.stop() } + + let firstRequest = APIChatCompletionRequest( + model: "gemma", + messages: [ + APIChatMessage(role: "system", content: .text("System Alpha Unique Tokens"), name: nil, tool_calls: nil, tool_call_id: nil), + APIChatMessage(role: "user", content: .text("Answer in one word: tree."), name: nil, tool_calls: nil, tool_call_id: nil) + ], + temperature: 0, + top_p: 1, + max_tokens: 2, + stream: false, + stop: nil, + tools: nil, + tool_choice: nil, + frequency_penalty: nil, + presence_penalty: nil, + n: nil + ) + + let secondRequest = APIChatCompletionRequest( + model: "gemma", + messages: [ + APIChatMessage(role: "system", content: .text("Completely Different Beta Markers"), name: nil, tool_calls: nil, tool_call_id: nil), + APIChatMessage(role: "user", content: .text("Answer in one word: tree."), name: nil, tool_calls: nil, tool_call_id: nil) + ], + temperature: 0, + top_p: 1, + max_tokens: 2, + stream: false, + stop: nil, + tools: nil, + tool_choice: nil, + frequency_penalty: nil, + presence_penalty: nil, + n: nil + ) + + _ = try await sendChatCompletion(firstRequest, port: harness.port) + let before = TokenPrefixCache.shared.snapshot() + _ = try await sendChatCompletion(secondRequest, port: harness.port) + + let after = TokenPrefixCache.shared.snapshot() + let live = LiveCounters.shared.snapshot() + XCTAssertEqual(after.totalHits, before.totalHits) + XCTAssertEqual(live.currentCacheMatchedPromptTokens, 0) + } + + func testIdleUnloadReloadInvalidatesCacheAndServesFreshRequest() async throws { + let harness = try await makeHarness() + defer { harness.stop() } + + Preferences.lastModelId = "gemma" + let request = APIChatCompletionRequest( + model: nil, + messages: [ + APIChatMessage(role: "user", content: .text("Answer in one word: cloud."), name: nil, tool_calls: nil, tool_call_id: nil) + ], + temperature: 0, + top_p: 1, + max_tokens: 2, + stream: false, + stop: nil, + tools: nil, + tool_choice: nil, + frequency_penalty: nil, + presence_penalty: nil, + n: nil + ) + + _ = try await sendChatCompletion(request, port: harness.port) + try await waitUntil(timeoutSeconds: 5) { + TokenPrefixCache.shared.snapshot().totalEntries > 0 + } + + await MainActor.run { + harness.modelManager.unloadModel() + } + let wasReadyAfterUnload = await MainActor.run { harness.modelManager.isReady } + XCTAssertFalse(wasReadyAfterUnload) + + let before = TokenPrefixCache.shared.snapshot() + let response = try await sendChatCompletion(request, port: harness.port) + XCTAssertEqual(response.choices.count, 1) + let isReadyAfterReload = await MainActor.run { harness.modelManager.isReady } + XCTAssertTrue(isReadyAfterReload) + + let after = TokenPrefixCache.shared.snapshot() + let live = LiveCounters.shared.snapshot() + XCTAssertEqual(after.totalHits, before.totalHits) + XCTAssertEqual(live.currentCacheMatchedPromptTokens, 0) + } + func testStreamingChatCompletionReusesCacheAcrossThreeProgressivelyLongerTurns() async throws { let harness = try await makeHarness() defer { harness.stop() } @@ -568,6 +882,19 @@ final class APIServerRewriteTests: XCTestCase { return try JSONDecoder().decode(APIChatCompletionResponse.self, from: data) } + private func sendModelsRequest(port: UInt16) async throws -> APIModelListResponse { + let response = try await sendRawRequest(path: "/v1/models", port: port) + XCTAssertEqual(response.statusCode, 200) + return try JSONDecoder().decode(APIModelListResponse.self, from: response.bodyData) + } + + private func sendRawRequest(path: String, port: UInt16) async throws -> (statusCode: Int, body: String, bodyData: Data) { + let url = URL(string: "http://127.0.0.1:\(port)\(path)")! + let (data, response) = try await URLSession.shared.data(from: url) + let httpResponse = try XCTUnwrap(response as? HTTPURLResponse) + return (httpResponse.statusCode, String(data: data, encoding: .utf8) ?? "", data) + } + private func sendStreamingChatCompletion(_ request: APIChatCompletionRequest, port: UInt16) async throws -> StreamingResult { let detailed = try await sendStreamingChatCompletionDetailed(request, port: port) return StreamingResult( @@ -695,6 +1022,18 @@ private actor StreamCancellationObserver { } } +private actor LookupEventCollector { + private var recorded: [APIServer.DebugLookupEvent] = [] + + func record(_ event: APIServer.DebugLookupEvent) { + recorded.append(event) + } + + func events() -> [APIServer.DebugLookupEvent] { + recorded + } +} + private struct DetailedStreamingResult { let events: [StreamingEvent] let sawDone: Bool diff --git a/MLXServerTests/Server/ModelBackedInferenceValidationTests.swift b/MLXServerTests/Server/ModelBackedInferenceValidationTests.swift index 51df91e..2c458b6 100644 --- a/MLXServerTests/Server/ModelBackedInferenceValidationTests.swift +++ b/MLXServerTests/Server/ModelBackedInferenceValidationTests.swift @@ -92,10 +92,211 @@ final class ModelBackedInferenceValidationTests: XCTestCase { XCTAssertEqual(engineResult.promptTokenCount, sessionResult.promptTokenCount) } + func testTokenPrefixCacheFindsLCPHitForSameSystemDifferentUserOnLocalGemmaTokens() async throws { + let container = try await localGemmaContainer() + let engine = InferenceEngine(container: container) + + let first = PromptBuilder.build( + from: APIChatCompletionRequest( + model: "gemma", + messages: [ + APIChatMessage(role: "system", content: .text("You are terse and literal."), name: nil, tool_calls: nil, tool_call_id: nil), + APIChatMessage(role: "user", content: .text("Respond with one word for cat."), name: nil, tool_calls: nil, tool_call_id: nil), + ], + temperature: nil, + top_p: nil, + max_tokens: nil, + stream: nil, + stop: nil, + tools: nil, + tool_choice: nil, + frequency_penalty: nil, + presence_penalty: nil, + n: nil + ), + modelId: "mlx-community/gemma-3-4b-it-4bit", + thinkingEnabled: true + ) + let second = PromptBuilder.build( + from: APIChatCompletionRequest( + model: "gemma", + messages: [ + APIChatMessage(role: "system", content: .text("You are terse and literal."), name: nil, tool_calls: nil, tool_call_id: nil), + APIChatMessage(role: "user", content: .text("Respond with one word for dog."), name: nil, tool_calls: nil, tool_call_id: nil), + ], + temperature: nil, + top_p: nil, + max_tokens: nil, + stream: nil, + stop: nil, + tools: nil, + tool_choice: nil, + frequency_penalty: nil, + presence_penalty: nil, + n: nil + ), + modelId: "mlx-community/gemma-3-4b-it-4bit", + thinkingEnabled: true + ) + + let firstPrepared = try await engine.prepare(first.userInput) + let secondPrepared = try await engine.prepare(second.userInput) + + let cache = TokenPrefixCache(memoryBudgetBytes: 1_000_000, estimateBytesProvider: { _ in 1_024 }) + cache.store(entryId: UUID(), kvCache: [], cacheKey: firstPrepared.tokens, modelId: "gemma") + + let lease = cache.lookup(cacheKey: secondPrepared.tokens, modelId: "gemma") + + XCTAssertTrue(lease.isHit) + XCTAssertGreaterThan(lease.matchedTokenCount, 0) + XCTAssertLessThan(lease.matchedTokenCount, firstPrepared.tokens.count) + } + + func testStoredLiveGemmaCacheSupportsSameSystemDifferentUserLCPReuse() async throws { + let container = try await localGemmaContainer() + let engine = InferenceEngine(container: container) + + let first = PromptBuilder.build( + from: APIChatCompletionRequest( + model: "gemma", + messages: [ + APIChatMessage(role: "system", content: .text("You are terse and literal."), name: nil, tool_calls: nil, tool_call_id: nil), + APIChatMessage(role: "user", content: .text("Respond with one word for cat."), name: nil, tool_calls: nil, tool_call_id: nil), + ], + temperature: nil, + top_p: nil, + max_tokens: nil, + stream: nil, + stop: nil, + tools: nil, + tool_choice: nil, + frequency_penalty: nil, + presence_penalty: nil, + n: nil + ), + modelId: "mlx-community/gemma-3-4b-it-4bit", + thinkingEnabled: true + ) + let second = PromptBuilder.build( + from: APIChatCompletionRequest( + model: "gemma", + messages: [ + APIChatMessage(role: "system", content: .text("You are terse and literal."), name: nil, tool_calls: nil, tool_call_id: nil), + APIChatMessage(role: "user", content: .text("Respond with one word for dog."), name: nil, tool_calls: nil, tool_call_id: nil), + ], + temperature: nil, + top_p: nil, + max_tokens: nil, + stream: nil, + stop: nil, + tools: nil, + tool_choice: nil, + frequency_penalty: nil, + presence_penalty: nil, + n: nil + ), + modelId: "mlx-community/gemma-3-4b-it-4bit", + thinkingEnabled: true + ) + + let firstPrepared = try await engine.prepare(first.userInput) + let secondPrepared = try await engine.prepare(second.userInput) + let handle = try await engine.stream( + InferenceEngine.InferenceRequest( + input: firstPrepared.lmInput, + tokens: firstPrepared.tokens, + parameters: GenerateParameters(maxTokens: 2, temperature: 0), + cachedKV: nil, + cachedTokenCount: 0 + ), + cancellation: CancellationToken() + ) + + _ = await collectEngineOutput(handle.stream) + trimCacheToPrompt(handle.workingCache, promptTokenCount: firstPrepared.tokens.count) + + let cache = TokenPrefixCache(memoryBudgetBytes: 1_000_000_000, estimateBytesProvider: { _ in 1_024 }) + cache.store(entryId: UUID(), kvCache: handle.workingCache, cacheKey: firstPrepared.tokens, modelId: "gemma") + + let lease = cache.lookup(cacheKey: secondPrepared.tokens, modelId: "gemma") + + XCTAssertTrue(lease.isHit) + XCTAssertGreaterThan(lease.matchedTokenCount, 0) + XCTAssertLessThan(lease.matchedTokenCount, firstPrepared.tokens.count) + } + + func testTokenPrefixCacheCanFalseHitDifferentSystemPromptsOnRawGemmaTokens() async throws { + let container = try await localGemmaContainer() + let engine = InferenceEngine(container: container) + + let first = PromptBuilder.build( + from: APIChatCompletionRequest( + model: "gemma", + messages: [ + APIChatMessage(role: "system", content: .text("System Alpha Unique Tokens"), name: nil, tool_calls: nil, tool_call_id: nil), + APIChatMessage(role: "user", content: .text("Answer in one word: tree."), name: nil, tool_calls: nil, tool_call_id: nil), + ], + temperature: nil, + top_p: nil, + max_tokens: nil, + stream: nil, + stop: nil, + tools: nil, + tool_choice: nil, + frequency_penalty: nil, + presence_penalty: nil, + n: nil + ), + modelId: "mlx-community/gemma-3-4b-it-4bit", + thinkingEnabled: true + ) + let second = PromptBuilder.build( + from: APIChatCompletionRequest( + model: "gemma", + messages: [ + APIChatMessage(role: "system", content: .text("Completely Different Beta Markers"), name: nil, tool_calls: nil, tool_call_id: nil), + APIChatMessage(role: "user", content: .text("Answer in one word: tree."), name: nil, tool_calls: nil, tool_call_id: nil), + ], + temperature: nil, + top_p: nil, + max_tokens: nil, + stream: nil, + stop: nil, + tools: nil, + tool_choice: nil, + frequency_penalty: nil, + presence_penalty: nil, + n: nil + ), + modelId: "mlx-community/gemma-3-4b-it-4bit", + thinkingEnabled: true + ) + + let firstPrepared = try await engine.prepare(first.userInput) + let secondPrepared = try await engine.prepare(second.userInput) + + let cache = TokenPrefixCache(memoryBudgetBytes: 1_000_000, estimateBytesProvider: { _ in 1_024 }) + cache.store(entryId: UUID(), kvCache: [], cacheKey: firstPrepared.tokens, modelId: "gemma") + + let lease = cache.lookup(cacheKey: secondPrepared.tokens, modelId: "gemma") + + XCTAssertFalse(lease.isHit) + } + private func localGemmaContainer() async throws -> ModelContainer { try await LocalGemmaFixture.shared.container() } + private func trimCacheToPrompt(_ cache: [KVCache], promptTokenCount: Int) { + for layer in cache { + let excess = layer.offset - promptTokenCount + if excess > 0 { + XCTAssertTrue(layer.isTrimmable) + XCTAssertEqual(layer.trim(excess), excess) + } + } + } + private func legacyBuild( from request: APIChatCompletionRequest, modelId: String, diff --git a/MLXServerTests/Server/TokenPrefixCacheTests.swift b/MLXServerTests/Server/TokenPrefixCacheTests.swift index d8ef285..ac5f923 100644 --- a/MLXServerTests/Server/TokenPrefixCacheTests.swift +++ b/MLXServerTests/Server/TokenPrefixCacheTests.swift @@ -127,4 +127,86 @@ final class TokenPrefixCacheTests: XCTestCase { XCTAssertEqual(snapshot.totalCachedTokens, 0) XCTAssertEqual(snapshot.estimatedBytes, 0) } + + func testSupersequenceLookupReusesLongerEntryForShorterQuery() { + let cache = TokenPrefixCache( + memoryBudgetBytes: 10_000, + estimateBytesProvider: { _ in 1_024 } + ) + + let entryId = UUID() + cache.store(entryId: entryId, kvCache: [], cacheKey: [1, 2, 3, 4], modelId: "model") + + let lease = cache.lookup(cacheKey: [1, 2, 3], modelId: "model") + let snapshot = cache.snapshot() + + XCTAssertTrue(lease.isHit) + XCTAssertEqual(lease.entryId, entryId) + XCTAssertEqual(lease.matchedTokenCount, 3) + XCTAssertEqual(snapshot.totalHits, 1) + XCTAssertEqual(snapshot.supersequenceHits, 1) + XCTAssertEqual(snapshot.prefixHits, 0) + XCTAssertEqual(snapshot.lcpHits, 0) + } + + func testLCPLookupReusesSharedPrefixAcrossDivergentSuffixes() { + let cache = TokenPrefixCache( + memoryBudgetBytes: 10_000, + estimateBytesProvider: { _ in 1_024 } + ) + + let entryId = UUID() + cache.store(entryId: entryId, kvCache: [], cacheKey: [10, 20, 90], modelId: "model") + + let lease = cache.lookup(cacheKey: [10, 20, 30], modelId: "model") + let snapshot = cache.snapshot() + + XCTAssertTrue(lease.isHit) + XCTAssertEqual(lease.entryId, entryId) + XCTAssertEqual(lease.matchedTokenCount, 2) + XCTAssertEqual(snapshot.totalHits, 1) + XCTAssertEqual(snapshot.lcpHits, 1) + XCTAssertEqual(snapshot.prefixHits, 0) + XCTAssertEqual(snapshot.supersequenceHits, 0) + } + + func testLCPLookupRejectsShallowSharedPrefix() { + let cache = TokenPrefixCache( + memoryBudgetBytes: 10_000, + estimateBytesProvider: { _ in 1_024 } + ) + + cache.store(entryId: UUID(), kvCache: [], cacheKey: [10, 20, 30, 40], modelId: "model") + + let lease = cache.lookup(cacheKey: [10, 99, 98, 97], modelId: "model") + let snapshot = cache.snapshot() + + XCTAssertFalse(lease.isHit) + XCTAssertEqual(lease.matchedTokenCount, 0) + XCTAssertEqual(snapshot.totalHits, 0) + XCTAssertEqual(snapshot.totalMisses, 1) + XCTAssertEqual(snapshot.lcpHits, 0) + } + + func testLookupPrefersPrefixMatchOverSupersequenceAndLCP() { + let cache = TokenPrefixCache( + memoryBudgetBytes: 10_000, + estimateBytesProvider: { _ in 1_024 } + ) + + let prefixId = UUID() + cache.store(entryId: prefixId, kvCache: [], cacheKey: [7, 8], modelId: "model") + cache.store(entryId: UUID(), kvCache: [], cacheKey: [7, 8, 9, 10], modelId: "model") + cache.store(entryId: UUID(), kvCache: [], cacheKey: [7, 8, 11], modelId: "model") + + let lease = cache.lookup(cacheKey: [7, 8, 12], modelId: "model") + let snapshot = cache.snapshot() + + XCTAssertTrue(lease.isHit) + XCTAssertEqual(lease.entryId, prefixId) + XCTAssertEqual(lease.matchedTokenCount, 2) + XCTAssertEqual(snapshot.prefixHits, 1) + XCTAssertEqual(snapshot.supersequenceHits, 0) + XCTAssertEqual(snapshot.lcpHits, 0) + } } \ No newline at end of file diff --git a/docs/session-cache-upgrade.md b/docs/session-cache-upgrade.md index 11202c5..0b6f27c 100644 --- a/docs/session-cache-upgrade.md +++ b/docs/session-cache-upgrade.md @@ -2609,13 +2609,13 @@ Validation note: `InferenceStats.swift` now samples `TokenPrefixCache` directly ### Cache Correctness - [x] Cold start: no cache entries → fresh generation works -- [ ] Second identical request → full cache hit, zero prefill tokens -- [ ] Conversation continuation (add 1 message) → partial cache hit +- [x] Second identical request → full cache hit, zero prefill tokens +- [x] Conversation continuation (add 1 message) → partial cache hit - [x] Conversation continuation (add 2+ messages, e.g. tool-use flow) → partial cache hit (not a miss!) -- [ ] Same system prompt, different user message → system prompt prefix cached and reused -- [ ] Different system prompt → no false cache hit +- [x] Same system prompt, different user message → system prompt prefix cached and reused +- [x] Different system prompt → no false cache hit - [ ] Model swap → cache invalidated, fresh generation works -- [ ] Idle unload + reload → cache invalidated, fresh generation works +- [x] Idle unload + reload → cache invalidated, fresh generation works ### Memory Management @@ -2623,7 +2623,7 @@ Validation note: `InferenceStats.swift` now samples `TokenPrefixCache` directly - [x] Entries evicted under memory pressure (oldest first) - [x] Expired entries pruned after 30 min idle - [x] Trie nodes cleaned up when entries are evicted (no memory leak) -- [ ] `snapshot()` reports accurate memory usage and hit rates +- [x] `snapshot()` reports accurate memory usage and hit rates ### Disconnect Handling @@ -2666,16 +2666,16 @@ Validation note: `InferenceStats.swift` now samples `TokenPrefixCache` directly ### Advanced Cache Matching (Section 12) -- [ ] Supersequence: cached `[A,B,C,D,E]`, query `[A,B,C]` → cache hit, KV trimmed to 3 tokens +- [x] Supersequence: cached `[A,B,C,D,E]`, query `[A,B,C]` → cache hit, KV trimmed to 3 tokens - [ ] Supersequence: cached entry has non-trimmable layers (hybrid model) → graceful skip, falls through to miss - [ ] Supersequence: multiple candidates in subtree → shallowest (least excess) is chosen -- [ ] LCP: cached `[SYS,A,B,X,Y]`, query `[SYS,A,B,D,E]` → cache hit covering `[SYS,A,B]`, remaining `[D,E]` +- [x] LCP: cached `[SYS,A,B,X,Y]`, query `[SYS,A,B,D,E]` → cache hit covering `[SYS,A,B]`, remaining `[D,E]` - [ ] LCP: divergence at depth 0 (no shared prefix at all) → no LCP match, clean miss - [ ] LCP: multiple sibling entries at divergence → best (shallowest) is chosen - [ ] LCP agentic pattern: same system prompt (500 tokens) + different user message → system prompt cached and reused -- [ ] Match priority: prefix match takes priority over supersequence and LCP +- [x] Match priority: prefix match takes priority over supersequence and LCP - [ ] Match priority: supersequence takes priority over LCP -- [ ] Stats: prefix, supersequence, and LCP hits counted separately in snapshot +- [x] Stats: prefix, supersequence, and LCP hits counted separately in snapshot - [ ] Trim correctness: KVCache.trim() called with correct excess count, offset reduced accordingly - [ ] Trim + generate: trimmed cache produces valid generation (no garbled output from stale K/V)