fix: more hardening of cache behaviour and some fixes
This commit is contained in:
@@ -428,6 +428,9 @@ final class InferenceStats {
|
|||||||
var totalCacheMisses: Int = 0
|
var totalCacheMisses: Int = 0
|
||||||
var totalCacheEvictions: Int = 0
|
var totalCacheEvictions: Int = 0
|
||||||
var cacheHitRatePercent: Double = 0
|
var cacheHitRatePercent: Double = 0
|
||||||
|
var totalPrefixHits: Int = 0
|
||||||
|
var totalSupersequenceHits: Int = 0
|
||||||
|
var totalLCPHits: Int = 0
|
||||||
var totalPreparingDuration: TimeInterval = 0
|
var totalPreparingDuration: TimeInterval = 0
|
||||||
var totalSessionBuildDuration: TimeInterval = 0
|
var totalSessionBuildDuration: TimeInterval = 0
|
||||||
var totalPrefillDuration: TimeInterval = 0
|
var totalPrefillDuration: TimeInterval = 0
|
||||||
@@ -532,6 +535,9 @@ final class InferenceStats {
|
|||||||
totalCacheMisses = cache.totalMisses
|
totalCacheMisses = cache.totalMisses
|
||||||
totalCacheEvictions = cache.totalEvictions
|
totalCacheEvictions = cache.totalEvictions
|
||||||
cacheHitRatePercent = cache.hitRate
|
cacheHitRatePercent = cache.hitRate
|
||||||
|
totalPrefixHits = cache.prefixHits
|
||||||
|
totalSupersequenceHits = cache.supersequenceHits
|
||||||
|
totalLCPHits = cache.lcpHits
|
||||||
cacheEntryCount = cache.totalEntries
|
cacheEntryCount = cache.totalEntries
|
||||||
cacheEstimatedBytes = cache.estimatedBytes
|
cacheEstimatedBytes = cache.estimatedBytes
|
||||||
cacheEstimatedTokens = cache.totalCachedTokens
|
cacheEstimatedTokens = cache.totalCachedTokens
|
||||||
@@ -658,6 +664,9 @@ final class InferenceStats {
|
|||||||
totalCacheMisses = 0
|
totalCacheMisses = 0
|
||||||
totalCacheEvictions = 0
|
totalCacheEvictions = 0
|
||||||
cacheHitRatePercent = 0
|
cacheHitRatePercent = 0
|
||||||
|
totalPrefixHits = 0
|
||||||
|
totalSupersequenceHits = 0
|
||||||
|
totalLCPHits = 0
|
||||||
cacheEntryCount = 0
|
cacheEntryCount = 0
|
||||||
cacheEstimatedBytes = 0
|
cacheEstimatedBytes = 0
|
||||||
cacheEstimatedTokens = 0
|
cacheEstimatedTokens = 0
|
||||||
|
|||||||
@@ -7,6 +7,16 @@ import Network
|
|||||||
@Observable
|
@Observable
|
||||||
@MainActor
|
@MainActor
|
||||||
final class APIServer {
|
final class APIServer {
|
||||||
|
struct DebugLookupEvent: Sendable {
|
||||||
|
let requestId: String
|
||||||
|
let modelId: String
|
||||||
|
let promptTokenCount: Int
|
||||||
|
let isHit: Bool
|
||||||
|
let matchedTokenCount: Int
|
||||||
|
}
|
||||||
|
|
||||||
|
nonisolated(unsafe) static var debugLookupEventHandler: (@Sendable (DebugLookupEvent) -> Void)?
|
||||||
|
|
||||||
var isRunning = false
|
var isRunning = false
|
||||||
var port: Int = 1234
|
var port: Int = 1234
|
||||||
var requestCount: Int = 0
|
var requestCount: Int = 0
|
||||||
@@ -283,6 +293,16 @@ final class APIServer {
|
|||||||
let lease = cacheKey.map { TokenPrefixCache.shared.lookup(cacheKey: $0, modelId: currentModelId) }
|
let lease = cacheKey.map { TokenPrefixCache.shared.lookup(cacheKey: $0, modelId: currentModelId) }
|
||||||
?? TokenPrefixCache.CacheLease(entryId: UUID(), kvCache: nil, matchedTokenCount: 0, isHit: false)
|
?? TokenPrefixCache.CacheLease(entryId: UUID(), kvCache: nil, matchedTokenCount: 0, isHit: false)
|
||||||
|
|
||||||
|
Self.debugLookupEventHandler?(
|
||||||
|
DebugLookupEvent(
|
||||||
|
requestId: requestId,
|
||||||
|
modelId: currentModelId,
|
||||||
|
promptTokenCount: preparedInference.tokens.count,
|
||||||
|
isHit: lease.isHit,
|
||||||
|
matchedTokenCount: lease.matchedTokenCount
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
LiveCounters.shared.recordPrefillReuse(
|
LiveCounters.shared.recordPrefillReuse(
|
||||||
requestId: requestId,
|
requestId: requestId,
|
||||||
matchedPromptTokens: lease.matchedTokenCount,
|
matchedPromptTokens: lease.matchedTokenCount,
|
||||||
@@ -595,9 +615,7 @@ final class APIServer {
|
|||||||
cacheKey: [Int],
|
cacheKey: [Int],
|
||||||
modelId: String
|
modelId: String
|
||||||
) {
|
) {
|
||||||
guard trimGeneratedTokens(cache, promptTokenCount: promptTokenCount) else {
|
_ = trimGeneratedTokens(cache, promptTokenCount: promptTokenCount)
|
||||||
return
|
|
||||||
}
|
|
||||||
TokenPrefixCache.shared.store(
|
TokenPrefixCache.shared.store(
|
||||||
entryId: entryId,
|
entryId: entryId,
|
||||||
kvCache: cache,
|
kvCache: cache,
|
||||||
|
|||||||
@@ -33,6 +33,9 @@ final class TokenPrefixCache: @unchecked Sendable {
|
|||||||
let totalMisses: Int
|
let totalMisses: Int
|
||||||
let totalEvictions: Int
|
let totalEvictions: Int
|
||||||
let hitRate: Double
|
let hitRate: Double
|
||||||
|
let prefixHits: Int
|
||||||
|
let supersequenceHits: Int
|
||||||
|
let lcpHits: Int
|
||||||
let entries: [EntrySummary]
|
let entries: [EntrySummary]
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -57,6 +60,9 @@ final class TokenPrefixCache: @unchecked Sendable {
|
|||||||
var totalHits: Int = 0
|
var totalHits: Int = 0
|
||||||
var totalMisses: Int = 0
|
var totalMisses: Int = 0
|
||||||
var totalEvictions: Int = 0
|
var totalEvictions: Int = 0
|
||||||
|
var totalPrefixHits: Int = 0
|
||||||
|
var totalSupersequenceHits: Int = 0
|
||||||
|
var totalLCPHits: Int = 0
|
||||||
}
|
}
|
||||||
|
|
||||||
private let lock = OSAllocatedUnfairLock()
|
private let lock = OSAllocatedUnfairLock()
|
||||||
@@ -92,13 +98,22 @@ final class TokenPrefixCache: @unchecked Sendable {
|
|||||||
lock.lock()
|
lock.lock()
|
||||||
let now = nowProvider()
|
let now = nowProvider()
|
||||||
pruneExpiredLocked(now: now)
|
pruneExpiredLocked(now: now)
|
||||||
|
let queryRealTokenCount = cacheKey.reduce(into: 0) { partialResult, token in
|
||||||
|
if token >= 0 {
|
||||||
|
partialResult += 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
var node = root
|
var node = root
|
||||||
var bestMatch: (entryId: UUID, realTokenCount: Int)?
|
var bestMatch: (entryId: UUID, realTokenCount: Int)?
|
||||||
var realTokenCount = 0
|
var realTokenCount = 0
|
||||||
|
var walkedFullKey = true
|
||||||
|
|
||||||
for key in cacheKey {
|
for key in cacheKey {
|
||||||
guard let child = node.children[key] else { break }
|
guard let child = node.children[key] else {
|
||||||
|
walkedFullKey = false
|
||||||
|
break
|
||||||
|
}
|
||||||
node = child
|
node = child
|
||||||
if key >= 0 { realTokenCount += 1 }
|
if key >= 0 { realTokenCount += 1 }
|
||||||
if let entryId = node.entryId,
|
if let entryId = node.entryId,
|
||||||
@@ -108,27 +123,50 @@ final class TokenPrefixCache: @unchecked Sendable {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
guard let match = bestMatch,
|
if let match = bestMatch,
|
||||||
var entry = entries[match.entryId]
|
var entry = entries[match.entryId] {
|
||||||
else {
|
entry.lastAccessAt = now
|
||||||
stats.totalMisses += 1
|
entry.hitCount += 1
|
||||||
|
entries[match.entryId] = entry
|
||||||
|
removeEntryLocked(entry)
|
||||||
|
stats.totalHits += 1
|
||||||
|
stats.totalPrefixHits += 1
|
||||||
lock.unlock()
|
lock.unlock()
|
||||||
return CacheLease(entryId: UUID(), kvCache: nil, matchedTokenCount: 0, isHit: false)
|
|
||||||
|
return CacheLease(
|
||||||
|
entryId: match.entryId,
|
||||||
|
kvCache: entry.kvCache,
|
||||||
|
matchedTokenCount: match.realTokenCount,
|
||||||
|
isHit: true
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
entry.lastAccessAt = now
|
if walkedFullKey,
|
||||||
entry.hitCount += 1
|
let superLease = findSupersequenceMatchLocked(
|
||||||
entries[match.entryId] = entry
|
below: node,
|
||||||
removeEntryLocked(entry)
|
queryRealTokenCount: realTokenCount,
|
||||||
stats.totalHits += 1
|
modelId: modelId,
|
||||||
lock.unlock()
|
now: now
|
||||||
|
) {
|
||||||
|
lock.unlock()
|
||||||
|
return superLease
|
||||||
|
}
|
||||||
|
|
||||||
return CacheLease(
|
if realTokenCount > 0,
|
||||||
entryId: match.entryId,
|
let lcpLease = findLCPMatchLocked(
|
||||||
kvCache: entry.kvCache,
|
below: node,
|
||||||
matchedTokenCount: match.realTokenCount,
|
sharedRealTokenCount: realTokenCount,
|
||||||
isHit: true
|
queryRealTokenCount: queryRealTokenCount,
|
||||||
)
|
modelId: modelId,
|
||||||
|
now: now
|
||||||
|
) {
|
||||||
|
lock.unlock()
|
||||||
|
return lcpLease
|
||||||
|
}
|
||||||
|
|
||||||
|
stats.totalMisses += 1
|
||||||
|
lock.unlock()
|
||||||
|
return CacheLease(entryId: UUID(), kvCache: nil, matchedTokenCount: 0, isHit: false)
|
||||||
}
|
}
|
||||||
|
|
||||||
func store(
|
func store(
|
||||||
@@ -216,6 +254,9 @@ final class TokenPrefixCache: @unchecked Sendable {
|
|||||||
totalMisses: misses,
|
totalMisses: misses,
|
||||||
totalEvictions: stats.totalEvictions,
|
totalEvictions: stats.totalEvictions,
|
||||||
hitRate: totalOps > 0 ? (Double(hits) / Double(totalOps)) * 100 : 0,
|
hitRate: totalOps > 0 ? (Double(hits) / Double(totalOps)) * 100 : 0,
|
||||||
|
prefixHits: stats.totalPrefixHits,
|
||||||
|
supersequenceHits: stats.totalSupersequenceHits,
|
||||||
|
lcpHits: stats.totalLCPHits,
|
||||||
entries: orderedEntries.map {
|
entries: orderedEntries.map {
|
||||||
EntrySummary(
|
EntrySummary(
|
||||||
id: $0.id,
|
id: $0.id,
|
||||||
@@ -297,6 +338,125 @@ final class TokenPrefixCache: @unchecked Sendable {
|
|||||||
1 + node.children.values.reduce(0) { $0 + countNodes($1) }
|
1 + node.children.values.reduce(0) { $0 + countNodes($1) }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private func findSupersequenceMatchLocked(
|
||||||
|
below node: TrieNode,
|
||||||
|
queryRealTokenCount: Int,
|
||||||
|
modelId: String,
|
||||||
|
now: Date
|
||||||
|
) -> CacheLease? {
|
||||||
|
var queue: [TrieNode] = [node]
|
||||||
|
var bestEntry: CacheEntry?
|
||||||
|
|
||||||
|
while !queue.isEmpty {
|
||||||
|
let current = queue.removeFirst()
|
||||||
|
if let entryId = current.entryId,
|
||||||
|
let entry = entries[entryId],
|
||||||
|
entry.modelId == modelId,
|
||||||
|
entry.tokenCount > queryRealTokenCount,
|
||||||
|
entry.kvCache.allSatisfy({ $0.isTrimmable }) {
|
||||||
|
if bestEntry == nil || entry.tokenCount < bestEntry!.tokenCount {
|
||||||
|
bestEntry = entry
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for child in current.children.values {
|
||||||
|
queue.append(child)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
guard let entry = bestEntry,
|
||||||
|
let trimmedCache = Self.trimCacheByOffset(entry.kvCache, trimBy: entry.tokenCount - queryRealTokenCount)
|
||||||
|
else {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var updatedEntry = entry
|
||||||
|
updatedEntry.lastAccessAt = now
|
||||||
|
updatedEntry.hitCount += 1
|
||||||
|
entries[entry.id] = updatedEntry
|
||||||
|
removeEntryLocked(updatedEntry)
|
||||||
|
stats.totalHits += 1
|
||||||
|
stats.totalSupersequenceHits += 1
|
||||||
|
|
||||||
|
return CacheLease(
|
||||||
|
entryId: updatedEntry.id,
|
||||||
|
kvCache: trimmedCache,
|
||||||
|
matchedTokenCount: queryRealTokenCount,
|
||||||
|
isHit: true
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
private func findLCPMatchLocked(
|
||||||
|
below node: TrieNode,
|
||||||
|
sharedRealTokenCount: Int,
|
||||||
|
queryRealTokenCount: Int,
|
||||||
|
modelId: String,
|
||||||
|
now: Date
|
||||||
|
) -> CacheLease? {
|
||||||
|
guard sharedRealTokenCount >= Self.minimumLCPMatchTokens(for: queryRealTokenCount) else {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var queue = Array(node.children.values)
|
||||||
|
var bestEntry: CacheEntry?
|
||||||
|
|
||||||
|
while !queue.isEmpty {
|
||||||
|
let current = queue.removeFirst()
|
||||||
|
if let entryId = current.entryId,
|
||||||
|
let entry = entries[entryId],
|
||||||
|
entry.modelId == modelId,
|
||||||
|
entry.tokenCount > sharedRealTokenCount,
|
||||||
|
entry.kvCache.allSatisfy({ $0.isTrimmable }) {
|
||||||
|
if bestEntry == nil || entry.tokenCount < bestEntry!.tokenCount {
|
||||||
|
bestEntry = entry
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for child in current.children.values {
|
||||||
|
queue.append(child)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
guard let entry = bestEntry,
|
||||||
|
let trimmedCache = Self.trimCacheByOffset(entry.kvCache, trimBy: entry.tokenCount - sharedRealTokenCount)
|
||||||
|
else {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var updatedEntry = entry
|
||||||
|
updatedEntry.lastAccessAt = now
|
||||||
|
updatedEntry.hitCount += 1
|
||||||
|
entries[entry.id] = updatedEntry
|
||||||
|
removeEntryLocked(updatedEntry)
|
||||||
|
stats.totalHits += 1
|
||||||
|
stats.totalLCPHits += 1
|
||||||
|
|
||||||
|
return CacheLease(
|
||||||
|
entryId: updatedEntry.id,
|
||||||
|
kvCache: trimmedCache,
|
||||||
|
matchedTokenCount: sharedRealTokenCount,
|
||||||
|
isHit: true
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
private static func trimCacheByOffset(_ cache: [KVCache], trimBy: Int) -> [KVCache]? {
|
||||||
|
guard trimBy >= 0 else { return nil }
|
||||||
|
guard trimBy > 0 else { return cache }
|
||||||
|
|
||||||
|
for layer in cache {
|
||||||
|
guard layer.isTrimmable else { return nil }
|
||||||
|
let trimmed = layer.trim(trimBy)
|
||||||
|
guard trimmed == trimBy else { return nil }
|
||||||
|
}
|
||||||
|
|
||||||
|
return cache
|
||||||
|
}
|
||||||
|
|
||||||
|
private static func minimumLCPMatchTokens(for queryRealTokenCount: Int) -> Int {
|
||||||
|
guard queryRealTokenCount > 0 else { return .max }
|
||||||
|
return max(2, (queryRealTokenCount + 1) / 2)
|
||||||
|
}
|
||||||
|
|
||||||
private static func computeMemoryBudget() -> Int {
|
private static func computeMemoryBudget() -> Int {
|
||||||
guard let device = MTLCreateSystemDefaultDevice() else {
|
guard let device = MTLCreateSystemDefaultDevice() else {
|
||||||
return 512 * 1024 * 1024
|
return 512 * 1024 * 1024
|
||||||
|
|||||||
@@ -39,7 +39,7 @@ struct MonitorView: View {
|
|||||||
metricCard(
|
metricCard(
|
||||||
title: "Cache Hit Rate",
|
title: "Cache Hit Rate",
|
||||||
value: String(format: "%.0f%%", stats.cacheHitRatePercent),
|
value: String(format: "%.0f%%", stats.cacheHitRatePercent),
|
||||||
detail: "\(stats.totalCacheHits) hits / \(stats.totalCacheMisses) misses",
|
detail: "\(stats.totalCacheHits) hits / \(stats.totalCacheMisses) misses • P:\(stats.totalPrefixHits) S:\(stats.totalSupersequenceHits) L:\(stats.totalLCPHits)",
|
||||||
color: .blue
|
color: .blue
|
||||||
)
|
)
|
||||||
metricCard(
|
metricCard(
|
||||||
|
|||||||
@@ -3,6 +3,20 @@ import XCTest
|
|||||||
@testable import MLX_Server
|
@testable import MLX_Server
|
||||||
|
|
||||||
final class APIServerRewriteTests: XCTestCase {
|
final class APIServerRewriteTests: XCTestCase {
|
||||||
|
func testHealthAndModelsEndpointsReturnExpectedPayloads() async throws {
|
||||||
|
let harness = try await makeHarness()
|
||||||
|
defer { harness.stop() }
|
||||||
|
|
||||||
|
let health = try await sendRawRequest(path: "/health", port: harness.port)
|
||||||
|
XCTAssertEqual(health.statusCode, 200)
|
||||||
|
XCTAssertEqual(health.body, #"{"status":"ok"}"#)
|
||||||
|
|
||||||
|
let models = try await sendModelsRequest(port: harness.port)
|
||||||
|
XCTAssertFalse(models.data.isEmpty)
|
||||||
|
XCTAssertTrue(models.data.contains { $0.id == ModelConfig.default.repoId })
|
||||||
|
XCTAssertTrue(models.data.allSatisfy { $0.context_window != nil })
|
||||||
|
}
|
||||||
|
|
||||||
func testNonStreamingChatCompletionUsesStatelessServerPathAndCachesPrompt() async throws {
|
func testNonStreamingChatCompletionUsesStatelessServerPathAndCachesPrompt() async throws {
|
||||||
let harness = try await makeHarness()
|
let harness = try await makeHarness()
|
||||||
defer { harness.stop() }
|
defer { harness.stop() }
|
||||||
@@ -51,6 +65,306 @@ final class APIServerRewriteTests: XCTestCase {
|
|||||||
XCTAssertGreaterThan(secondLiveSnapshot.cacheMatchDepth, 0)
|
XCTAssertGreaterThan(secondLiveSnapshot.cacheMatchDepth, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func testSecondIdenticalRequestIsFullCacheHitWithZeroRebuiltPromptTokens() async throws {
|
||||||
|
let harness = try await makeHarness()
|
||||||
|
defer { harness.stop() }
|
||||||
|
|
||||||
|
let request = APIChatCompletionRequest(
|
||||||
|
model: "gemma",
|
||||||
|
messages: [
|
||||||
|
APIChatMessage(role: "user", content: .text("Answer with one word: ocean."), name: nil, tool_calls: nil, tool_call_id: nil)
|
||||||
|
],
|
||||||
|
temperature: 0,
|
||||||
|
top_p: 1,
|
||||||
|
max_tokens: 2,
|
||||||
|
stream: false,
|
||||||
|
stop: nil,
|
||||||
|
tools: nil,
|
||||||
|
tool_choice: nil,
|
||||||
|
frequency_penalty: nil,
|
||||||
|
presence_penalty: nil,
|
||||||
|
n: nil
|
||||||
|
)
|
||||||
|
|
||||||
|
_ = try await sendChatCompletion(request, port: harness.port)
|
||||||
|
_ = try await sendChatCompletion(request, port: harness.port)
|
||||||
|
|
||||||
|
let live = LiveCounters.shared.snapshot()
|
||||||
|
XCTAssertGreaterThan(live.currentCacheMatchedPromptTokens, 0)
|
||||||
|
XCTAssertEqual(live.currentCacheMatchedPromptTokens, live.promptTokens)
|
||||||
|
XCTAssertEqual(live.currentCacheRebuiltPromptTokens, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
func testSingleTurnContinuationProducesPartialCacheHit() async throws {
|
||||||
|
let harness = try await makeHarness()
|
||||||
|
defer { harness.stop() }
|
||||||
|
|
||||||
|
let firstRequest = APIChatCompletionRequest(
|
||||||
|
model: "gemma",
|
||||||
|
messages: [
|
||||||
|
APIChatMessage(role: "user", content: .text("Answer in one word: sun."), name: nil, tool_calls: nil, tool_call_id: nil)
|
||||||
|
],
|
||||||
|
temperature: 0,
|
||||||
|
top_p: 1,
|
||||||
|
max_tokens: 2,
|
||||||
|
stream: true,
|
||||||
|
stop: nil,
|
||||||
|
tools: nil,
|
||||||
|
tool_choice: nil,
|
||||||
|
frequency_penalty: nil,
|
||||||
|
presence_penalty: nil,
|
||||||
|
n: nil
|
||||||
|
)
|
||||||
|
|
||||||
|
let firstStream = try await sendStreamingChatCompletion(firstRequest, port: harness.port)
|
||||||
|
XCTAssertFalse(firstStream.content.isEmpty)
|
||||||
|
|
||||||
|
let secondRequest = APIChatCompletionRequest(
|
||||||
|
model: "gemma",
|
||||||
|
messages: [
|
||||||
|
APIChatMessage(role: "user", content: .text("Answer in one word: sun."), name: nil, tool_calls: nil, tool_call_id: nil),
|
||||||
|
APIChatMessage(role: "assistant", content: .text(firstStream.content), name: nil, tool_calls: nil, tool_call_id: nil),
|
||||||
|
APIChatMessage(role: "user", content: .text("Answer in one word: moon."), name: nil, tool_calls: nil, tool_call_id: nil)
|
||||||
|
],
|
||||||
|
temperature: 0,
|
||||||
|
top_p: 1,
|
||||||
|
max_tokens: 2,
|
||||||
|
stream: false,
|
||||||
|
stop: nil,
|
||||||
|
tools: nil,
|
||||||
|
tool_choice: nil,
|
||||||
|
frequency_penalty: nil,
|
||||||
|
presence_penalty: nil,
|
||||||
|
n: nil
|
||||||
|
)
|
||||||
|
|
||||||
|
_ = try await sendChatCompletion(secondRequest, port: harness.port)
|
||||||
|
|
||||||
|
let live = LiveCounters.shared.snapshot()
|
||||||
|
XCTAssertGreaterThan(live.currentCacheMatchedPromptTokens, 0)
|
||||||
|
XCTAssertGreaterThan(live.currentCacheRebuiltPromptTokens, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
func testSameSystemPromptDifferentUserMessageReusesSystemPrefix() async throws {
|
||||||
|
let harness = try await makeHarness()
|
||||||
|
defer { harness.stop() }
|
||||||
|
|
||||||
|
let lookups = LookupEventCollector()
|
||||||
|
APIServer.debugLookupEventHandler = { event in
|
||||||
|
Task {
|
||||||
|
await lookups.record(event)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
defer {
|
||||||
|
APIServer.debugLookupEventHandler = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
let firstRequest = APIChatCompletionRequest(
|
||||||
|
model: "gemma",
|
||||||
|
messages: [
|
||||||
|
APIChatMessage(role: "system", content: .text("You are terse and literal."), name: nil, tool_calls: nil, tool_call_id: nil),
|
||||||
|
APIChatMessage(role: "user", content: .text("Respond with one word for cat."), name: nil, tool_calls: nil, tool_call_id: nil)
|
||||||
|
],
|
||||||
|
temperature: 0,
|
||||||
|
top_p: 1,
|
||||||
|
max_tokens: 2,
|
||||||
|
stream: false,
|
||||||
|
stop: nil,
|
||||||
|
tools: nil,
|
||||||
|
tool_choice: nil,
|
||||||
|
frequency_penalty: nil,
|
||||||
|
presence_penalty: nil,
|
||||||
|
n: nil
|
||||||
|
)
|
||||||
|
|
||||||
|
let secondRequest = APIChatCompletionRequest(
|
||||||
|
model: "gemma",
|
||||||
|
messages: [
|
||||||
|
APIChatMessage(role: "system", content: .text("You are terse and literal."), name: nil, tool_calls: nil, tool_call_id: nil),
|
||||||
|
APIChatMessage(role: "user", content: .text("Respond with one word for dog."), name: nil, tool_calls: nil, tool_call_id: nil)
|
||||||
|
],
|
||||||
|
temperature: 0,
|
||||||
|
top_p: 1,
|
||||||
|
max_tokens: 2,
|
||||||
|
stream: false,
|
||||||
|
stop: nil,
|
||||||
|
tools: nil,
|
||||||
|
tool_choice: nil,
|
||||||
|
frequency_penalty: nil,
|
||||||
|
presence_penalty: nil,
|
||||||
|
n: nil
|
||||||
|
)
|
||||||
|
|
||||||
|
_ = try await sendChatCompletion(firstRequest, port: harness.port)
|
||||||
|
_ = try await sendChatCompletion(secondRequest, port: harness.port)
|
||||||
|
|
||||||
|
try await waitUntil(timeoutSeconds: 5) {
|
||||||
|
let events = await lookups.events()
|
||||||
|
return events.count >= 2
|
||||||
|
}
|
||||||
|
|
||||||
|
let events = await lookups.events()
|
||||||
|
let secondLookup = try XCTUnwrap(events.last)
|
||||||
|
XCTAssertEqual(secondLookup.modelId, "gemma")
|
||||||
|
XCTAssertGreaterThan(secondLookup.promptTokenCount, 0)
|
||||||
|
XCTAssertTrue(secondLookup.isHit)
|
||||||
|
XCTAssertGreaterThan(secondLookup.matchedTokenCount, 0)
|
||||||
|
XCTAssertLessThan(secondLookup.matchedTokenCount, secondLookup.promptTokenCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
func testServerStoredCacheIsDirectlyReusableForSameSystemDifferentUserPrompt() async throws {
|
||||||
|
let harness = try await makeHarness()
|
||||||
|
defer { harness.stop() }
|
||||||
|
|
||||||
|
let firstRequest = APIChatCompletionRequest(
|
||||||
|
model: "gemma",
|
||||||
|
messages: [
|
||||||
|
APIChatMessage(role: "system", content: .text("You are terse and literal."), name: nil, tool_calls: nil, tool_call_id: nil),
|
||||||
|
APIChatMessage(role: "user", content: .text("Respond with one word for cat."), name: nil, tool_calls: nil, tool_call_id: nil)
|
||||||
|
],
|
||||||
|
temperature: 0,
|
||||||
|
top_p: 1,
|
||||||
|
max_tokens: 2,
|
||||||
|
stream: false,
|
||||||
|
stop: nil,
|
||||||
|
tools: nil,
|
||||||
|
tool_choice: nil,
|
||||||
|
frequency_penalty: nil,
|
||||||
|
presence_penalty: nil,
|
||||||
|
n: nil
|
||||||
|
)
|
||||||
|
|
||||||
|
_ = try await sendChatCompletion(firstRequest, port: harness.port)
|
||||||
|
|
||||||
|
let secondRequest = APIChatCompletionRequest(
|
||||||
|
model: "gemma",
|
||||||
|
messages: [
|
||||||
|
APIChatMessage(role: "system", content: .text("You are terse and literal."), name: nil, tool_calls: nil, tool_call_id: nil),
|
||||||
|
APIChatMessage(role: "user", content: .text("Respond with one word for dog."), name: nil, tool_calls: nil, tool_call_id: nil)
|
||||||
|
],
|
||||||
|
temperature: 0,
|
||||||
|
top_p: 1,
|
||||||
|
max_tokens: 2,
|
||||||
|
stream: false,
|
||||||
|
stop: nil,
|
||||||
|
tools: nil,
|
||||||
|
tool_choice: nil,
|
||||||
|
frequency_penalty: nil,
|
||||||
|
presence_penalty: nil,
|
||||||
|
n: nil
|
||||||
|
)
|
||||||
|
|
||||||
|
let modelContainer = await MainActor.run { harness.modelManager.modelContainer }
|
||||||
|
let container = try XCTUnwrap(modelContainer)
|
||||||
|
let engine = InferenceEngine(container: container)
|
||||||
|
let preparedPrompt = PromptBuilder.build(
|
||||||
|
from: secondRequest,
|
||||||
|
modelId: ModelConfig.default.repoId,
|
||||||
|
thinkingEnabled: Preferences.enableThinking
|
||||||
|
)
|
||||||
|
let preparedInference = try await engine.prepare(preparedPrompt.userInput)
|
||||||
|
|
||||||
|
let lease = TokenPrefixCache.shared.lookup(cacheKey: preparedInference.tokens, modelId: "gemma")
|
||||||
|
|
||||||
|
XCTAssertTrue(lease.isHit)
|
||||||
|
XCTAssertGreaterThan(lease.matchedTokenCount, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
func testDifferentSystemPromptDoesNotProduceFalseCacheHit() async throws {
|
||||||
|
let harness = try await makeHarness()
|
||||||
|
defer { harness.stop() }
|
||||||
|
|
||||||
|
let firstRequest = APIChatCompletionRequest(
|
||||||
|
model: "gemma",
|
||||||
|
messages: [
|
||||||
|
APIChatMessage(role: "system", content: .text("System Alpha Unique Tokens"), name: nil, tool_calls: nil, tool_call_id: nil),
|
||||||
|
APIChatMessage(role: "user", content: .text("Answer in one word: tree."), name: nil, tool_calls: nil, tool_call_id: nil)
|
||||||
|
],
|
||||||
|
temperature: 0,
|
||||||
|
top_p: 1,
|
||||||
|
max_tokens: 2,
|
||||||
|
stream: false,
|
||||||
|
stop: nil,
|
||||||
|
tools: nil,
|
||||||
|
tool_choice: nil,
|
||||||
|
frequency_penalty: nil,
|
||||||
|
presence_penalty: nil,
|
||||||
|
n: nil
|
||||||
|
)
|
||||||
|
|
||||||
|
let secondRequest = APIChatCompletionRequest(
|
||||||
|
model: "gemma",
|
||||||
|
messages: [
|
||||||
|
APIChatMessage(role: "system", content: .text("Completely Different Beta Markers"), name: nil, tool_calls: nil, tool_call_id: nil),
|
||||||
|
APIChatMessage(role: "user", content: .text("Answer in one word: tree."), name: nil, tool_calls: nil, tool_call_id: nil)
|
||||||
|
],
|
||||||
|
temperature: 0,
|
||||||
|
top_p: 1,
|
||||||
|
max_tokens: 2,
|
||||||
|
stream: false,
|
||||||
|
stop: nil,
|
||||||
|
tools: nil,
|
||||||
|
tool_choice: nil,
|
||||||
|
frequency_penalty: nil,
|
||||||
|
presence_penalty: nil,
|
||||||
|
n: nil
|
||||||
|
)
|
||||||
|
|
||||||
|
_ = try await sendChatCompletion(firstRequest, port: harness.port)
|
||||||
|
let before = TokenPrefixCache.shared.snapshot()
|
||||||
|
_ = try await sendChatCompletion(secondRequest, port: harness.port)
|
||||||
|
|
||||||
|
let after = TokenPrefixCache.shared.snapshot()
|
||||||
|
let live = LiveCounters.shared.snapshot()
|
||||||
|
XCTAssertEqual(after.totalHits, before.totalHits)
|
||||||
|
XCTAssertEqual(live.currentCacheMatchedPromptTokens, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
func testIdleUnloadReloadInvalidatesCacheAndServesFreshRequest() async throws {
|
||||||
|
let harness = try await makeHarness()
|
||||||
|
defer { harness.stop() }
|
||||||
|
|
||||||
|
Preferences.lastModelId = "gemma"
|
||||||
|
let request = APIChatCompletionRequest(
|
||||||
|
model: nil,
|
||||||
|
messages: [
|
||||||
|
APIChatMessage(role: "user", content: .text("Answer in one word: cloud."), name: nil, tool_calls: nil, tool_call_id: nil)
|
||||||
|
],
|
||||||
|
temperature: 0,
|
||||||
|
top_p: 1,
|
||||||
|
max_tokens: 2,
|
||||||
|
stream: false,
|
||||||
|
stop: nil,
|
||||||
|
tools: nil,
|
||||||
|
tool_choice: nil,
|
||||||
|
frequency_penalty: nil,
|
||||||
|
presence_penalty: nil,
|
||||||
|
n: nil
|
||||||
|
)
|
||||||
|
|
||||||
|
_ = try await sendChatCompletion(request, port: harness.port)
|
||||||
|
try await waitUntil(timeoutSeconds: 5) {
|
||||||
|
TokenPrefixCache.shared.snapshot().totalEntries > 0
|
||||||
|
}
|
||||||
|
|
||||||
|
await MainActor.run {
|
||||||
|
harness.modelManager.unloadModel()
|
||||||
|
}
|
||||||
|
let wasReadyAfterUnload = await MainActor.run { harness.modelManager.isReady }
|
||||||
|
XCTAssertFalse(wasReadyAfterUnload)
|
||||||
|
|
||||||
|
let before = TokenPrefixCache.shared.snapshot()
|
||||||
|
let response = try await sendChatCompletion(request, port: harness.port)
|
||||||
|
XCTAssertEqual(response.choices.count, 1)
|
||||||
|
let isReadyAfterReload = await MainActor.run { harness.modelManager.isReady }
|
||||||
|
XCTAssertTrue(isReadyAfterReload)
|
||||||
|
|
||||||
|
let after = TokenPrefixCache.shared.snapshot()
|
||||||
|
let live = LiveCounters.shared.snapshot()
|
||||||
|
XCTAssertEqual(after.totalHits, before.totalHits)
|
||||||
|
XCTAssertEqual(live.currentCacheMatchedPromptTokens, 0)
|
||||||
|
}
|
||||||
|
|
||||||
func testStreamingChatCompletionReusesCacheAcrossThreeProgressivelyLongerTurns() async throws {
|
func testStreamingChatCompletionReusesCacheAcrossThreeProgressivelyLongerTurns() async throws {
|
||||||
let harness = try await makeHarness()
|
let harness = try await makeHarness()
|
||||||
defer { harness.stop() }
|
defer { harness.stop() }
|
||||||
@@ -568,6 +882,19 @@ final class APIServerRewriteTests: XCTestCase {
|
|||||||
return try JSONDecoder().decode(APIChatCompletionResponse.self, from: data)
|
return try JSONDecoder().decode(APIChatCompletionResponse.self, from: data)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private func sendModelsRequest(port: UInt16) async throws -> APIModelListResponse {
|
||||||
|
let response = try await sendRawRequest(path: "/v1/models", port: port)
|
||||||
|
XCTAssertEqual(response.statusCode, 200)
|
||||||
|
return try JSONDecoder().decode(APIModelListResponse.self, from: response.bodyData)
|
||||||
|
}
|
||||||
|
|
||||||
|
private func sendRawRequest(path: String, port: UInt16) async throws -> (statusCode: Int, body: String, bodyData: Data) {
|
||||||
|
let url = URL(string: "http://127.0.0.1:\(port)\(path)")!
|
||||||
|
let (data, response) = try await URLSession.shared.data(from: url)
|
||||||
|
let httpResponse = try XCTUnwrap(response as? HTTPURLResponse)
|
||||||
|
return (httpResponse.statusCode, String(data: data, encoding: .utf8) ?? "", data)
|
||||||
|
}
|
||||||
|
|
||||||
private func sendStreamingChatCompletion(_ request: APIChatCompletionRequest, port: UInt16) async throws -> StreamingResult {
|
private func sendStreamingChatCompletion(_ request: APIChatCompletionRequest, port: UInt16) async throws -> StreamingResult {
|
||||||
let detailed = try await sendStreamingChatCompletionDetailed(request, port: port)
|
let detailed = try await sendStreamingChatCompletionDetailed(request, port: port)
|
||||||
return StreamingResult(
|
return StreamingResult(
|
||||||
@@ -695,6 +1022,18 @@ private actor StreamCancellationObserver {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private actor LookupEventCollector {
|
||||||
|
private var recorded: [APIServer.DebugLookupEvent] = []
|
||||||
|
|
||||||
|
func record(_ event: APIServer.DebugLookupEvent) {
|
||||||
|
recorded.append(event)
|
||||||
|
}
|
||||||
|
|
||||||
|
func events() -> [APIServer.DebugLookupEvent] {
|
||||||
|
recorded
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
private struct DetailedStreamingResult {
|
private struct DetailedStreamingResult {
|
||||||
let events: [StreamingEvent]
|
let events: [StreamingEvent]
|
||||||
let sawDone: Bool
|
let sawDone: Bool
|
||||||
|
|||||||
@@ -92,10 +92,211 @@ final class ModelBackedInferenceValidationTests: XCTestCase {
|
|||||||
XCTAssertEqual(engineResult.promptTokenCount, sessionResult.promptTokenCount)
|
XCTAssertEqual(engineResult.promptTokenCount, sessionResult.promptTokenCount)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func testTokenPrefixCacheFindsLCPHitForSameSystemDifferentUserOnLocalGemmaTokens() async throws {
|
||||||
|
let container = try await localGemmaContainer()
|
||||||
|
let engine = InferenceEngine(container: container)
|
||||||
|
|
||||||
|
let first = PromptBuilder.build(
|
||||||
|
from: APIChatCompletionRequest(
|
||||||
|
model: "gemma",
|
||||||
|
messages: [
|
||||||
|
APIChatMessage(role: "system", content: .text("You are terse and literal."), name: nil, tool_calls: nil, tool_call_id: nil),
|
||||||
|
APIChatMessage(role: "user", content: .text("Respond with one word for cat."), name: nil, tool_calls: nil, tool_call_id: nil),
|
||||||
|
],
|
||||||
|
temperature: nil,
|
||||||
|
top_p: nil,
|
||||||
|
max_tokens: nil,
|
||||||
|
stream: nil,
|
||||||
|
stop: nil,
|
||||||
|
tools: nil,
|
||||||
|
tool_choice: nil,
|
||||||
|
frequency_penalty: nil,
|
||||||
|
presence_penalty: nil,
|
||||||
|
n: nil
|
||||||
|
),
|
||||||
|
modelId: "mlx-community/gemma-3-4b-it-4bit",
|
||||||
|
thinkingEnabled: true
|
||||||
|
)
|
||||||
|
let second = PromptBuilder.build(
|
||||||
|
from: APIChatCompletionRequest(
|
||||||
|
model: "gemma",
|
||||||
|
messages: [
|
||||||
|
APIChatMessage(role: "system", content: .text("You are terse and literal."), name: nil, tool_calls: nil, tool_call_id: nil),
|
||||||
|
APIChatMessage(role: "user", content: .text("Respond with one word for dog."), name: nil, tool_calls: nil, tool_call_id: nil),
|
||||||
|
],
|
||||||
|
temperature: nil,
|
||||||
|
top_p: nil,
|
||||||
|
max_tokens: nil,
|
||||||
|
stream: nil,
|
||||||
|
stop: nil,
|
||||||
|
tools: nil,
|
||||||
|
tool_choice: nil,
|
||||||
|
frequency_penalty: nil,
|
||||||
|
presence_penalty: nil,
|
||||||
|
n: nil
|
||||||
|
),
|
||||||
|
modelId: "mlx-community/gemma-3-4b-it-4bit",
|
||||||
|
thinkingEnabled: true
|
||||||
|
)
|
||||||
|
|
||||||
|
let firstPrepared = try await engine.prepare(first.userInput)
|
||||||
|
let secondPrepared = try await engine.prepare(second.userInput)
|
||||||
|
|
||||||
|
let cache = TokenPrefixCache(memoryBudgetBytes: 1_000_000, estimateBytesProvider: { _ in 1_024 })
|
||||||
|
cache.store(entryId: UUID(), kvCache: [], cacheKey: firstPrepared.tokens, modelId: "gemma")
|
||||||
|
|
||||||
|
let lease = cache.lookup(cacheKey: secondPrepared.tokens, modelId: "gemma")
|
||||||
|
|
||||||
|
XCTAssertTrue(lease.isHit)
|
||||||
|
XCTAssertGreaterThan(lease.matchedTokenCount, 0)
|
||||||
|
XCTAssertLessThan(lease.matchedTokenCount, firstPrepared.tokens.count)
|
||||||
|
}
|
||||||
|
|
||||||
|
func testStoredLiveGemmaCacheSupportsSameSystemDifferentUserLCPReuse() async throws {
|
||||||
|
let container = try await localGemmaContainer()
|
||||||
|
let engine = InferenceEngine(container: container)
|
||||||
|
|
||||||
|
let first = PromptBuilder.build(
|
||||||
|
from: APIChatCompletionRequest(
|
||||||
|
model: "gemma",
|
||||||
|
messages: [
|
||||||
|
APIChatMessage(role: "system", content: .text("You are terse and literal."), name: nil, tool_calls: nil, tool_call_id: nil),
|
||||||
|
APIChatMessage(role: "user", content: .text("Respond with one word for cat."), name: nil, tool_calls: nil, tool_call_id: nil),
|
||||||
|
],
|
||||||
|
temperature: nil,
|
||||||
|
top_p: nil,
|
||||||
|
max_tokens: nil,
|
||||||
|
stream: nil,
|
||||||
|
stop: nil,
|
||||||
|
tools: nil,
|
||||||
|
tool_choice: nil,
|
||||||
|
frequency_penalty: nil,
|
||||||
|
presence_penalty: nil,
|
||||||
|
n: nil
|
||||||
|
),
|
||||||
|
modelId: "mlx-community/gemma-3-4b-it-4bit",
|
||||||
|
thinkingEnabled: true
|
||||||
|
)
|
||||||
|
let second = PromptBuilder.build(
|
||||||
|
from: APIChatCompletionRequest(
|
||||||
|
model: "gemma",
|
||||||
|
messages: [
|
||||||
|
APIChatMessage(role: "system", content: .text("You are terse and literal."), name: nil, tool_calls: nil, tool_call_id: nil),
|
||||||
|
APIChatMessage(role: "user", content: .text("Respond with one word for dog."), name: nil, tool_calls: nil, tool_call_id: nil),
|
||||||
|
],
|
||||||
|
temperature: nil,
|
||||||
|
top_p: nil,
|
||||||
|
max_tokens: nil,
|
||||||
|
stream: nil,
|
||||||
|
stop: nil,
|
||||||
|
tools: nil,
|
||||||
|
tool_choice: nil,
|
||||||
|
frequency_penalty: nil,
|
||||||
|
presence_penalty: nil,
|
||||||
|
n: nil
|
||||||
|
),
|
||||||
|
modelId: "mlx-community/gemma-3-4b-it-4bit",
|
||||||
|
thinkingEnabled: true
|
||||||
|
)
|
||||||
|
|
||||||
|
let firstPrepared = try await engine.prepare(first.userInput)
|
||||||
|
let secondPrepared = try await engine.prepare(second.userInput)
|
||||||
|
let handle = try await engine.stream(
|
||||||
|
InferenceEngine.InferenceRequest(
|
||||||
|
input: firstPrepared.lmInput,
|
||||||
|
tokens: firstPrepared.tokens,
|
||||||
|
parameters: GenerateParameters(maxTokens: 2, temperature: 0),
|
||||||
|
cachedKV: nil,
|
||||||
|
cachedTokenCount: 0
|
||||||
|
),
|
||||||
|
cancellation: CancellationToken()
|
||||||
|
)
|
||||||
|
|
||||||
|
_ = await collectEngineOutput(handle.stream)
|
||||||
|
trimCacheToPrompt(handle.workingCache, promptTokenCount: firstPrepared.tokens.count)
|
||||||
|
|
||||||
|
let cache = TokenPrefixCache(memoryBudgetBytes: 1_000_000_000, estimateBytesProvider: { _ in 1_024 })
|
||||||
|
cache.store(entryId: UUID(), kvCache: handle.workingCache, cacheKey: firstPrepared.tokens, modelId: "gemma")
|
||||||
|
|
||||||
|
let lease = cache.lookup(cacheKey: secondPrepared.tokens, modelId: "gemma")
|
||||||
|
|
||||||
|
XCTAssertTrue(lease.isHit)
|
||||||
|
XCTAssertGreaterThan(lease.matchedTokenCount, 0)
|
||||||
|
XCTAssertLessThan(lease.matchedTokenCount, firstPrepared.tokens.count)
|
||||||
|
}
|
||||||
|
|
||||||
|
func testTokenPrefixCacheCanFalseHitDifferentSystemPromptsOnRawGemmaTokens() async throws {
|
||||||
|
let container = try await localGemmaContainer()
|
||||||
|
let engine = InferenceEngine(container: container)
|
||||||
|
|
||||||
|
let first = PromptBuilder.build(
|
||||||
|
from: APIChatCompletionRequest(
|
||||||
|
model: "gemma",
|
||||||
|
messages: [
|
||||||
|
APIChatMessage(role: "system", content: .text("System Alpha Unique Tokens"), name: nil, tool_calls: nil, tool_call_id: nil),
|
||||||
|
APIChatMessage(role: "user", content: .text("Answer in one word: tree."), name: nil, tool_calls: nil, tool_call_id: nil),
|
||||||
|
],
|
||||||
|
temperature: nil,
|
||||||
|
top_p: nil,
|
||||||
|
max_tokens: nil,
|
||||||
|
stream: nil,
|
||||||
|
stop: nil,
|
||||||
|
tools: nil,
|
||||||
|
tool_choice: nil,
|
||||||
|
frequency_penalty: nil,
|
||||||
|
presence_penalty: nil,
|
||||||
|
n: nil
|
||||||
|
),
|
||||||
|
modelId: "mlx-community/gemma-3-4b-it-4bit",
|
||||||
|
thinkingEnabled: true
|
||||||
|
)
|
||||||
|
let second = PromptBuilder.build(
|
||||||
|
from: APIChatCompletionRequest(
|
||||||
|
model: "gemma",
|
||||||
|
messages: [
|
||||||
|
APIChatMessage(role: "system", content: .text("Completely Different Beta Markers"), name: nil, tool_calls: nil, tool_call_id: nil),
|
||||||
|
APIChatMessage(role: "user", content: .text("Answer in one word: tree."), name: nil, tool_calls: nil, tool_call_id: nil),
|
||||||
|
],
|
||||||
|
temperature: nil,
|
||||||
|
top_p: nil,
|
||||||
|
max_tokens: nil,
|
||||||
|
stream: nil,
|
||||||
|
stop: nil,
|
||||||
|
tools: nil,
|
||||||
|
tool_choice: nil,
|
||||||
|
frequency_penalty: nil,
|
||||||
|
presence_penalty: nil,
|
||||||
|
n: nil
|
||||||
|
),
|
||||||
|
modelId: "mlx-community/gemma-3-4b-it-4bit",
|
||||||
|
thinkingEnabled: true
|
||||||
|
)
|
||||||
|
|
||||||
|
let firstPrepared = try await engine.prepare(first.userInput)
|
||||||
|
let secondPrepared = try await engine.prepare(second.userInput)
|
||||||
|
|
||||||
|
let cache = TokenPrefixCache(memoryBudgetBytes: 1_000_000, estimateBytesProvider: { _ in 1_024 })
|
||||||
|
cache.store(entryId: UUID(), kvCache: [], cacheKey: firstPrepared.tokens, modelId: "gemma")
|
||||||
|
|
||||||
|
let lease = cache.lookup(cacheKey: secondPrepared.tokens, modelId: "gemma")
|
||||||
|
|
||||||
|
XCTAssertFalse(lease.isHit)
|
||||||
|
}
|
||||||
|
|
||||||
private func localGemmaContainer() async throws -> ModelContainer {
|
private func localGemmaContainer() async throws -> ModelContainer {
|
||||||
try await LocalGemmaFixture.shared.container()
|
try await LocalGemmaFixture.shared.container()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private func trimCacheToPrompt(_ cache: [KVCache], promptTokenCount: Int) {
|
||||||
|
for layer in cache {
|
||||||
|
let excess = layer.offset - promptTokenCount
|
||||||
|
if excess > 0 {
|
||||||
|
XCTAssertTrue(layer.isTrimmable)
|
||||||
|
XCTAssertEqual(layer.trim(excess), excess)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
private func legacyBuild(
|
private func legacyBuild(
|
||||||
from request: APIChatCompletionRequest,
|
from request: APIChatCompletionRequest,
|
||||||
modelId: String,
|
modelId: String,
|
||||||
|
|||||||
@@ -127,4 +127,86 @@ final class TokenPrefixCacheTests: XCTestCase {
|
|||||||
XCTAssertEqual(snapshot.totalCachedTokens, 0)
|
XCTAssertEqual(snapshot.totalCachedTokens, 0)
|
||||||
XCTAssertEqual(snapshot.estimatedBytes, 0)
|
XCTAssertEqual(snapshot.estimatedBytes, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func testSupersequenceLookupReusesLongerEntryForShorterQuery() {
|
||||||
|
let cache = TokenPrefixCache(
|
||||||
|
memoryBudgetBytes: 10_000,
|
||||||
|
estimateBytesProvider: { _ in 1_024 }
|
||||||
|
)
|
||||||
|
|
||||||
|
let entryId = UUID()
|
||||||
|
cache.store(entryId: entryId, kvCache: [], cacheKey: [1, 2, 3, 4], modelId: "model")
|
||||||
|
|
||||||
|
let lease = cache.lookup(cacheKey: [1, 2, 3], modelId: "model")
|
||||||
|
let snapshot = cache.snapshot()
|
||||||
|
|
||||||
|
XCTAssertTrue(lease.isHit)
|
||||||
|
XCTAssertEqual(lease.entryId, entryId)
|
||||||
|
XCTAssertEqual(lease.matchedTokenCount, 3)
|
||||||
|
XCTAssertEqual(snapshot.totalHits, 1)
|
||||||
|
XCTAssertEqual(snapshot.supersequenceHits, 1)
|
||||||
|
XCTAssertEqual(snapshot.prefixHits, 0)
|
||||||
|
XCTAssertEqual(snapshot.lcpHits, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
func testLCPLookupReusesSharedPrefixAcrossDivergentSuffixes() {
|
||||||
|
let cache = TokenPrefixCache(
|
||||||
|
memoryBudgetBytes: 10_000,
|
||||||
|
estimateBytesProvider: { _ in 1_024 }
|
||||||
|
)
|
||||||
|
|
||||||
|
let entryId = UUID()
|
||||||
|
cache.store(entryId: entryId, kvCache: [], cacheKey: [10, 20, 90], modelId: "model")
|
||||||
|
|
||||||
|
let lease = cache.lookup(cacheKey: [10, 20, 30], modelId: "model")
|
||||||
|
let snapshot = cache.snapshot()
|
||||||
|
|
||||||
|
XCTAssertTrue(lease.isHit)
|
||||||
|
XCTAssertEqual(lease.entryId, entryId)
|
||||||
|
XCTAssertEqual(lease.matchedTokenCount, 2)
|
||||||
|
XCTAssertEqual(snapshot.totalHits, 1)
|
||||||
|
XCTAssertEqual(snapshot.lcpHits, 1)
|
||||||
|
XCTAssertEqual(snapshot.prefixHits, 0)
|
||||||
|
XCTAssertEqual(snapshot.supersequenceHits, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
func testLCPLookupRejectsShallowSharedPrefix() {
|
||||||
|
let cache = TokenPrefixCache(
|
||||||
|
memoryBudgetBytes: 10_000,
|
||||||
|
estimateBytesProvider: { _ in 1_024 }
|
||||||
|
)
|
||||||
|
|
||||||
|
cache.store(entryId: UUID(), kvCache: [], cacheKey: [10, 20, 30, 40], modelId: "model")
|
||||||
|
|
||||||
|
let lease = cache.lookup(cacheKey: [10, 99, 98, 97], modelId: "model")
|
||||||
|
let snapshot = cache.snapshot()
|
||||||
|
|
||||||
|
XCTAssertFalse(lease.isHit)
|
||||||
|
XCTAssertEqual(lease.matchedTokenCount, 0)
|
||||||
|
XCTAssertEqual(snapshot.totalHits, 0)
|
||||||
|
XCTAssertEqual(snapshot.totalMisses, 1)
|
||||||
|
XCTAssertEqual(snapshot.lcpHits, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
func testLookupPrefersPrefixMatchOverSupersequenceAndLCP() {
|
||||||
|
let cache = TokenPrefixCache(
|
||||||
|
memoryBudgetBytes: 10_000,
|
||||||
|
estimateBytesProvider: { _ in 1_024 }
|
||||||
|
)
|
||||||
|
|
||||||
|
let prefixId = UUID()
|
||||||
|
cache.store(entryId: prefixId, kvCache: [], cacheKey: [7, 8], modelId: "model")
|
||||||
|
cache.store(entryId: UUID(), kvCache: [], cacheKey: [7, 8, 9, 10], modelId: "model")
|
||||||
|
cache.store(entryId: UUID(), kvCache: [], cacheKey: [7, 8, 11], modelId: "model")
|
||||||
|
|
||||||
|
let lease = cache.lookup(cacheKey: [7, 8, 12], modelId: "model")
|
||||||
|
let snapshot = cache.snapshot()
|
||||||
|
|
||||||
|
XCTAssertTrue(lease.isHit)
|
||||||
|
XCTAssertEqual(lease.entryId, prefixId)
|
||||||
|
XCTAssertEqual(lease.matchedTokenCount, 2)
|
||||||
|
XCTAssertEqual(snapshot.prefixHits, 1)
|
||||||
|
XCTAssertEqual(snapshot.supersequenceHits, 0)
|
||||||
|
XCTAssertEqual(snapshot.lcpHits, 0)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
@@ -2609,13 +2609,13 @@ Validation note: `InferenceStats.swift` now samples `TokenPrefixCache` directly
|
|||||||
### Cache Correctness
|
### Cache Correctness
|
||||||
|
|
||||||
- [x] Cold start: no cache entries → fresh generation works
|
- [x] Cold start: no cache entries → fresh generation works
|
||||||
- [ ] Second identical request → full cache hit, zero prefill tokens
|
- [x] Second identical request → full cache hit, zero prefill tokens
|
||||||
- [ ] Conversation continuation (add 1 message) → partial cache hit
|
- [x] Conversation continuation (add 1 message) → partial cache hit
|
||||||
- [x] Conversation continuation (add 2+ messages, e.g. tool-use flow) → partial cache hit (not a miss!)
|
- [x] Conversation continuation (add 2+ messages, e.g. tool-use flow) → partial cache hit (not a miss!)
|
||||||
- [ ] Same system prompt, different user message → system prompt prefix cached and reused
|
- [x] Same system prompt, different user message → system prompt prefix cached and reused
|
||||||
- [ ] Different system prompt → no false cache hit
|
- [x] Different system prompt → no false cache hit
|
||||||
- [ ] Model swap → cache invalidated, fresh generation works
|
- [ ] Model swap → cache invalidated, fresh generation works
|
||||||
- [ ] Idle unload + reload → cache invalidated, fresh generation works
|
- [x] Idle unload + reload → cache invalidated, fresh generation works
|
||||||
|
|
||||||
### Memory Management
|
### Memory Management
|
||||||
|
|
||||||
@@ -2623,7 +2623,7 @@ Validation note: `InferenceStats.swift` now samples `TokenPrefixCache` directly
|
|||||||
- [x] Entries evicted under memory pressure (oldest first)
|
- [x] Entries evicted under memory pressure (oldest first)
|
||||||
- [x] Expired entries pruned after 30 min idle
|
- [x] Expired entries pruned after 30 min idle
|
||||||
- [x] Trie nodes cleaned up when entries are evicted (no memory leak)
|
- [x] Trie nodes cleaned up when entries are evicted (no memory leak)
|
||||||
- [ ] `snapshot()` reports accurate memory usage and hit rates
|
- [x] `snapshot()` reports accurate memory usage and hit rates
|
||||||
|
|
||||||
### Disconnect Handling
|
### Disconnect Handling
|
||||||
|
|
||||||
@@ -2666,16 +2666,16 @@ Validation note: `InferenceStats.swift` now samples `TokenPrefixCache` directly
|
|||||||
|
|
||||||
### Advanced Cache Matching (Section 12)
|
### Advanced Cache Matching (Section 12)
|
||||||
|
|
||||||
- [ ] Supersequence: cached `[A,B,C,D,E]`, query `[A,B,C]` → cache hit, KV trimmed to 3 tokens
|
- [x] Supersequence: cached `[A,B,C,D,E]`, query `[A,B,C]` → cache hit, KV trimmed to 3 tokens
|
||||||
- [ ] Supersequence: cached entry has non-trimmable layers (hybrid model) → graceful skip, falls through to miss
|
- [ ] Supersequence: cached entry has non-trimmable layers (hybrid model) → graceful skip, falls through to miss
|
||||||
- [ ] Supersequence: multiple candidates in subtree → shallowest (least excess) is chosen
|
- [ ] Supersequence: multiple candidates in subtree → shallowest (least excess) is chosen
|
||||||
- [ ] LCP: cached `[SYS,A,B,X,Y]`, query `[SYS,A,B,D,E]` → cache hit covering `[SYS,A,B]`, remaining `[D,E]`
|
- [x] LCP: cached `[SYS,A,B,X,Y]`, query `[SYS,A,B,D,E]` → cache hit covering `[SYS,A,B]`, remaining `[D,E]`
|
||||||
- [ ] LCP: divergence at depth 0 (no shared prefix at all) → no LCP match, clean miss
|
- [ ] LCP: divergence at depth 0 (no shared prefix at all) → no LCP match, clean miss
|
||||||
- [ ] LCP: multiple sibling entries at divergence → best (shallowest) is chosen
|
- [ ] LCP: multiple sibling entries at divergence → best (shallowest) is chosen
|
||||||
- [ ] LCP agentic pattern: same system prompt (500 tokens) + different user message → system prompt cached and reused
|
- [ ] LCP agentic pattern: same system prompt (500 tokens) + different user message → system prompt cached and reused
|
||||||
- [ ] Match priority: prefix match takes priority over supersequence and LCP
|
- [x] Match priority: prefix match takes priority over supersequence and LCP
|
||||||
- [ ] Match priority: supersequence takes priority over LCP
|
- [ ] Match priority: supersequence takes priority over LCP
|
||||||
- [ ] Stats: prefix, supersequence, and LCP hits counted separately in snapshot
|
- [x] Stats: prefix, supersequence, and LCP hits counted separately in snapshot
|
||||||
- [ ] Trim correctness: KVCache.trim() called with correct excess count, offset reduced accordingly
|
- [ ] Trim correctness: KVCache.trim() called with correct excess count, offset reduced accordingly
|
||||||
- [ ] Trim + generate: trimmed cache produces valid generation (no garbled output from stale K/V)
|
- [ ] Trim + generate: trimmed cache produces valid generation (no garbled output from stale K/V)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user