fix: more hardening of cache behaviour and some fixes
This commit is contained in:
@@ -3,6 +3,20 @@ import XCTest
|
||||
@testable import MLX_Server
|
||||
|
||||
final class APIServerRewriteTests: XCTestCase {
|
||||
func testHealthAndModelsEndpointsReturnExpectedPayloads() async throws {
|
||||
let harness = try await makeHarness()
|
||||
defer { harness.stop() }
|
||||
|
||||
let health = try await sendRawRequest(path: "/health", port: harness.port)
|
||||
XCTAssertEqual(health.statusCode, 200)
|
||||
XCTAssertEqual(health.body, #"{"status":"ok"}"#)
|
||||
|
||||
let models = try await sendModelsRequest(port: harness.port)
|
||||
XCTAssertFalse(models.data.isEmpty)
|
||||
XCTAssertTrue(models.data.contains { $0.id == ModelConfig.default.repoId })
|
||||
XCTAssertTrue(models.data.allSatisfy { $0.context_window != nil })
|
||||
}
|
||||
|
||||
func testNonStreamingChatCompletionUsesStatelessServerPathAndCachesPrompt() async throws {
|
||||
let harness = try await makeHarness()
|
||||
defer { harness.stop() }
|
||||
@@ -51,6 +65,306 @@ final class APIServerRewriteTests: XCTestCase {
|
||||
XCTAssertGreaterThan(secondLiveSnapshot.cacheMatchDepth, 0)
|
||||
}
|
||||
|
||||
func testSecondIdenticalRequestIsFullCacheHitWithZeroRebuiltPromptTokens() async throws {
|
||||
let harness = try await makeHarness()
|
||||
defer { harness.stop() }
|
||||
|
||||
let request = APIChatCompletionRequest(
|
||||
model: "gemma",
|
||||
messages: [
|
||||
APIChatMessage(role: "user", content: .text("Answer with one word: ocean."), name: nil, tool_calls: nil, tool_call_id: nil)
|
||||
],
|
||||
temperature: 0,
|
||||
top_p: 1,
|
||||
max_tokens: 2,
|
||||
stream: false,
|
||||
stop: nil,
|
||||
tools: nil,
|
||||
tool_choice: nil,
|
||||
frequency_penalty: nil,
|
||||
presence_penalty: nil,
|
||||
n: nil
|
||||
)
|
||||
|
||||
_ = try await sendChatCompletion(request, port: harness.port)
|
||||
_ = try await sendChatCompletion(request, port: harness.port)
|
||||
|
||||
let live = LiveCounters.shared.snapshot()
|
||||
XCTAssertGreaterThan(live.currentCacheMatchedPromptTokens, 0)
|
||||
XCTAssertEqual(live.currentCacheMatchedPromptTokens, live.promptTokens)
|
||||
XCTAssertEqual(live.currentCacheRebuiltPromptTokens, 0)
|
||||
}
|
||||
|
||||
func testSingleTurnContinuationProducesPartialCacheHit() async throws {
|
||||
let harness = try await makeHarness()
|
||||
defer { harness.stop() }
|
||||
|
||||
let firstRequest = APIChatCompletionRequest(
|
||||
model: "gemma",
|
||||
messages: [
|
||||
APIChatMessage(role: "user", content: .text("Answer in one word: sun."), name: nil, tool_calls: nil, tool_call_id: nil)
|
||||
],
|
||||
temperature: 0,
|
||||
top_p: 1,
|
||||
max_tokens: 2,
|
||||
stream: true,
|
||||
stop: nil,
|
||||
tools: nil,
|
||||
tool_choice: nil,
|
||||
frequency_penalty: nil,
|
||||
presence_penalty: nil,
|
||||
n: nil
|
||||
)
|
||||
|
||||
let firstStream = try await sendStreamingChatCompletion(firstRequest, port: harness.port)
|
||||
XCTAssertFalse(firstStream.content.isEmpty)
|
||||
|
||||
let secondRequest = APIChatCompletionRequest(
|
||||
model: "gemma",
|
||||
messages: [
|
||||
APIChatMessage(role: "user", content: .text("Answer in one word: sun."), name: nil, tool_calls: nil, tool_call_id: nil),
|
||||
APIChatMessage(role: "assistant", content: .text(firstStream.content), name: nil, tool_calls: nil, tool_call_id: nil),
|
||||
APIChatMessage(role: "user", content: .text("Answer in one word: moon."), name: nil, tool_calls: nil, tool_call_id: nil)
|
||||
],
|
||||
temperature: 0,
|
||||
top_p: 1,
|
||||
max_tokens: 2,
|
||||
stream: false,
|
||||
stop: nil,
|
||||
tools: nil,
|
||||
tool_choice: nil,
|
||||
frequency_penalty: nil,
|
||||
presence_penalty: nil,
|
||||
n: nil
|
||||
)
|
||||
|
||||
_ = try await sendChatCompletion(secondRequest, port: harness.port)
|
||||
|
||||
let live = LiveCounters.shared.snapshot()
|
||||
XCTAssertGreaterThan(live.currentCacheMatchedPromptTokens, 0)
|
||||
XCTAssertGreaterThan(live.currentCacheRebuiltPromptTokens, 0)
|
||||
}
|
||||
|
||||
func testSameSystemPromptDifferentUserMessageReusesSystemPrefix() async throws {
|
||||
let harness = try await makeHarness()
|
||||
defer { harness.stop() }
|
||||
|
||||
let lookups = LookupEventCollector()
|
||||
APIServer.debugLookupEventHandler = { event in
|
||||
Task {
|
||||
await lookups.record(event)
|
||||
}
|
||||
}
|
||||
defer {
|
||||
APIServer.debugLookupEventHandler = nil
|
||||
}
|
||||
|
||||
let firstRequest = APIChatCompletionRequest(
|
||||
model: "gemma",
|
||||
messages: [
|
||||
APIChatMessage(role: "system", content: .text("You are terse and literal."), name: nil, tool_calls: nil, tool_call_id: nil),
|
||||
APIChatMessage(role: "user", content: .text("Respond with one word for cat."), name: nil, tool_calls: nil, tool_call_id: nil)
|
||||
],
|
||||
temperature: 0,
|
||||
top_p: 1,
|
||||
max_tokens: 2,
|
||||
stream: false,
|
||||
stop: nil,
|
||||
tools: nil,
|
||||
tool_choice: nil,
|
||||
frequency_penalty: nil,
|
||||
presence_penalty: nil,
|
||||
n: nil
|
||||
)
|
||||
|
||||
let secondRequest = APIChatCompletionRequest(
|
||||
model: "gemma",
|
||||
messages: [
|
||||
APIChatMessage(role: "system", content: .text("You are terse and literal."), name: nil, tool_calls: nil, tool_call_id: nil),
|
||||
APIChatMessage(role: "user", content: .text("Respond with one word for dog."), name: nil, tool_calls: nil, tool_call_id: nil)
|
||||
],
|
||||
temperature: 0,
|
||||
top_p: 1,
|
||||
max_tokens: 2,
|
||||
stream: false,
|
||||
stop: nil,
|
||||
tools: nil,
|
||||
tool_choice: nil,
|
||||
frequency_penalty: nil,
|
||||
presence_penalty: nil,
|
||||
n: nil
|
||||
)
|
||||
|
||||
_ = try await sendChatCompletion(firstRequest, port: harness.port)
|
||||
_ = try await sendChatCompletion(secondRequest, port: harness.port)
|
||||
|
||||
try await waitUntil(timeoutSeconds: 5) {
|
||||
let events = await lookups.events()
|
||||
return events.count >= 2
|
||||
}
|
||||
|
||||
let events = await lookups.events()
|
||||
let secondLookup = try XCTUnwrap(events.last)
|
||||
XCTAssertEqual(secondLookup.modelId, "gemma")
|
||||
XCTAssertGreaterThan(secondLookup.promptTokenCount, 0)
|
||||
XCTAssertTrue(secondLookup.isHit)
|
||||
XCTAssertGreaterThan(secondLookup.matchedTokenCount, 0)
|
||||
XCTAssertLessThan(secondLookup.matchedTokenCount, secondLookup.promptTokenCount)
|
||||
}
|
||||
|
||||
func testServerStoredCacheIsDirectlyReusableForSameSystemDifferentUserPrompt() async throws {
|
||||
let harness = try await makeHarness()
|
||||
defer { harness.stop() }
|
||||
|
||||
let firstRequest = APIChatCompletionRequest(
|
||||
model: "gemma",
|
||||
messages: [
|
||||
APIChatMessage(role: "system", content: .text("You are terse and literal."), name: nil, tool_calls: nil, tool_call_id: nil),
|
||||
APIChatMessage(role: "user", content: .text("Respond with one word for cat."), name: nil, tool_calls: nil, tool_call_id: nil)
|
||||
],
|
||||
temperature: 0,
|
||||
top_p: 1,
|
||||
max_tokens: 2,
|
||||
stream: false,
|
||||
stop: nil,
|
||||
tools: nil,
|
||||
tool_choice: nil,
|
||||
frequency_penalty: nil,
|
||||
presence_penalty: nil,
|
||||
n: nil
|
||||
)
|
||||
|
||||
_ = try await sendChatCompletion(firstRequest, port: harness.port)
|
||||
|
||||
let secondRequest = APIChatCompletionRequest(
|
||||
model: "gemma",
|
||||
messages: [
|
||||
APIChatMessage(role: "system", content: .text("You are terse and literal."), name: nil, tool_calls: nil, tool_call_id: nil),
|
||||
APIChatMessage(role: "user", content: .text("Respond with one word for dog."), name: nil, tool_calls: nil, tool_call_id: nil)
|
||||
],
|
||||
temperature: 0,
|
||||
top_p: 1,
|
||||
max_tokens: 2,
|
||||
stream: false,
|
||||
stop: nil,
|
||||
tools: nil,
|
||||
tool_choice: nil,
|
||||
frequency_penalty: nil,
|
||||
presence_penalty: nil,
|
||||
n: nil
|
||||
)
|
||||
|
||||
let modelContainer = await MainActor.run { harness.modelManager.modelContainer }
|
||||
let container = try XCTUnwrap(modelContainer)
|
||||
let engine = InferenceEngine(container: container)
|
||||
let preparedPrompt = PromptBuilder.build(
|
||||
from: secondRequest,
|
||||
modelId: ModelConfig.default.repoId,
|
||||
thinkingEnabled: Preferences.enableThinking
|
||||
)
|
||||
let preparedInference = try await engine.prepare(preparedPrompt.userInput)
|
||||
|
||||
let lease = TokenPrefixCache.shared.lookup(cacheKey: preparedInference.tokens, modelId: "gemma")
|
||||
|
||||
XCTAssertTrue(lease.isHit)
|
||||
XCTAssertGreaterThan(lease.matchedTokenCount, 0)
|
||||
}
|
||||
|
||||
func testDifferentSystemPromptDoesNotProduceFalseCacheHit() async throws {
|
||||
let harness = try await makeHarness()
|
||||
defer { harness.stop() }
|
||||
|
||||
let firstRequest = APIChatCompletionRequest(
|
||||
model: "gemma",
|
||||
messages: [
|
||||
APIChatMessage(role: "system", content: .text("System Alpha Unique Tokens"), name: nil, tool_calls: nil, tool_call_id: nil),
|
||||
APIChatMessage(role: "user", content: .text("Answer in one word: tree."), name: nil, tool_calls: nil, tool_call_id: nil)
|
||||
],
|
||||
temperature: 0,
|
||||
top_p: 1,
|
||||
max_tokens: 2,
|
||||
stream: false,
|
||||
stop: nil,
|
||||
tools: nil,
|
||||
tool_choice: nil,
|
||||
frequency_penalty: nil,
|
||||
presence_penalty: nil,
|
||||
n: nil
|
||||
)
|
||||
|
||||
let secondRequest = APIChatCompletionRequest(
|
||||
model: "gemma",
|
||||
messages: [
|
||||
APIChatMessage(role: "system", content: .text("Completely Different Beta Markers"), name: nil, tool_calls: nil, tool_call_id: nil),
|
||||
APIChatMessage(role: "user", content: .text("Answer in one word: tree."), name: nil, tool_calls: nil, tool_call_id: nil)
|
||||
],
|
||||
temperature: 0,
|
||||
top_p: 1,
|
||||
max_tokens: 2,
|
||||
stream: false,
|
||||
stop: nil,
|
||||
tools: nil,
|
||||
tool_choice: nil,
|
||||
frequency_penalty: nil,
|
||||
presence_penalty: nil,
|
||||
n: nil
|
||||
)
|
||||
|
||||
_ = try await sendChatCompletion(firstRequest, port: harness.port)
|
||||
let before = TokenPrefixCache.shared.snapshot()
|
||||
_ = try await sendChatCompletion(secondRequest, port: harness.port)
|
||||
|
||||
let after = TokenPrefixCache.shared.snapshot()
|
||||
let live = LiveCounters.shared.snapshot()
|
||||
XCTAssertEqual(after.totalHits, before.totalHits)
|
||||
XCTAssertEqual(live.currentCacheMatchedPromptTokens, 0)
|
||||
}
|
||||
|
||||
func testIdleUnloadReloadInvalidatesCacheAndServesFreshRequest() async throws {
|
||||
let harness = try await makeHarness()
|
||||
defer { harness.stop() }
|
||||
|
||||
Preferences.lastModelId = "gemma"
|
||||
let request = APIChatCompletionRequest(
|
||||
model: nil,
|
||||
messages: [
|
||||
APIChatMessage(role: "user", content: .text("Answer in one word: cloud."), name: nil, tool_calls: nil, tool_call_id: nil)
|
||||
],
|
||||
temperature: 0,
|
||||
top_p: 1,
|
||||
max_tokens: 2,
|
||||
stream: false,
|
||||
stop: nil,
|
||||
tools: nil,
|
||||
tool_choice: nil,
|
||||
frequency_penalty: nil,
|
||||
presence_penalty: nil,
|
||||
n: nil
|
||||
)
|
||||
|
||||
_ = try await sendChatCompletion(request, port: harness.port)
|
||||
try await waitUntil(timeoutSeconds: 5) {
|
||||
TokenPrefixCache.shared.snapshot().totalEntries > 0
|
||||
}
|
||||
|
||||
await MainActor.run {
|
||||
harness.modelManager.unloadModel()
|
||||
}
|
||||
let wasReadyAfterUnload = await MainActor.run { harness.modelManager.isReady }
|
||||
XCTAssertFalse(wasReadyAfterUnload)
|
||||
|
||||
let before = TokenPrefixCache.shared.snapshot()
|
||||
let response = try await sendChatCompletion(request, port: harness.port)
|
||||
XCTAssertEqual(response.choices.count, 1)
|
||||
let isReadyAfterReload = await MainActor.run { harness.modelManager.isReady }
|
||||
XCTAssertTrue(isReadyAfterReload)
|
||||
|
||||
let after = TokenPrefixCache.shared.snapshot()
|
||||
let live = LiveCounters.shared.snapshot()
|
||||
XCTAssertEqual(after.totalHits, before.totalHits)
|
||||
XCTAssertEqual(live.currentCacheMatchedPromptTokens, 0)
|
||||
}
|
||||
|
||||
func testStreamingChatCompletionReusesCacheAcrossThreeProgressivelyLongerTurns() async throws {
|
||||
let harness = try await makeHarness()
|
||||
defer { harness.stop() }
|
||||
@@ -568,6 +882,19 @@ final class APIServerRewriteTests: XCTestCase {
|
||||
return try JSONDecoder().decode(APIChatCompletionResponse.self, from: data)
|
||||
}
|
||||
|
||||
private func sendModelsRequest(port: UInt16) async throws -> APIModelListResponse {
|
||||
let response = try await sendRawRequest(path: "/v1/models", port: port)
|
||||
XCTAssertEqual(response.statusCode, 200)
|
||||
return try JSONDecoder().decode(APIModelListResponse.self, from: response.bodyData)
|
||||
}
|
||||
|
||||
private func sendRawRequest(path: String, port: UInt16) async throws -> (statusCode: Int, body: String, bodyData: Data) {
|
||||
let url = URL(string: "http://127.0.0.1:\(port)\(path)")!
|
||||
let (data, response) = try await URLSession.shared.data(from: url)
|
||||
let httpResponse = try XCTUnwrap(response as? HTTPURLResponse)
|
||||
return (httpResponse.statusCode, String(data: data, encoding: .utf8) ?? "", data)
|
||||
}
|
||||
|
||||
private func sendStreamingChatCompletion(_ request: APIChatCompletionRequest, port: UInt16) async throws -> StreamingResult {
|
||||
let detailed = try await sendStreamingChatCompletionDetailed(request, port: port)
|
||||
return StreamingResult(
|
||||
@@ -695,6 +1022,18 @@ private actor StreamCancellationObserver {
|
||||
}
|
||||
}
|
||||
|
||||
private actor LookupEventCollector {
|
||||
private var recorded: [APIServer.DebugLookupEvent] = []
|
||||
|
||||
func record(_ event: APIServer.DebugLookupEvent) {
|
||||
recorded.append(event)
|
||||
}
|
||||
|
||||
func events() -> [APIServer.DebugLookupEvent] {
|
||||
recorded
|
||||
}
|
||||
}
|
||||
|
||||
private struct DetailedStreamingResult {
|
||||
let events: [StreamingEvent]
|
||||
let sawDone: Bool
|
||||
|
||||
@@ -92,10 +92,211 @@ final class ModelBackedInferenceValidationTests: XCTestCase {
|
||||
XCTAssertEqual(engineResult.promptTokenCount, sessionResult.promptTokenCount)
|
||||
}
|
||||
|
||||
func testTokenPrefixCacheFindsLCPHitForSameSystemDifferentUserOnLocalGemmaTokens() async throws {
|
||||
let container = try await localGemmaContainer()
|
||||
let engine = InferenceEngine(container: container)
|
||||
|
||||
let first = PromptBuilder.build(
|
||||
from: APIChatCompletionRequest(
|
||||
model: "gemma",
|
||||
messages: [
|
||||
APIChatMessage(role: "system", content: .text("You are terse and literal."), name: nil, tool_calls: nil, tool_call_id: nil),
|
||||
APIChatMessage(role: "user", content: .text("Respond with one word for cat."), name: nil, tool_calls: nil, tool_call_id: nil),
|
||||
],
|
||||
temperature: nil,
|
||||
top_p: nil,
|
||||
max_tokens: nil,
|
||||
stream: nil,
|
||||
stop: nil,
|
||||
tools: nil,
|
||||
tool_choice: nil,
|
||||
frequency_penalty: nil,
|
||||
presence_penalty: nil,
|
||||
n: nil
|
||||
),
|
||||
modelId: "mlx-community/gemma-3-4b-it-4bit",
|
||||
thinkingEnabled: true
|
||||
)
|
||||
let second = PromptBuilder.build(
|
||||
from: APIChatCompletionRequest(
|
||||
model: "gemma",
|
||||
messages: [
|
||||
APIChatMessage(role: "system", content: .text("You are terse and literal."), name: nil, tool_calls: nil, tool_call_id: nil),
|
||||
APIChatMessage(role: "user", content: .text("Respond with one word for dog."), name: nil, tool_calls: nil, tool_call_id: nil),
|
||||
],
|
||||
temperature: nil,
|
||||
top_p: nil,
|
||||
max_tokens: nil,
|
||||
stream: nil,
|
||||
stop: nil,
|
||||
tools: nil,
|
||||
tool_choice: nil,
|
||||
frequency_penalty: nil,
|
||||
presence_penalty: nil,
|
||||
n: nil
|
||||
),
|
||||
modelId: "mlx-community/gemma-3-4b-it-4bit",
|
||||
thinkingEnabled: true
|
||||
)
|
||||
|
||||
let firstPrepared = try await engine.prepare(first.userInput)
|
||||
let secondPrepared = try await engine.prepare(second.userInput)
|
||||
|
||||
let cache = TokenPrefixCache(memoryBudgetBytes: 1_000_000, estimateBytesProvider: { _ in 1_024 })
|
||||
cache.store(entryId: UUID(), kvCache: [], cacheKey: firstPrepared.tokens, modelId: "gemma")
|
||||
|
||||
let lease = cache.lookup(cacheKey: secondPrepared.tokens, modelId: "gemma")
|
||||
|
||||
XCTAssertTrue(lease.isHit)
|
||||
XCTAssertGreaterThan(lease.matchedTokenCount, 0)
|
||||
XCTAssertLessThan(lease.matchedTokenCount, firstPrepared.tokens.count)
|
||||
}
|
||||
|
||||
func testStoredLiveGemmaCacheSupportsSameSystemDifferentUserLCPReuse() async throws {
|
||||
let container = try await localGemmaContainer()
|
||||
let engine = InferenceEngine(container: container)
|
||||
|
||||
let first = PromptBuilder.build(
|
||||
from: APIChatCompletionRequest(
|
||||
model: "gemma",
|
||||
messages: [
|
||||
APIChatMessage(role: "system", content: .text("You are terse and literal."), name: nil, tool_calls: nil, tool_call_id: nil),
|
||||
APIChatMessage(role: "user", content: .text("Respond with one word for cat."), name: nil, tool_calls: nil, tool_call_id: nil),
|
||||
],
|
||||
temperature: nil,
|
||||
top_p: nil,
|
||||
max_tokens: nil,
|
||||
stream: nil,
|
||||
stop: nil,
|
||||
tools: nil,
|
||||
tool_choice: nil,
|
||||
frequency_penalty: nil,
|
||||
presence_penalty: nil,
|
||||
n: nil
|
||||
),
|
||||
modelId: "mlx-community/gemma-3-4b-it-4bit",
|
||||
thinkingEnabled: true
|
||||
)
|
||||
let second = PromptBuilder.build(
|
||||
from: APIChatCompletionRequest(
|
||||
model: "gemma",
|
||||
messages: [
|
||||
APIChatMessage(role: "system", content: .text("You are terse and literal."), name: nil, tool_calls: nil, tool_call_id: nil),
|
||||
APIChatMessage(role: "user", content: .text("Respond with one word for dog."), name: nil, tool_calls: nil, tool_call_id: nil),
|
||||
],
|
||||
temperature: nil,
|
||||
top_p: nil,
|
||||
max_tokens: nil,
|
||||
stream: nil,
|
||||
stop: nil,
|
||||
tools: nil,
|
||||
tool_choice: nil,
|
||||
frequency_penalty: nil,
|
||||
presence_penalty: nil,
|
||||
n: nil
|
||||
),
|
||||
modelId: "mlx-community/gemma-3-4b-it-4bit",
|
||||
thinkingEnabled: true
|
||||
)
|
||||
|
||||
let firstPrepared = try await engine.prepare(first.userInput)
|
||||
let secondPrepared = try await engine.prepare(second.userInput)
|
||||
let handle = try await engine.stream(
|
||||
InferenceEngine.InferenceRequest(
|
||||
input: firstPrepared.lmInput,
|
||||
tokens: firstPrepared.tokens,
|
||||
parameters: GenerateParameters(maxTokens: 2, temperature: 0),
|
||||
cachedKV: nil,
|
||||
cachedTokenCount: 0
|
||||
),
|
||||
cancellation: CancellationToken()
|
||||
)
|
||||
|
||||
_ = await collectEngineOutput(handle.stream)
|
||||
trimCacheToPrompt(handle.workingCache, promptTokenCount: firstPrepared.tokens.count)
|
||||
|
||||
let cache = TokenPrefixCache(memoryBudgetBytes: 1_000_000_000, estimateBytesProvider: { _ in 1_024 })
|
||||
cache.store(entryId: UUID(), kvCache: handle.workingCache, cacheKey: firstPrepared.tokens, modelId: "gemma")
|
||||
|
||||
let lease = cache.lookup(cacheKey: secondPrepared.tokens, modelId: "gemma")
|
||||
|
||||
XCTAssertTrue(lease.isHit)
|
||||
XCTAssertGreaterThan(lease.matchedTokenCount, 0)
|
||||
XCTAssertLessThan(lease.matchedTokenCount, firstPrepared.tokens.count)
|
||||
}
|
||||
|
||||
func testTokenPrefixCacheCanFalseHitDifferentSystemPromptsOnRawGemmaTokens() async throws {
|
||||
let container = try await localGemmaContainer()
|
||||
let engine = InferenceEngine(container: container)
|
||||
|
||||
let first = PromptBuilder.build(
|
||||
from: APIChatCompletionRequest(
|
||||
model: "gemma",
|
||||
messages: [
|
||||
APIChatMessage(role: "system", content: .text("System Alpha Unique Tokens"), name: nil, tool_calls: nil, tool_call_id: nil),
|
||||
APIChatMessage(role: "user", content: .text("Answer in one word: tree."), name: nil, tool_calls: nil, tool_call_id: nil),
|
||||
],
|
||||
temperature: nil,
|
||||
top_p: nil,
|
||||
max_tokens: nil,
|
||||
stream: nil,
|
||||
stop: nil,
|
||||
tools: nil,
|
||||
tool_choice: nil,
|
||||
frequency_penalty: nil,
|
||||
presence_penalty: nil,
|
||||
n: nil
|
||||
),
|
||||
modelId: "mlx-community/gemma-3-4b-it-4bit",
|
||||
thinkingEnabled: true
|
||||
)
|
||||
let second = PromptBuilder.build(
|
||||
from: APIChatCompletionRequest(
|
||||
model: "gemma",
|
||||
messages: [
|
||||
APIChatMessage(role: "system", content: .text("Completely Different Beta Markers"), name: nil, tool_calls: nil, tool_call_id: nil),
|
||||
APIChatMessage(role: "user", content: .text("Answer in one word: tree."), name: nil, tool_calls: nil, tool_call_id: nil),
|
||||
],
|
||||
temperature: nil,
|
||||
top_p: nil,
|
||||
max_tokens: nil,
|
||||
stream: nil,
|
||||
stop: nil,
|
||||
tools: nil,
|
||||
tool_choice: nil,
|
||||
frequency_penalty: nil,
|
||||
presence_penalty: nil,
|
||||
n: nil
|
||||
),
|
||||
modelId: "mlx-community/gemma-3-4b-it-4bit",
|
||||
thinkingEnabled: true
|
||||
)
|
||||
|
||||
let firstPrepared = try await engine.prepare(first.userInput)
|
||||
let secondPrepared = try await engine.prepare(second.userInput)
|
||||
|
||||
let cache = TokenPrefixCache(memoryBudgetBytes: 1_000_000, estimateBytesProvider: { _ in 1_024 })
|
||||
cache.store(entryId: UUID(), kvCache: [], cacheKey: firstPrepared.tokens, modelId: "gemma")
|
||||
|
||||
let lease = cache.lookup(cacheKey: secondPrepared.tokens, modelId: "gemma")
|
||||
|
||||
XCTAssertFalse(lease.isHit)
|
||||
}
|
||||
|
||||
private func localGemmaContainer() async throws -> ModelContainer {
|
||||
try await LocalGemmaFixture.shared.container()
|
||||
}
|
||||
|
||||
private func trimCacheToPrompt(_ cache: [KVCache], promptTokenCount: Int) {
|
||||
for layer in cache {
|
||||
let excess = layer.offset - promptTokenCount
|
||||
if excess > 0 {
|
||||
XCTAssertTrue(layer.isTrimmable)
|
||||
XCTAssertEqual(layer.trim(excess), excess)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private func legacyBuild(
|
||||
from request: APIChatCompletionRequest,
|
||||
modelId: String,
|
||||
|
||||
@@ -127,4 +127,86 @@ final class TokenPrefixCacheTests: XCTestCase {
|
||||
XCTAssertEqual(snapshot.totalCachedTokens, 0)
|
||||
XCTAssertEqual(snapshot.estimatedBytes, 0)
|
||||
}
|
||||
|
||||
func testSupersequenceLookupReusesLongerEntryForShorterQuery() {
|
||||
let cache = TokenPrefixCache(
|
||||
memoryBudgetBytes: 10_000,
|
||||
estimateBytesProvider: { _ in 1_024 }
|
||||
)
|
||||
|
||||
let entryId = UUID()
|
||||
cache.store(entryId: entryId, kvCache: [], cacheKey: [1, 2, 3, 4], modelId: "model")
|
||||
|
||||
let lease = cache.lookup(cacheKey: [1, 2, 3], modelId: "model")
|
||||
let snapshot = cache.snapshot()
|
||||
|
||||
XCTAssertTrue(lease.isHit)
|
||||
XCTAssertEqual(lease.entryId, entryId)
|
||||
XCTAssertEqual(lease.matchedTokenCount, 3)
|
||||
XCTAssertEqual(snapshot.totalHits, 1)
|
||||
XCTAssertEqual(snapshot.supersequenceHits, 1)
|
||||
XCTAssertEqual(snapshot.prefixHits, 0)
|
||||
XCTAssertEqual(snapshot.lcpHits, 0)
|
||||
}
|
||||
|
||||
func testLCPLookupReusesSharedPrefixAcrossDivergentSuffixes() {
|
||||
let cache = TokenPrefixCache(
|
||||
memoryBudgetBytes: 10_000,
|
||||
estimateBytesProvider: { _ in 1_024 }
|
||||
)
|
||||
|
||||
let entryId = UUID()
|
||||
cache.store(entryId: entryId, kvCache: [], cacheKey: [10, 20, 90], modelId: "model")
|
||||
|
||||
let lease = cache.lookup(cacheKey: [10, 20, 30], modelId: "model")
|
||||
let snapshot = cache.snapshot()
|
||||
|
||||
XCTAssertTrue(lease.isHit)
|
||||
XCTAssertEqual(lease.entryId, entryId)
|
||||
XCTAssertEqual(lease.matchedTokenCount, 2)
|
||||
XCTAssertEqual(snapshot.totalHits, 1)
|
||||
XCTAssertEqual(snapshot.lcpHits, 1)
|
||||
XCTAssertEqual(snapshot.prefixHits, 0)
|
||||
XCTAssertEqual(snapshot.supersequenceHits, 0)
|
||||
}
|
||||
|
||||
func testLCPLookupRejectsShallowSharedPrefix() {
|
||||
let cache = TokenPrefixCache(
|
||||
memoryBudgetBytes: 10_000,
|
||||
estimateBytesProvider: { _ in 1_024 }
|
||||
)
|
||||
|
||||
cache.store(entryId: UUID(), kvCache: [], cacheKey: [10, 20, 30, 40], modelId: "model")
|
||||
|
||||
let lease = cache.lookup(cacheKey: [10, 99, 98, 97], modelId: "model")
|
||||
let snapshot = cache.snapshot()
|
||||
|
||||
XCTAssertFalse(lease.isHit)
|
||||
XCTAssertEqual(lease.matchedTokenCount, 0)
|
||||
XCTAssertEqual(snapshot.totalHits, 0)
|
||||
XCTAssertEqual(snapshot.totalMisses, 1)
|
||||
XCTAssertEqual(snapshot.lcpHits, 0)
|
||||
}
|
||||
|
||||
func testLookupPrefersPrefixMatchOverSupersequenceAndLCP() {
|
||||
let cache = TokenPrefixCache(
|
||||
memoryBudgetBytes: 10_000,
|
||||
estimateBytesProvider: { _ in 1_024 }
|
||||
)
|
||||
|
||||
let prefixId = UUID()
|
||||
cache.store(entryId: prefixId, kvCache: [], cacheKey: [7, 8], modelId: "model")
|
||||
cache.store(entryId: UUID(), kvCache: [], cacheKey: [7, 8, 9, 10], modelId: "model")
|
||||
cache.store(entryId: UUID(), kvCache: [], cacheKey: [7, 8, 11], modelId: "model")
|
||||
|
||||
let lease = cache.lookup(cacheKey: [7, 8, 12], modelId: "model")
|
||||
let snapshot = cache.snapshot()
|
||||
|
||||
XCTAssertTrue(lease.isHit)
|
||||
XCTAssertEqual(lease.entryId, prefixId)
|
||||
XCTAssertEqual(lease.matchedTokenCount, 2)
|
||||
XCTAssertEqual(snapshot.prefixHits, 1)
|
||||
XCTAssertEqual(snapshot.supersequenceHits, 0)
|
||||
XCTAssertEqual(snapshot.lcpHits, 0)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user