fix: more hardening of cache behaviour and some fixes

This commit is contained in:
2026-03-20 11:43:58 +01:00
parent ee34fd5e84
commit 5aed0107c6
8 changed files with 841 additions and 32 deletions

View File

@@ -92,10 +92,211 @@ final class ModelBackedInferenceValidationTests: XCTestCase {
XCTAssertEqual(engineResult.promptTokenCount, sessionResult.promptTokenCount)
}
func testTokenPrefixCacheFindsLCPHitForSameSystemDifferentUserOnLocalGemmaTokens() async throws {
let container = try await localGemmaContainer()
let engine = InferenceEngine(container: container)
let first = PromptBuilder.build(
from: APIChatCompletionRequest(
model: "gemma",
messages: [
APIChatMessage(role: "system", content: .text("You are terse and literal."), name: nil, tool_calls: nil, tool_call_id: nil),
APIChatMessage(role: "user", content: .text("Respond with one word for cat."), name: nil, tool_calls: nil, tool_call_id: nil),
],
temperature: nil,
top_p: nil,
max_tokens: nil,
stream: nil,
stop: nil,
tools: nil,
tool_choice: nil,
frequency_penalty: nil,
presence_penalty: nil,
n: nil
),
modelId: "mlx-community/gemma-3-4b-it-4bit",
thinkingEnabled: true
)
let second = PromptBuilder.build(
from: APIChatCompletionRequest(
model: "gemma",
messages: [
APIChatMessage(role: "system", content: .text("You are terse and literal."), name: nil, tool_calls: nil, tool_call_id: nil),
APIChatMessage(role: "user", content: .text("Respond with one word for dog."), name: nil, tool_calls: nil, tool_call_id: nil),
],
temperature: nil,
top_p: nil,
max_tokens: nil,
stream: nil,
stop: nil,
tools: nil,
tool_choice: nil,
frequency_penalty: nil,
presence_penalty: nil,
n: nil
),
modelId: "mlx-community/gemma-3-4b-it-4bit",
thinkingEnabled: true
)
let firstPrepared = try await engine.prepare(first.userInput)
let secondPrepared = try await engine.prepare(second.userInput)
let cache = TokenPrefixCache(memoryBudgetBytes: 1_000_000, estimateBytesProvider: { _ in 1_024 })
cache.store(entryId: UUID(), kvCache: [], cacheKey: firstPrepared.tokens, modelId: "gemma")
let lease = cache.lookup(cacheKey: secondPrepared.tokens, modelId: "gemma")
XCTAssertTrue(lease.isHit)
XCTAssertGreaterThan(lease.matchedTokenCount, 0)
XCTAssertLessThan(lease.matchedTokenCount, firstPrepared.tokens.count)
}
func testStoredLiveGemmaCacheSupportsSameSystemDifferentUserLCPReuse() async throws {
let container = try await localGemmaContainer()
let engine = InferenceEngine(container: container)
let first = PromptBuilder.build(
from: APIChatCompletionRequest(
model: "gemma",
messages: [
APIChatMessage(role: "system", content: .text("You are terse and literal."), name: nil, tool_calls: nil, tool_call_id: nil),
APIChatMessage(role: "user", content: .text("Respond with one word for cat."), name: nil, tool_calls: nil, tool_call_id: nil),
],
temperature: nil,
top_p: nil,
max_tokens: nil,
stream: nil,
stop: nil,
tools: nil,
tool_choice: nil,
frequency_penalty: nil,
presence_penalty: nil,
n: nil
),
modelId: "mlx-community/gemma-3-4b-it-4bit",
thinkingEnabled: true
)
let second = PromptBuilder.build(
from: APIChatCompletionRequest(
model: "gemma",
messages: [
APIChatMessage(role: "system", content: .text("You are terse and literal."), name: nil, tool_calls: nil, tool_call_id: nil),
APIChatMessage(role: "user", content: .text("Respond with one word for dog."), name: nil, tool_calls: nil, tool_call_id: nil),
],
temperature: nil,
top_p: nil,
max_tokens: nil,
stream: nil,
stop: nil,
tools: nil,
tool_choice: nil,
frequency_penalty: nil,
presence_penalty: nil,
n: nil
),
modelId: "mlx-community/gemma-3-4b-it-4bit",
thinkingEnabled: true
)
let firstPrepared = try await engine.prepare(first.userInput)
let secondPrepared = try await engine.prepare(second.userInput)
let handle = try await engine.stream(
InferenceEngine.InferenceRequest(
input: firstPrepared.lmInput,
tokens: firstPrepared.tokens,
parameters: GenerateParameters(maxTokens: 2, temperature: 0),
cachedKV: nil,
cachedTokenCount: 0
),
cancellation: CancellationToken()
)
_ = await collectEngineOutput(handle.stream)
trimCacheToPrompt(handle.workingCache, promptTokenCount: firstPrepared.tokens.count)
let cache = TokenPrefixCache(memoryBudgetBytes: 1_000_000_000, estimateBytesProvider: { _ in 1_024 })
cache.store(entryId: UUID(), kvCache: handle.workingCache, cacheKey: firstPrepared.tokens, modelId: "gemma")
let lease = cache.lookup(cacheKey: secondPrepared.tokens, modelId: "gemma")
XCTAssertTrue(lease.isHit)
XCTAssertGreaterThan(lease.matchedTokenCount, 0)
XCTAssertLessThan(lease.matchedTokenCount, firstPrepared.tokens.count)
}
func testTokenPrefixCacheCanFalseHitDifferentSystemPromptsOnRawGemmaTokens() async throws {
let container = try await localGemmaContainer()
let engine = InferenceEngine(container: container)
let first = PromptBuilder.build(
from: APIChatCompletionRequest(
model: "gemma",
messages: [
APIChatMessage(role: "system", content: .text("System Alpha Unique Tokens"), name: nil, tool_calls: nil, tool_call_id: nil),
APIChatMessage(role: "user", content: .text("Answer in one word: tree."), name: nil, tool_calls: nil, tool_call_id: nil),
],
temperature: nil,
top_p: nil,
max_tokens: nil,
stream: nil,
stop: nil,
tools: nil,
tool_choice: nil,
frequency_penalty: nil,
presence_penalty: nil,
n: nil
),
modelId: "mlx-community/gemma-3-4b-it-4bit",
thinkingEnabled: true
)
let second = PromptBuilder.build(
from: APIChatCompletionRequest(
model: "gemma",
messages: [
APIChatMessage(role: "system", content: .text("Completely Different Beta Markers"), name: nil, tool_calls: nil, tool_call_id: nil),
APIChatMessage(role: "user", content: .text("Answer in one word: tree."), name: nil, tool_calls: nil, tool_call_id: nil),
],
temperature: nil,
top_p: nil,
max_tokens: nil,
stream: nil,
stop: nil,
tools: nil,
tool_choice: nil,
frequency_penalty: nil,
presence_penalty: nil,
n: nil
),
modelId: "mlx-community/gemma-3-4b-it-4bit",
thinkingEnabled: true
)
let firstPrepared = try await engine.prepare(first.userInput)
let secondPrepared = try await engine.prepare(second.userInput)
let cache = TokenPrefixCache(memoryBudgetBytes: 1_000_000, estimateBytesProvider: { _ in 1_024 })
cache.store(entryId: UUID(), kvCache: [], cacheKey: firstPrepared.tokens, modelId: "gemma")
let lease = cache.lookup(cacheKey: secondPrepared.tokens, modelId: "gemma")
XCTAssertFalse(lease.isHit)
}
private func localGemmaContainer() async throws -> ModelContainer {
try await LocalGemmaFixture.shared.container()
}
private func trimCacheToPrompt(_ cache: [KVCache], promptTokenCount: Int) {
for layer in cache {
let excess = layer.offset - promptTokenCount
if excess > 0 {
XCTAssertTrue(layer.isTrimmable)
XCTAssertEqual(layer.trim(excess), excess)
}
}
}
private func legacyBuild(
from request: APIChatCompletionRequest,
modelId: String,