fix: more hardening of cache behaviour and some fixes
This commit is contained in:
@@ -92,10 +92,211 @@ final class ModelBackedInferenceValidationTests: XCTestCase {
|
||||
XCTAssertEqual(engineResult.promptTokenCount, sessionResult.promptTokenCount)
|
||||
}
|
||||
|
||||
func testTokenPrefixCacheFindsLCPHitForSameSystemDifferentUserOnLocalGemmaTokens() async throws {
|
||||
let container = try await localGemmaContainer()
|
||||
let engine = InferenceEngine(container: container)
|
||||
|
||||
let first = PromptBuilder.build(
|
||||
from: APIChatCompletionRequest(
|
||||
model: "gemma",
|
||||
messages: [
|
||||
APIChatMessage(role: "system", content: .text("You are terse and literal."), name: nil, tool_calls: nil, tool_call_id: nil),
|
||||
APIChatMessage(role: "user", content: .text("Respond with one word for cat."), name: nil, tool_calls: nil, tool_call_id: nil),
|
||||
],
|
||||
temperature: nil,
|
||||
top_p: nil,
|
||||
max_tokens: nil,
|
||||
stream: nil,
|
||||
stop: nil,
|
||||
tools: nil,
|
||||
tool_choice: nil,
|
||||
frequency_penalty: nil,
|
||||
presence_penalty: nil,
|
||||
n: nil
|
||||
),
|
||||
modelId: "mlx-community/gemma-3-4b-it-4bit",
|
||||
thinkingEnabled: true
|
||||
)
|
||||
let second = PromptBuilder.build(
|
||||
from: APIChatCompletionRequest(
|
||||
model: "gemma",
|
||||
messages: [
|
||||
APIChatMessage(role: "system", content: .text("You are terse and literal."), name: nil, tool_calls: nil, tool_call_id: nil),
|
||||
APIChatMessage(role: "user", content: .text("Respond with one word for dog."), name: nil, tool_calls: nil, tool_call_id: nil),
|
||||
],
|
||||
temperature: nil,
|
||||
top_p: nil,
|
||||
max_tokens: nil,
|
||||
stream: nil,
|
||||
stop: nil,
|
||||
tools: nil,
|
||||
tool_choice: nil,
|
||||
frequency_penalty: nil,
|
||||
presence_penalty: nil,
|
||||
n: nil
|
||||
),
|
||||
modelId: "mlx-community/gemma-3-4b-it-4bit",
|
||||
thinkingEnabled: true
|
||||
)
|
||||
|
||||
let firstPrepared = try await engine.prepare(first.userInput)
|
||||
let secondPrepared = try await engine.prepare(second.userInput)
|
||||
|
||||
let cache = TokenPrefixCache(memoryBudgetBytes: 1_000_000, estimateBytesProvider: { _ in 1_024 })
|
||||
cache.store(entryId: UUID(), kvCache: [], cacheKey: firstPrepared.tokens, modelId: "gemma")
|
||||
|
||||
let lease = cache.lookup(cacheKey: secondPrepared.tokens, modelId: "gemma")
|
||||
|
||||
XCTAssertTrue(lease.isHit)
|
||||
XCTAssertGreaterThan(lease.matchedTokenCount, 0)
|
||||
XCTAssertLessThan(lease.matchedTokenCount, firstPrepared.tokens.count)
|
||||
}
|
||||
|
||||
func testStoredLiveGemmaCacheSupportsSameSystemDifferentUserLCPReuse() async throws {
|
||||
let container = try await localGemmaContainer()
|
||||
let engine = InferenceEngine(container: container)
|
||||
|
||||
let first = PromptBuilder.build(
|
||||
from: APIChatCompletionRequest(
|
||||
model: "gemma",
|
||||
messages: [
|
||||
APIChatMessage(role: "system", content: .text("You are terse and literal."), name: nil, tool_calls: nil, tool_call_id: nil),
|
||||
APIChatMessage(role: "user", content: .text("Respond with one word for cat."), name: nil, tool_calls: nil, tool_call_id: nil),
|
||||
],
|
||||
temperature: nil,
|
||||
top_p: nil,
|
||||
max_tokens: nil,
|
||||
stream: nil,
|
||||
stop: nil,
|
||||
tools: nil,
|
||||
tool_choice: nil,
|
||||
frequency_penalty: nil,
|
||||
presence_penalty: nil,
|
||||
n: nil
|
||||
),
|
||||
modelId: "mlx-community/gemma-3-4b-it-4bit",
|
||||
thinkingEnabled: true
|
||||
)
|
||||
let second = PromptBuilder.build(
|
||||
from: APIChatCompletionRequest(
|
||||
model: "gemma",
|
||||
messages: [
|
||||
APIChatMessage(role: "system", content: .text("You are terse and literal."), name: nil, tool_calls: nil, tool_call_id: nil),
|
||||
APIChatMessage(role: "user", content: .text("Respond with one word for dog."), name: nil, tool_calls: nil, tool_call_id: nil),
|
||||
],
|
||||
temperature: nil,
|
||||
top_p: nil,
|
||||
max_tokens: nil,
|
||||
stream: nil,
|
||||
stop: nil,
|
||||
tools: nil,
|
||||
tool_choice: nil,
|
||||
frequency_penalty: nil,
|
||||
presence_penalty: nil,
|
||||
n: nil
|
||||
),
|
||||
modelId: "mlx-community/gemma-3-4b-it-4bit",
|
||||
thinkingEnabled: true
|
||||
)
|
||||
|
||||
let firstPrepared = try await engine.prepare(first.userInput)
|
||||
let secondPrepared = try await engine.prepare(second.userInput)
|
||||
let handle = try await engine.stream(
|
||||
InferenceEngine.InferenceRequest(
|
||||
input: firstPrepared.lmInput,
|
||||
tokens: firstPrepared.tokens,
|
||||
parameters: GenerateParameters(maxTokens: 2, temperature: 0),
|
||||
cachedKV: nil,
|
||||
cachedTokenCount: 0
|
||||
),
|
||||
cancellation: CancellationToken()
|
||||
)
|
||||
|
||||
_ = await collectEngineOutput(handle.stream)
|
||||
trimCacheToPrompt(handle.workingCache, promptTokenCount: firstPrepared.tokens.count)
|
||||
|
||||
let cache = TokenPrefixCache(memoryBudgetBytes: 1_000_000_000, estimateBytesProvider: { _ in 1_024 })
|
||||
cache.store(entryId: UUID(), kvCache: handle.workingCache, cacheKey: firstPrepared.tokens, modelId: "gemma")
|
||||
|
||||
let lease = cache.lookup(cacheKey: secondPrepared.tokens, modelId: "gemma")
|
||||
|
||||
XCTAssertTrue(lease.isHit)
|
||||
XCTAssertGreaterThan(lease.matchedTokenCount, 0)
|
||||
XCTAssertLessThan(lease.matchedTokenCount, firstPrepared.tokens.count)
|
||||
}
|
||||
|
||||
func testTokenPrefixCacheCanFalseHitDifferentSystemPromptsOnRawGemmaTokens() async throws {
|
||||
let container = try await localGemmaContainer()
|
||||
let engine = InferenceEngine(container: container)
|
||||
|
||||
let first = PromptBuilder.build(
|
||||
from: APIChatCompletionRequest(
|
||||
model: "gemma",
|
||||
messages: [
|
||||
APIChatMessage(role: "system", content: .text("System Alpha Unique Tokens"), name: nil, tool_calls: nil, tool_call_id: nil),
|
||||
APIChatMessage(role: "user", content: .text("Answer in one word: tree."), name: nil, tool_calls: nil, tool_call_id: nil),
|
||||
],
|
||||
temperature: nil,
|
||||
top_p: nil,
|
||||
max_tokens: nil,
|
||||
stream: nil,
|
||||
stop: nil,
|
||||
tools: nil,
|
||||
tool_choice: nil,
|
||||
frequency_penalty: nil,
|
||||
presence_penalty: nil,
|
||||
n: nil
|
||||
),
|
||||
modelId: "mlx-community/gemma-3-4b-it-4bit",
|
||||
thinkingEnabled: true
|
||||
)
|
||||
let second = PromptBuilder.build(
|
||||
from: APIChatCompletionRequest(
|
||||
model: "gemma",
|
||||
messages: [
|
||||
APIChatMessage(role: "system", content: .text("Completely Different Beta Markers"), name: nil, tool_calls: nil, tool_call_id: nil),
|
||||
APIChatMessage(role: "user", content: .text("Answer in one word: tree."), name: nil, tool_calls: nil, tool_call_id: nil),
|
||||
],
|
||||
temperature: nil,
|
||||
top_p: nil,
|
||||
max_tokens: nil,
|
||||
stream: nil,
|
||||
stop: nil,
|
||||
tools: nil,
|
||||
tool_choice: nil,
|
||||
frequency_penalty: nil,
|
||||
presence_penalty: nil,
|
||||
n: nil
|
||||
),
|
||||
modelId: "mlx-community/gemma-3-4b-it-4bit",
|
||||
thinkingEnabled: true
|
||||
)
|
||||
|
||||
let firstPrepared = try await engine.prepare(first.userInput)
|
||||
let secondPrepared = try await engine.prepare(second.userInput)
|
||||
|
||||
let cache = TokenPrefixCache(memoryBudgetBytes: 1_000_000, estimateBytesProvider: { _ in 1_024 })
|
||||
cache.store(entryId: UUID(), kvCache: [], cacheKey: firstPrepared.tokens, modelId: "gemma")
|
||||
|
||||
let lease = cache.lookup(cacheKey: secondPrepared.tokens, modelId: "gemma")
|
||||
|
||||
XCTAssertFalse(lease.isHit)
|
||||
}
|
||||
|
||||
private func localGemmaContainer() async throws -> ModelContainer {
|
||||
try await LocalGemmaFixture.shared.container()
|
||||
}
|
||||
|
||||
private func trimCacheToPrompt(_ cache: [KVCache], promptTokenCount: Int) {
|
||||
for layer in cache {
|
||||
let excess = layer.offset - promptTokenCount
|
||||
if excess > 0 {
|
||||
XCTAssertTrue(layer.isTrimmable)
|
||||
XCTAssertEqual(layer.trim(excess), excess)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private func legacyBuild(
|
||||
from request: APIChatCompletionRequest,
|
||||
modelId: String,
|
||||
|
||||
Reference in New Issue
Block a user