feat: vision properly coverd with tests and completed

This commit is contained in:
2026-03-20 12:57:27 +01:00
parent e59be9df1a
commit 0761254d17
12 changed files with 648 additions and 40 deletions

View File

@@ -6,8 +6,6 @@ import XCTest
@testable import MLX_Server
final class ModelBackedInferenceValidationTests: XCTestCase {
private let onePixelPNGBase64 = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8BQDwAFgwJ/lRyXWQAAAABJRU5ErkJggg=="
func testPromptBuilderTokenizationMatchesLegacyShapingOnLocalGemma() async throws {
let container = try await localGemmaContainer()
let engine = InferenceEngine(container: container)
@@ -19,7 +17,7 @@ final class ModelBackedInferenceValidationTests: XCTestCase {
role: "user",
content: .parts([
APIContentPart(type: "text", text: "What is in this image?", image_url: nil),
APIContentPart(type: "image_url", text: nil, image_url: APIImageURL(url: "data:image/png;base64,\(onePixelPNGBase64)", detail: nil))
APIContentPart(type: "image_url", text: nil, image_url: APIImageURL(url: TestImageFixtures.primaryDataURI, detail: nil))
]),
name: nil,
tool_calls: nil,
@@ -92,6 +90,62 @@ final class ModelBackedInferenceValidationTests: XCTestCase {
XCTAssertEqual(engineResult.promptTokenCount, sessionResult.promptTokenCount)
}
func testVisionCacheKeyChangesWhenImageChangesButTokensStayTheSame() async throws {
let container = try await localGemmaContainer()
let engine = InferenceEngine(container: container)
let first = PromptBuilder.build(
from: visionRequest(dataURI: TestImageFixtures.primaryDataURI),
modelId: "mlx-community/gemma-3-4b-it-4bit",
thinkingEnabled: false
)
let second = PromptBuilder.build(
from: visionRequest(dataURI: TestImageFixtures.alternateDataURI),
modelId: "mlx-community/gemma-3-4b-it-4bit",
thinkingEnabled: false
)
let firstPrepared = try await engine.prepare(first.userInput, imageFingerprints: first.imageFingerprints)
let secondPrepared = try await engine.prepare(second.userInput, imageFingerprints: second.imageFingerprints)
XCTAssertEqual(firstPrepared.tokens, secondPrepared.tokens)
XCTAssertNotEqual(firstPrepared.cacheKey, secondPrepared.cacheKey)
}
func testStoredLiveGemmaVisionCacheReusesSameImagePrompt() async throws {
let container = try await localGemmaContainer()
let engine = InferenceEngine(container: container)
let prompt = PromptBuilder.build(
from: visionRequest(dataURI: TestImageFixtures.primaryDataURI),
modelId: "mlx-community/gemma-3-4b-it-4bit",
thinkingEnabled: false
)
let prepared = try await engine.prepare(prompt.userInput, imageFingerprints: prompt.imageFingerprints)
let handle = try await engine.stream(
InferenceEngine.InferenceRequest(
input: prepared.lmInput,
tokens: prepared.tokens,
parameters: GenerateParameters(maxTokens: 2, temperature: 0),
cachedKV: nil,
cachedTokenCount: 0
),
cancellation: CancellationToken()
)
_ = await collectEngineOutput(handle.stream)
trimCacheToPrompt(handle.workingCache, promptTokenCount: prepared.tokens.count)
let cache = TokenPrefixCache(memoryBudgetBytes: 1_000_000_000, estimateBytesProvider: { _ in 1_024 })
cache.store(entryId: UUID(), kvCache: handle.workingCache, cacheKey: prepared.cacheKey, modelId: "gemma")
let lease = cache.lookup(cacheKey: prepared.cacheKey, modelId: "gemma")
XCTAssertTrue(lease.isHit)
XCTAssertEqual(lease.matchedTokenCount, prepared.tokens.count)
}
func testTokenPrefixCacheFindsLCPHitForSameSystemDifferentUserOnLocalGemmaTokens() async throws {
let container = try await localGemmaContainer()
let engine = InferenceEngine(container: container)
@@ -225,6 +279,71 @@ final class ModelBackedInferenceValidationTests: XCTestCase {
XCTAssertLessThan(lease.matchedTokenCount, firstPrepared.tokens.count)
}
func testStoredLiveGemmaCacheSupportsSupersequenceReuseForShorterPrefix() async throws {
let container = try await localGemmaContainer()
let engine = InferenceEngine(container: container)
let prompt = PromptBuilder.build(
from: APIChatCompletionRequest(
model: "gemma",
messages: [
APIChatMessage(role: "system", content: .text("You are terse and literal."), name: nil, tool_calls: nil, tool_call_id: nil),
APIChatMessage(role: "user", content: .text("Respond with one word for cat, then one word for dog."), name: nil, tool_calls: nil, tool_call_id: nil),
],
temperature: nil,
top_p: nil,
max_tokens: nil,
stream: nil,
stop: nil,
tools: nil,
tool_choice: nil,
frequency_penalty: nil,
presence_penalty: nil,
n: nil
),
modelId: "mlx-community/gemma-3-4b-it-4bit",
thinkingEnabled: true
)
let prepared = try await engine.prepare(prompt.userInput)
XCTAssertGreaterThan(prepared.tokens.count, 16)
let handle = try await engine.stream(
InferenceEngine.InferenceRequest(
input: prepared.lmInput,
tokens: prepared.tokens,
parameters: GenerateParameters(maxTokens: 2, temperature: 0),
cachedKV: nil,
cachedTokenCount: 0
),
cancellation: CancellationToken()
)
_ = await collectEngineOutput(handle.stream)
trimCacheToPrompt(handle.workingCache, promptTokenCount: prepared.tokens.count)
let shorterTokenCount = prepared.tokens.count - 16
let shorterPrefix = Array(prepared.tokens.prefix(shorterTokenCount))
let cache = TokenPrefixCache(memoryBudgetBytes: 1_000_000_000, estimateBytesProvider: { _ in 1_024 })
cache.store(entryId: UUID(), kvCache: handle.workingCache, cacheKey: prepared.tokens, modelId: "gemma")
let lease = cache.lookup(cacheKey: shorterPrefix, modelId: "gemma")
XCTAssertTrue(lease.isHit)
XCTAssertEqual(lease.matchedTokenCount, shorterTokenCount)
let leasedCache = try XCTUnwrap(lease.kvCache)
XCTAssertFalse(leasedCache.isEmpty)
for layer in leasedCache {
XCTAssertEqual(layer.offset, shorterTokenCount)
}
let snapshot = cache.snapshot()
XCTAssertEqual(snapshot.supersequenceHits, 1)
XCTAssertEqual(snapshot.lcpHits, 0)
XCTAssertEqual(snapshot.prefixHits, 0)
}
func testTokenPrefixCacheCanFalseHitDifferentSystemPromptsOnRawGemmaTokens() async throws {
let container = try await localGemmaContainer()
let engine = InferenceEngine(container: container)
@@ -376,6 +495,7 @@ final class ModelBackedInferenceValidationTests: XCTestCase {
instructions: instructions,
chatMessages: chatMessages,
messageSignatures: messageSignatures,
imageFingerprints: imageURLsFingerprintOrder(from: request),
estimatedBytes: estimatedBytes,
estimatedPromptTokens: (instructions.count + chatMessages.reduce(0) { $0 + $1.content.count }) * 10 / 35,
containsImages: containsImages,
@@ -384,6 +504,48 @@ final class ModelBackedInferenceValidationTests: XCTestCase {
)
}
private func visionRequest(dataURI: String) -> APIChatCompletionRequest {
APIChatCompletionRequest(
model: "gemma",
messages: [
APIChatMessage(
role: "user",
content: .parts([
APIContentPart(type: "text", text: "What is in this image?", image_url: nil),
APIContentPart(type: "image_url", text: nil, image_url: APIImageURL(url: dataURI, detail: nil))
]),
name: nil,
tool_calls: nil,
tool_call_id: nil
)
],
temperature: nil,
top_p: nil,
max_tokens: nil,
stream: nil,
stop: nil,
tools: nil,
tool_choice: nil,
frequency_penalty: nil,
presence_penalty: nil,
n: nil
)
}
private func imageURLsFingerprintOrder(from request: APIChatCompletionRequest) -> [UInt64] {
request.messages
.filter { $0.role != "system" }
.flatMap { $0.content?.imageURLs ?? [] }
.reduce(into: [UInt64]()) { fingerprints, imageURL in
var hash: UInt64 = 14_695_981_039_346_656_037
for byte in imageURL.utf8 {
hash ^= UInt64(byte)
hash &*= 1_099_511_628_211
}
fingerprints.append(hash)
}
}
private func messageSignature(role: Chat.Message.Role, content: String, imageURLs: [String]) -> UInt64 {
var hash: UInt64 = 14_695_981_039_346_656_037