feat: vision properly coverd with tests and completed

This commit is contained in:
2026-03-20 12:57:27 +01:00
parent e59be9df1a
commit 0761254d17
12 changed files with 648 additions and 40 deletions

View File

@@ -19,6 +19,7 @@
2CAAF7129F7CC45200FA9F6B /* ModelPickerView.swift in Sources */ = {isa = PBXBuildFile; fileRef = C3C3A76C02AF70A9D8F868FC /* ModelPickerView.swift */; };
2D08769282BD71C170DB0943 /* InferenceStats.swift in Sources */ = {isa = PBXBuildFile; fileRef = E35452B166893B25E765FF70 /* InferenceStats.swift */; };
2E3A02DF9C6A5109E532D5E2 /* ChatDocumentController.swift in Sources */ = {isa = PBXBuildFile; fileRef = D5C1FCEFEA72B9ABB87FB20E /* ChatDocumentController.swift */; };
3A9DB84947BBBBED06CF9E1E /* TestImageFixtures.swift in Sources */ = {isa = PBXBuildFile; fileRef = 31BD930DEC051408444C30D4 /* TestImageFixtures.swift */; };
4158FA884D981D73288FB74C /* SaveChatCommands.swift in Sources */ = {isa = PBXBuildFile; fileRef = 2E2FCA55CEBEBCED78D9479A /* SaveChatCommands.swift */; };
4CB13DC1AC7A500DDBB443EC /* ChatInputView.swift in Sources */ = {isa = PBXBuildFile; fileRef = E5E6AD02CDF23BDAB64700A7 /* ChatInputView.swift */; };
4DC033E45880B2948B47DEB1 /* FocusedValues.swift in Sources */ = {isa = PBXBuildFile; fileRef = EF518FEBF3A38E830E3CE1A5 /* FocusedValues.swift */; };
@@ -85,6 +86,7 @@
24E29065DD29C17D20B0400D /* ChatDocumentMigration.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ChatDocumentMigration.swift; sourceTree = "<group>"; };
2DC8C86D397B1FCA08E07CBD /* DownloadModalView.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = DownloadModalView.swift; sourceTree = "<group>"; };
2E2FCA55CEBEBCED78D9479A /* SaveChatCommands.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = SaveChatCommands.swift; sourceTree = "<group>"; };
31BD930DEC051408444C30D4 /* TestImageFixtures.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = TestImageFixtures.swift; sourceTree = "<group>"; };
3489501F2F8E1BA382347CFA /* CancellationToken.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = CancellationToken.swift; sourceTree = "<group>"; };
37FEB592E5E717F817B03151 /* SceneManagementView.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = SceneManagementView.swift; sourceTree = "<group>"; };
386CD08DC6338F42460DFBE2 /* Info.plist */ = {isa = PBXFileReference; lastKnownFileType = text.plist; path = Info.plist; sourceTree = "<group>"; };
@@ -189,6 +191,7 @@
D388BE00B42C06ED9D9905BF /* ModelBackedInferenceValidationTests.swift */,
5F9426FA5A4AC55F8D9C080E /* PromptBuilderTests.swift */,
49C383DD5224F3420EB98DB2 /* StreamingSSEEncoderTests.swift */,
31BD930DEC051408444C30D4 /* TestImageFixtures.swift */,
64B2EDD5D1881AC9E1E60913 /* TokenPrefixCacheTests.swift */,
B89226C9ED585A5296C54441 /* ToolCallParserTests.swift */,
);
@@ -400,6 +403,7 @@
8E665E21CCCD87A907CEA78D /* ModelBackedInferenceValidationTests.swift in Sources */,
1FE8C624898960ECCE39C0D4 /* PromptBuilderTests.swift in Sources */,
FE4405F66873C75CD6FA19A5 /* StreamingSSEEncoderTests.swift in Sources */,
3A9DB84947BBBBED06CF9E1E /* TestImageFixtures.swift in Sources */,
221DEC86374902FCFD661A01 /* TokenPrefixCacheTests.swift in Sources */,
834B49AA3E30A1FED549D057 /* ToolCallParserTests.swift in Sources */,
);

View File

@@ -323,7 +323,10 @@ final class APIServer {
let preparedInference: InferenceEngine.PreparedInference
do {
let prepareStartedAt = Date()
preparedInference = try await engine.prepare(preparedPrompt.userInput)
preparedInference = try await engine.prepare(
preparedPrompt.userInput,
imageFingerprints: preparedPrompt.imageFingerprints
)
if preparedPrompt.containsImages {
LiveCounters.shared.visionProcessingCompleted(
requestId: requestId,
@@ -336,10 +339,8 @@ final class APIServer {
return
}
// Vision requests stay uncached until image fingerprinting lands.
let cacheKey = preparedInference.hasImages ? nil : preparedInference.tokens
let lease = cacheKey.map { TokenPrefixCache.shared.lookup(cacheKey: $0, modelId: currentModelId) }
?? TokenPrefixCache.CacheLease(entryId: UUID(), kvCache: nil, matchedTokenCount: 0, isHit: false)
let cacheKey = preparedInference.cacheKey
let lease = TokenPrefixCache.shared.lookup(cacheKey: cacheKey, modelId: currentModelId)
Self.debugLookupEventHandler?(
DebugLookupEvent(
@@ -405,8 +406,7 @@ final class APIServer {
)
}
if let cacheKey,
!isShuttingDown,
if !isShuttingDown,
result.succeeded || result.cancelled {
Self.storePromptCache(
streamHandle.workingCache,
@@ -669,7 +669,9 @@ final class APIServer {
cacheKey: [Int],
modelId: String
) {
_ = trimGeneratedTokens(cache, promptTokenCount: promptTokenCount)
guard trimGeneratedTokens(cache, promptTokenCount: promptTokenCount) else {
return
}
TokenPrefixCache.shared.store(
entryId: entryId,
kvCache: cache,

View File

@@ -25,6 +25,7 @@ final class InferenceEngine: @unchecked Sendable {
struct PreparedInference: @unchecked Sendable {
let lmInput: LMInput
let tokens: [Int]
let cacheKey: [Int]
let hasImages: Bool
}
@@ -49,18 +50,178 @@ final class InferenceEngine: @unchecked Sendable {
}
}
func prepare(_ userInput: UserInput) async throws -> PreparedInference {
func prepare(_ userInput: UserInput, imageFingerprints: [UInt64] = []) async throws -> PreparedInference {
nonisolated(unsafe) let input = userInput
let lmInput = try await container.prepare(input: input)
nonisolated(unsafe) let preparedInput = lmInput
let tokenArray: [Int] = await container.perform { _ in
preparedInput.text.tokens.asArray(Int.self)
}
let cacheKey = await buildCacheKey(tokens: tokenArray, imageFingerprints: imageFingerprints)
return PreparedInference(
lmInput: lmInput,
tokens: tokenArray,
cacheKey: cacheKey,
hasImages: userInput.images.count > 0
)
}
private func buildCacheKey(tokens: [Int], imageFingerprints: [UInt64]) async -> [Int] {
guard !imageFingerprints.isEmpty else {
return tokens
}
let modelIdentifier = await container.configuration.name.lowercased()
if modelIdentifier.contains("gemma"),
let key = Self.buildGemmaCacheKey(tokens: tokens, imageFingerprints: imageFingerprints) {
return key
}
return await container.perform { context in
let visionStartTokens = context.tokenizer.encode(text: "<|vision_start|>")
let imagePadTokens = context.tokenizer.encode(text: "<|image_pad|>")
let visionEndTokens = context.tokenizer.encode(text: "<|vision_end|>")
if let key = Self.buildQwenCacheKey(
tokens: tokens,
imageFingerprints: imageFingerprints,
visionStartTokens: visionStartTokens,
imagePadTokens: imagePadTokens,
visionEndTokens: visionEndTokens
) {
return key
}
return Self.buildFallbackVisionCacheKey(tokens: tokens, imageFingerprints: imageFingerprints)
}
}
private static func buildGemmaCacheKey(tokens: [Int], imageFingerprints: [UInt64]) -> [Int]? {
let imageTokenId = 262_144
let totalImageTokenCount = tokens.reduce(into: 0) { count, token in
if token == imageTokenId {
count += 1
}
}
guard totalImageTokenCount > 0,
totalImageTokenCount % imageFingerprints.count == 0
else {
return nil
}
let tokensPerImage = totalImageTokenCount / imageFingerprints.count
guard tokensPerImage > 0 else {
return nil
}
var key: [Int] = []
key.reserveCapacity(tokens.count + imageFingerprints.count * 2)
var currentImageTokenCount = 0
var currentImageIndex = 0
for token in tokens {
key.append(token)
guard token == imageTokenId else { continue }
currentImageTokenCount += 1
if currentImageTokenCount == tokensPerImage,
currentImageIndex < imageFingerprints.count {
key.append(contentsOf: fingerprintSentinels(imageFingerprints[currentImageIndex]))
currentImageIndex += 1
currentImageTokenCount = 0
}
}
guard currentImageIndex == imageFingerprints.count else {
return nil
}
return key
}
private static func buildQwenCacheKey(
tokens: [Int],
imageFingerprints: [UInt64],
visionStartTokens: [Int],
imagePadTokens: [Int],
visionEndTokens: [Int]
) -> [Int]? {
guard !visionStartTokens.isEmpty,
!imagePadTokens.isEmpty,
!visionEndTokens.isEmpty
else {
return nil
}
var key: [Int] = []
key.reserveCapacity(tokens.count + imageFingerprints.count * 2)
var tokenIndex = 0
var imageIndex = 0
while tokenIndex < tokens.count {
if matches(tokens: tokens, sequence: visionStartTokens, at: tokenIndex) {
let imageRegionStart = tokenIndex
var scanIndex = tokenIndex + visionStartTokens.count
var sawImagePad = false
while matches(tokens: tokens, sequence: imagePadTokens, at: scanIndex) {
sawImagePad = true
scanIndex += imagePadTokens.count
}
if sawImagePad,
matches(tokens: tokens, sequence: visionEndTokens, at: scanIndex),
imageIndex < imageFingerprints.count {
let imageRegionEnd = scanIndex + visionEndTokens.count
key.append(contentsOf: tokens[imageRegionStart..<imageRegionEnd])
key.append(contentsOf: fingerprintSentinels(imageFingerprints[imageIndex]))
tokenIndex = imageRegionEnd
imageIndex += 1
continue
}
}
key.append(tokens[tokenIndex])
tokenIndex += 1
}
guard imageIndex == imageFingerprints.count else {
return nil
}
return key
}
private static func buildFallbackVisionCacheKey(tokens: [Int], imageFingerprints: [UInt64]) -> [Int] {
var key: [Int] = []
key.reserveCapacity(tokens.count + imageFingerprints.count * 2)
for fingerprint in imageFingerprints {
key.append(contentsOf: fingerprintSentinels(fingerprint))
}
key.append(contentsOf: tokens)
return key
}
private static func fingerprintSentinels(_ fingerprint: UInt64) -> [Int] {
let upper = Int(UInt32(truncatingIfNeeded: fingerprint >> 32))
let lower = Int(UInt32(truncatingIfNeeded: fingerprint))
return [-(upper + 1), -(lower + 1)]
}
private static func matches(tokens: [Int], sequence: [Int], at start: Int) -> Bool {
guard start + sequence.count <= tokens.count else {
return false
}
for (offset, token) in sequence.enumerated() where tokens[start + offset] != token {
return false
}
return true
}
}

View File

@@ -7,6 +7,7 @@ enum PromptBuilder {
let instructions: String
let chatMessages: [Chat.Message]
let messageSignatures: [UInt64]
let imageFingerprints: [UInt64]
let estimatedBytes: Int
let estimatedPromptTokens: Int
let containsImages: Bool
@@ -36,6 +37,7 @@ enum PromptBuilder {
let isQwen = modelId.lowercased().contains("qwen")
var chatMessages: [Chat.Message] = []
var messageSignatures: [UInt64] = []
var imageFingerprints: [UInt64] = []
var estimatedBytes = instructions.utf8.count
var containsImages = false
@@ -64,6 +66,7 @@ enum PromptBuilder {
for urlString in imageURLs {
if let decoded = ImageDecoder.decode(urlString) {
messageImages.append(decoded.image)
imageFingerprints.append(imageFingerprint(urlString))
messageImageBytes += decoded.estimatedBytes
}
}
@@ -99,6 +102,7 @@ enum PromptBuilder {
instructions: instructions,
chatMessages: chatMessages,
messageSignatures: messageSignatures,
imageFingerprints: imageFingerprints,
estimatedBytes: estimatedBytes,
estimatedPromptTokens: estimatedPromptTokens,
containsImages: containsImages,
@@ -107,6 +111,15 @@ enum PromptBuilder {
)
}
private static func imageFingerprint(_ source: String) -> UInt64 {
var hash: UInt64 = 14_695_981_039_346_656_037
for byte in source.utf8 {
hash ^= UInt64(byte)
hash &*= 1_099_511_628_211
}
return hash
}
private static func messageSignature(role: Chat.Message.Role, content: String, imageURLs: [String]) -> UInt64 {
var hash: UInt64 = 14_695_981_039_346_656_037

View File

@@ -123,12 +123,12 @@ final class TokenPrefixCache: @unchecked Sendable {
}
}
if let match = bestMatch,
var entry = entries[match.entryId] {
if let match = bestMatch,
var entry = entries[match.entryId] {
entry.lastAccessAt = now
entry.hitCount += 1
entries[match.entryId] = entry
removeEntryLocked(entry)
removeEntryLocked(entry, countAsEviction: false)
stats.totalHits += 1
stats.totalPrefixHits += 1
lock.unlock()
@@ -152,7 +152,8 @@ final class TokenPrefixCache: @unchecked Sendable {
return superLease
}
if realTokenCount > 0,
if !walkedFullKey,
realTokenCount > 0,
let lcpLease = findLCPMatchLocked(
below: node,
sharedRealTokenCount: realTokenCount,
@@ -190,7 +191,7 @@ final class TokenPrefixCache: @unchecked Sendable {
if let oldId = node.entryId,
let oldEntry = entries[oldId] {
removeEntryLocked(oldEntry)
removeEntryLocked(oldEntry, countAsEviction: false)
}
node.entryId = entryId
@@ -285,7 +286,7 @@ final class TokenPrefixCache: @unchecked Sendable {
now.timeIntervalSince($0.lastAccessAt) > idleTTL
}
for entry in expired {
removeEntryLocked(entry)
removeEntryLocked(entry, countAsEviction: true)
}
}
@@ -294,11 +295,11 @@ final class TokenPrefixCache: @unchecked Sendable {
guard let victim = entries.values.min(by: evictionOrder) else {
break
}
removeEntryLocked(victim)
removeEntryLocked(victim, countAsEviction: true)
}
}
private func removeEntryLocked(_ entry: CacheEntry) {
private func removeEntryLocked(_ entry: CacheEntry, countAsEviction: Bool) {
guard entries[entry.id] != nil else { return }
var node = root
@@ -321,7 +322,9 @@ final class TokenPrefixCache: @unchecked Sendable {
currentMemoryBytes = max(0, currentMemoryBytes - entry.estimatedBytes)
entries.removeValue(forKey: entry.id)
stats.totalEvictions += 1
if countAsEviction {
stats.totalEvictions += 1
}
}
private func evictionOrder(lhs: CacheEntry, rhs: CacheEntry) -> Bool {
@@ -374,7 +377,7 @@ final class TokenPrefixCache: @unchecked Sendable {
updatedEntry.lastAccessAt = now
updatedEntry.hitCount += 1
entries[entry.id] = updatedEntry
removeEntryLocked(updatedEntry)
removeEntryLocked(updatedEntry, countAsEviction: false)
stats.totalHits += 1
stats.totalSupersequenceHits += 1
@@ -427,7 +430,7 @@ final class TokenPrefixCache: @unchecked Sendable {
updatedEntry.lastAccessAt = now
updatedEntry.hitCount += 1
entries[entry.id] = updatedEntry
removeEntryLocked(updatedEntry)
removeEntryLocked(updatedEntry, countAsEviction: false)
stats.totalHits += 1
stats.totalLCPHits += 1

View File

@@ -120,6 +120,124 @@ final class APIServerRewriteTests: XCTestCase {
XCTAssertGreaterThan(secondLiveSnapshot.cacheMatchDepth, 0)
}
func testVisionPromptCachesAndReusesSameImageRequest() async throws {
let harness = try await makeHarness()
defer { harness.stop() }
let lookups = LookupEventCollector()
APIServer.debugLookupEventHandler = { event in
Task {
await lookups.record(event)
}
}
defer {
APIServer.debugLookupEventHandler = nil
}
let request = visionRequest(dataURI: TestImageFixtures.primaryDataURI, prompt: "Describe this image in one word.")
_ = try await sendChatCompletion(request, port: harness.port)
_ = try await sendChatCompletion(request, port: harness.port)
try await waitUntil(timeoutSeconds: 5) {
let events = await lookups.events()
return events.count >= 2
}
let events = await lookups.events()
let secondLookup = try XCTUnwrap(events.last)
XCTAssertTrue(secondLookup.isHit)
XCTAssertEqual(secondLookup.matchedTokenCount, secondLookup.promptTokenCount)
}
func testVisionPromptDifferentImageMissesCache() async throws {
let harness = try await makeHarness()
defer { harness.stop() }
let lookups = LookupEventCollector()
APIServer.debugLookupEventHandler = { event in
Task {
await lookups.record(event)
}
}
defer {
APIServer.debugLookupEventHandler = nil
}
_ = try await sendChatCompletion(visionRequest(dataURI: TestImageFixtures.primaryDataURI, prompt: "Describe this image in one word."), port: harness.port)
_ = try await sendChatCompletion(visionRequest(dataURI: TestImageFixtures.alternateDataURI, prompt: "Describe this image in one word."), port: harness.port)
try await waitUntil(timeoutSeconds: 5) {
let events = await lookups.events()
return events.count >= 2
}
let events = await lookups.events()
let secondLookup = try XCTUnwrap(events.last)
XCTAssertFalse(secondLookup.isHit)
XCTAssertEqual(secondLookup.matchedTokenCount, 0)
}
func testTextOnlyFollowUpReusesEarlierImagePrefix() async throws {
let harness = try await makeHarness()
defer { harness.stop() }
let lookups = LookupEventCollector()
APIServer.debugLookupEventHandler = { event in
Task {
await lookups.record(event)
}
}
defer {
APIServer.debugLookupEventHandler = nil
}
let firstRequest = visionRequest(dataURI: TestImageFixtures.primaryDataURI, prompt: "Describe this image in one short word.")
let firstResponse = try await sendChatCompletion(firstRequest, port: harness.port)
let assistantContent = try XCTUnwrap(firstResponse.choices.first?.message.content)
let followUpRequest = APIChatCompletionRequest(
model: "gemma",
messages: [
APIChatMessage(
role: "user",
content: .parts([
APIContentPart(type: "text", text: "Describe this image in one short word.", image_url: nil),
APIContentPart(type: "image_url", text: nil, image_url: APIImageURL(url: TestImageFixtures.primaryDataURI, detail: nil))
]),
name: nil,
tool_calls: nil,
tool_call_id: nil
),
APIChatMessage(role: "assistant", content: .text(assistantContent), name: nil, tool_calls: nil, tool_call_id: nil),
APIChatMessage(role: "user", content: .text("Now answer in one word: what color is the sky?"), name: nil, tool_calls: nil, tool_call_id: nil)
],
temperature: 0,
top_p: 1,
max_tokens: 2,
stream: false,
stop: nil,
tools: nil,
tool_choice: nil,
frequency_penalty: nil,
presence_penalty: nil,
n: nil
)
_ = try await sendChatCompletion(followUpRequest, port: harness.port)
try await waitUntil(timeoutSeconds: 5) {
let events = await lookups.events()
return events.count >= 2
}
let events = await lookups.events()
let secondLookup = try XCTUnwrap(events.last)
XCTAssertTrue(secondLookup.isHit)
XCTAssertGreaterThan(secondLookup.matchedTokenCount, 0)
XCTAssertLessThan(secondLookup.matchedTokenCount, secondLookup.promptTokenCount)
}
func testSecondIdenticalRequestIsFullCacheHitWithZeroRebuiltPromptTokens() async throws {
let harness = try await makeHarness()
defer { harness.stop() }
@@ -1216,6 +1334,34 @@ final class APIServerRewriteTests: XCTestCase {
return TestHarness(server: server, modelManager: modelManager, port: port)
}
private func visionRequest(dataURI: String, prompt: String) -> APIChatCompletionRequest {
APIChatCompletionRequest(
model: "gemma",
messages: [
APIChatMessage(
role: "user",
content: .parts([
APIContentPart(type: "text", text: prompt, image_url: nil),
APIContentPart(type: "image_url", text: nil, image_url: APIImageURL(url: dataURI, detail: nil))
]),
name: nil,
tool_calls: nil,
tool_call_id: nil
)
],
temperature: 0,
top_p: 1,
max_tokens: 2,
stream: false,
stop: nil,
tools: nil,
tool_choice: nil,
frequency_penalty: nil,
presence_penalty: nil,
n: nil
)
}
private func sendChatCompletion(_ request: APIChatCompletionRequest, port: UInt16) async throws -> APIChatCompletionResponse {
let url = URL(string: "http://127.0.0.1:\(port)/v1/chat/completions")!
var urlRequest = URLRequest(url: url)

View File

@@ -2,17 +2,15 @@ import XCTest
@testable import MLX_Server
final class ImageDecoderTests: XCTestCase {
private let onePixelPNGBase64 = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8BQDwAFgwJ/lRyXWQAAAABJRU5ErkJggg=="
func testDecodeDataURI() {
let image = ImageDecoder.decode("data:image/png;base64,\(onePixelPNGBase64)")
let image = ImageDecoder.decode(TestImageFixtures.primaryDataURI)
XCTAssertNotNil(image)
XCTAssertGreaterThanOrEqual(image?.estimatedBytes ?? 0, 4)
}
func testDecodePlainBase64() {
let image = ImageDecoder.decode(onePixelPNGBase64)
let image = ImageDecoder.decode(TestImageFixtures.primaryPNGBase64)
XCTAssertNotNil(image)
XCTAssertGreaterThanOrEqual(image?.estimatedBytes ?? 0, 4)

View File

@@ -6,8 +6,6 @@ import XCTest
@testable import MLX_Server
final class ModelBackedInferenceValidationTests: XCTestCase {
private let onePixelPNGBase64 = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8BQDwAFgwJ/lRyXWQAAAABJRU5ErkJggg=="
func testPromptBuilderTokenizationMatchesLegacyShapingOnLocalGemma() async throws {
let container = try await localGemmaContainer()
let engine = InferenceEngine(container: container)
@@ -19,7 +17,7 @@ final class ModelBackedInferenceValidationTests: XCTestCase {
role: "user",
content: .parts([
APIContentPart(type: "text", text: "What is in this image?", image_url: nil),
APIContentPart(type: "image_url", text: nil, image_url: APIImageURL(url: "data:image/png;base64,\(onePixelPNGBase64)", detail: nil))
APIContentPart(type: "image_url", text: nil, image_url: APIImageURL(url: TestImageFixtures.primaryDataURI, detail: nil))
]),
name: nil,
tool_calls: nil,
@@ -92,6 +90,62 @@ final class ModelBackedInferenceValidationTests: XCTestCase {
XCTAssertEqual(engineResult.promptTokenCount, sessionResult.promptTokenCount)
}
func testVisionCacheKeyChangesWhenImageChangesButTokensStayTheSame() async throws {
let container = try await localGemmaContainer()
let engine = InferenceEngine(container: container)
let first = PromptBuilder.build(
from: visionRequest(dataURI: TestImageFixtures.primaryDataURI),
modelId: "mlx-community/gemma-3-4b-it-4bit",
thinkingEnabled: false
)
let second = PromptBuilder.build(
from: visionRequest(dataURI: TestImageFixtures.alternateDataURI),
modelId: "mlx-community/gemma-3-4b-it-4bit",
thinkingEnabled: false
)
let firstPrepared = try await engine.prepare(first.userInput, imageFingerprints: first.imageFingerprints)
let secondPrepared = try await engine.prepare(second.userInput, imageFingerprints: second.imageFingerprints)
XCTAssertEqual(firstPrepared.tokens, secondPrepared.tokens)
XCTAssertNotEqual(firstPrepared.cacheKey, secondPrepared.cacheKey)
}
func testStoredLiveGemmaVisionCacheReusesSameImagePrompt() async throws {
let container = try await localGemmaContainer()
let engine = InferenceEngine(container: container)
let prompt = PromptBuilder.build(
from: visionRequest(dataURI: TestImageFixtures.primaryDataURI),
modelId: "mlx-community/gemma-3-4b-it-4bit",
thinkingEnabled: false
)
let prepared = try await engine.prepare(prompt.userInput, imageFingerprints: prompt.imageFingerprints)
let handle = try await engine.stream(
InferenceEngine.InferenceRequest(
input: prepared.lmInput,
tokens: prepared.tokens,
parameters: GenerateParameters(maxTokens: 2, temperature: 0),
cachedKV: nil,
cachedTokenCount: 0
),
cancellation: CancellationToken()
)
_ = await collectEngineOutput(handle.stream)
trimCacheToPrompt(handle.workingCache, promptTokenCount: prepared.tokens.count)
let cache = TokenPrefixCache(memoryBudgetBytes: 1_000_000_000, estimateBytesProvider: { _ in 1_024 })
cache.store(entryId: UUID(), kvCache: handle.workingCache, cacheKey: prepared.cacheKey, modelId: "gemma")
let lease = cache.lookup(cacheKey: prepared.cacheKey, modelId: "gemma")
XCTAssertTrue(lease.isHit)
XCTAssertEqual(lease.matchedTokenCount, prepared.tokens.count)
}
func testTokenPrefixCacheFindsLCPHitForSameSystemDifferentUserOnLocalGemmaTokens() async throws {
let container = try await localGemmaContainer()
let engine = InferenceEngine(container: container)
@@ -225,6 +279,71 @@ final class ModelBackedInferenceValidationTests: XCTestCase {
XCTAssertLessThan(lease.matchedTokenCount, firstPrepared.tokens.count)
}
func testStoredLiveGemmaCacheSupportsSupersequenceReuseForShorterPrefix() async throws {
let container = try await localGemmaContainer()
let engine = InferenceEngine(container: container)
let prompt = PromptBuilder.build(
from: APIChatCompletionRequest(
model: "gemma",
messages: [
APIChatMessage(role: "system", content: .text("You are terse and literal."), name: nil, tool_calls: nil, tool_call_id: nil),
APIChatMessage(role: "user", content: .text("Respond with one word for cat, then one word for dog."), name: nil, tool_calls: nil, tool_call_id: nil),
],
temperature: nil,
top_p: nil,
max_tokens: nil,
stream: nil,
stop: nil,
tools: nil,
tool_choice: nil,
frequency_penalty: nil,
presence_penalty: nil,
n: nil
),
modelId: "mlx-community/gemma-3-4b-it-4bit",
thinkingEnabled: true
)
let prepared = try await engine.prepare(prompt.userInput)
XCTAssertGreaterThan(prepared.tokens.count, 16)
let handle = try await engine.stream(
InferenceEngine.InferenceRequest(
input: prepared.lmInput,
tokens: prepared.tokens,
parameters: GenerateParameters(maxTokens: 2, temperature: 0),
cachedKV: nil,
cachedTokenCount: 0
),
cancellation: CancellationToken()
)
_ = await collectEngineOutput(handle.stream)
trimCacheToPrompt(handle.workingCache, promptTokenCount: prepared.tokens.count)
let shorterTokenCount = prepared.tokens.count - 16
let shorterPrefix = Array(prepared.tokens.prefix(shorterTokenCount))
let cache = TokenPrefixCache(memoryBudgetBytes: 1_000_000_000, estimateBytesProvider: { _ in 1_024 })
cache.store(entryId: UUID(), kvCache: handle.workingCache, cacheKey: prepared.tokens, modelId: "gemma")
let lease = cache.lookup(cacheKey: shorterPrefix, modelId: "gemma")
XCTAssertTrue(lease.isHit)
XCTAssertEqual(lease.matchedTokenCount, shorterTokenCount)
let leasedCache = try XCTUnwrap(lease.kvCache)
XCTAssertFalse(leasedCache.isEmpty)
for layer in leasedCache {
XCTAssertEqual(layer.offset, shorterTokenCount)
}
let snapshot = cache.snapshot()
XCTAssertEqual(snapshot.supersequenceHits, 1)
XCTAssertEqual(snapshot.lcpHits, 0)
XCTAssertEqual(snapshot.prefixHits, 0)
}
func testTokenPrefixCacheCanFalseHitDifferentSystemPromptsOnRawGemmaTokens() async throws {
let container = try await localGemmaContainer()
let engine = InferenceEngine(container: container)
@@ -376,6 +495,7 @@ final class ModelBackedInferenceValidationTests: XCTestCase {
instructions: instructions,
chatMessages: chatMessages,
messageSignatures: messageSignatures,
imageFingerprints: imageURLsFingerprintOrder(from: request),
estimatedBytes: estimatedBytes,
estimatedPromptTokens: (instructions.count + chatMessages.reduce(0) { $0 + $1.content.count }) * 10 / 35,
containsImages: containsImages,
@@ -384,6 +504,48 @@ final class ModelBackedInferenceValidationTests: XCTestCase {
)
}
private func visionRequest(dataURI: String) -> APIChatCompletionRequest {
APIChatCompletionRequest(
model: "gemma",
messages: [
APIChatMessage(
role: "user",
content: .parts([
APIContentPart(type: "text", text: "What is in this image?", image_url: nil),
APIContentPart(type: "image_url", text: nil, image_url: APIImageURL(url: dataURI, detail: nil))
]),
name: nil,
tool_calls: nil,
tool_call_id: nil
)
],
temperature: nil,
top_p: nil,
max_tokens: nil,
stream: nil,
stop: nil,
tools: nil,
tool_choice: nil,
frequency_penalty: nil,
presence_penalty: nil,
n: nil
)
}
private func imageURLsFingerprintOrder(from request: APIChatCompletionRequest) -> [UInt64] {
request.messages
.filter { $0.role != "system" }
.flatMap { $0.content?.imageURLs ?? [] }
.reduce(into: [UInt64]()) { fingerprints, imageURL in
var hash: UInt64 = 14_695_981_039_346_656_037
for byte in imageURL.utf8 {
hash ^= UInt64(byte)
hash &*= 1_099_511_628_211
}
fingerprints.append(hash)
}
}
private func messageSignature(role: Chat.Message.Role, content: String, imageURLs: [String]) -> UInt64 {
var hash: UInt64 = 14_695_981_039_346_656_037

View File

@@ -3,8 +3,6 @@ import MLXLMCommon
@testable import MLX_Server
final class PromptBuilderTests: XCTestCase {
private let onePixelPNGBase64 = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8BQDwAFgwJ/lRyXWQAAAABJRU5ErkJggg=="
func testBuildMatchesLegacyAPIServerShapingForGemma() {
let toolCall = APIToolCall(
id: "call_weather",
@@ -20,7 +18,7 @@ final class PromptBuilderTests: XCTestCase {
role: "tool",
content: .parts([
APIContentPart(type: "text", text: "{\"temp\":19}", image_url: nil),
APIContentPart(type: "image_url", text: nil, image_url: APIImageURL(url: "data:image/png;base64,\(onePixelPNGBase64)", detail: nil))
APIContentPart(type: "image_url", text: nil, image_url: APIImageURL(url: TestImageFixtures.primaryDataURI, detail: nil))
]),
name: nil,
tool_calls: nil,
@@ -132,7 +130,7 @@ final class PromptBuilderTests: XCTestCase {
role: "tool",
content: .parts([
APIContentPart(type: "text", text: "{\"ok\":true}", image_url: nil),
APIContentPart(type: "image_url", text: nil, image_url: APIImageURL(url: "data:image/png;base64,\(onePixelPNGBase64)", detail: nil))
APIContentPart(type: "image_url", text: nil, image_url: APIImageURL(url: TestImageFixtures.primaryDataURI, detail: nil))
]),
name: nil,
tool_calls: nil,
@@ -156,9 +154,70 @@ final class PromptBuilderTests: XCTestCase {
XCTAssertTrue(prepared.chatMessages[0].content.contains("```tool_output"))
XCTAssertTrue(prepared.containsImages)
XCTAssertEqual(prepared.chatMessages[0].images.count, 1)
XCTAssertEqual(prepared.imageFingerprints.count, 1)
XCTAssertGreaterThan(prepared.estimatedBytes, prepared.chatMessages[0].content.utf8.count)
}
func testBuildHashesRawImageSourcesIntoStableFingerprints() {
let firstRequest = APIChatCompletionRequest(
model: "gemma",
messages: [
APIChatMessage(
role: "user",
content: .parts([
APIContentPart(type: "text", text: "Describe this.", image_url: nil),
APIContentPart(type: "image_url", text: nil, image_url: APIImageURL(url: TestImageFixtures.primaryDataURI, detail: nil))
]),
name: nil,
tool_calls: nil,
tool_call_id: nil
)
],
temperature: nil,
top_p: nil,
max_tokens: nil,
stream: nil,
stop: nil,
tools: nil,
tool_choice: nil,
frequency_penalty: nil,
presence_penalty: nil,
n: nil
)
let secondRequest = APIChatCompletionRequest(
model: "gemma",
messages: [
APIChatMessage(
role: "user",
content: .parts([
APIContentPart(type: "text", text: "Describe this.", image_url: nil),
APIContentPart(type: "image_url", text: nil, image_url: APIImageURL(url: TestImageFixtures.alternateDataURI, detail: nil))
]),
name: nil,
tool_calls: nil,
tool_call_id: nil
)
],
temperature: nil,
top_p: nil,
max_tokens: nil,
stream: nil,
stop: nil,
tools: nil,
tool_choice: nil,
frequency_penalty: nil,
presence_penalty: nil,
n: nil
)
let firstPrepared = PromptBuilder.build(from: firstRequest, modelId: "mlx-community/gemma-3-4b-it-4bit", thinkingEnabled: true)
let secondPrepared = PromptBuilder.build(from: secondRequest, modelId: "mlx-community/gemma-3-4b-it-4bit", thinkingEnabled: true)
XCTAssertEqual(firstPrepared.imageFingerprints.count, 1)
XCTAssertEqual(secondPrepared.imageFingerprints.count, 1)
XCTAssertNotEqual(firstPrepared.imageFingerprints, secondPrepared.imageFingerprints)
}
private func legacyBuild(
from request: APIChatCompletionRequest,
modelId: String,
@@ -237,6 +296,7 @@ final class PromptBuilderTests: XCTestCase {
instructions: instructions,
chatMessages: chatMessages,
messageSignatures: messageSignatures,
imageFingerprints: imageURLsFingerprintOrder(from: request),
estimatedBytes: estimatedBytes,
estimatedPromptTokens: (instructions.count + chatMessages.reduce(0) { $0 + $1.content.count }) * 10 / 35,
containsImages: containsImages,
@@ -245,6 +305,20 @@ final class PromptBuilderTests: XCTestCase {
)
}
private func imageURLsFingerprintOrder(from request: APIChatCompletionRequest) -> [UInt64] {
request.messages
.filter { $0.role != "system" }
.flatMap { $0.content?.imageURLs ?? [] }
.reduce(into: [UInt64]()) { fingerprints, imageURL in
var hash: UInt64 = 14_695_981_039_346_656_037
for byte in imageURL.utf8 {
hash ^= UInt64(byte)
hash &*= 1_099_511_628_211
}
fingerprints.append(hash)
}
}
private func messageSignature(role: Chat.Message.Role, content: String, imageURLs: [String]) -> UInt64 {
var hash: UInt64 = 14_695_981_039_346_656_037

View File

@@ -0,0 +1,30 @@
import Foundation
enum TestImageFixtures {
private static let repoRoot: URL = {
URL(fileURLWithPath: #filePath)
.deletingLastPathComponent()
.deletingLastPathComponent()
.deletingLastPathComponent()
}()
private static func loadBase64(named name: String) -> String {
let url = repoRoot
.appendingPathComponent("MLXServer")
.appendingPathComponent("Assets.xcassets")
.appendingPathComponent("AppIcon.appiconset")
.appendingPathComponent(name)
guard let data = try? Data(contentsOf: url) else {
fatalError("Missing image fixture at \(url.path)")
}
return data.base64EncodedString()
}
static let primaryPNGBase64 = loadBase64(named: "icon_16x16.png")
static let alternatePNGBase64 = loadBase64(named: "icon_32x32.png")
static let primaryDataURI = "data:image/png;base64,\(primaryPNGBase64)"
static let alternateDataURI = "data:image/png;base64,\(alternatePNGBase64)"
}

View File

@@ -109,6 +109,21 @@ final class TokenPrefixCacheTests: XCTestCase {
XCTAssertEqual(cache.debugTrieNodeCount(), 1)
}
func testCheckoutHitDoesNotCountAsEviction() {
let cache = TokenPrefixCache(
memoryBudgetBytes: 10_000,
estimateBytesProvider: { _ in 1_024 }
)
cache.store(entryId: UUID(), kvCache: [], cacheKey: [1, 2, 3], modelId: "model")
let lease = cache.lookup(cacheKey: [1, 2, 3, 4], modelId: "model")
let snapshot = cache.snapshot()
XCTAssertTrue(lease.isHit)
XCTAssertEqual(snapshot.totalEvictions, 0)
}
func testSnapshotReportsHitRateAndTokenTotals() {
let cache = TokenPrefixCache(
memoryBudgetBytes: 10_000,

View File

@@ -2575,7 +2575,7 @@ Validation note: `PromptBuilder.swift` is now covered by both shaping-parity uni
7. [x] **`APIServer.swift` rewrite** — Wire everything together. Replace ChatSession with InferenceEngine, ConversationSessionCache with TokenPrefixCache, add PromptBuilder and StreamingSSEEncoder.
8. [x] **Delete `ConversationSessionCache.swift`** — Only after APIServer is fully migrated and tested.
Validation note: `APIServer.swift` now routes the API path through `PromptBuilder`, `InferenceEngine`, `TokenPrefixCache`, and `StreamingSSEEncoder`, and the full repository test workflow is green. Image-bearing requests intentionally bypass prefix-cache reuse for now until image fingerprinting is implemented.
Validation note: `APIServer.swift` now routes the API path through `PromptBuilder`, `InferenceEngine`, `TokenPrefixCache`, and `StreamingSSEEncoder`, and the full repository test workflow is green. Image-bearing requests now participate in prefix-cache reuse via image-aware cache keys built from prompt tokens plus stable image fingerprints, preventing false hits across different images while enabling same-image reuse.
### Phase 4: Statistics & Monitoring
@@ -2583,13 +2583,13 @@ Validation note: `APIServer.swift` now routes the API path through `PromptBuilde
10. [x] **InferenceStats upgrade** — Add new snapshot fields, new time-series histories. Switch from ConversationSessionCache.snapshot() to TokenPrefixCache.snapshot().
11. [x] **MonitorView upgrade** — Add TTFT chart, prefill speed chart, cache match quality chart, cache memory budget chart. Update cache card and cumulative tiles. Add vision encoder time chart (conditional on VL model). Replace session list with cache entry list.
Validation note: `InferenceStats.swift` now samples `TokenPrefixCache` directly and `MonitorView.swift` now surfaces TTFT, prefill speed, cache match depth, cache memory pressure, disconnect totals, and vision prepare time from `LiveCounters`. Match-type hit breakdown is still open because it depends on the advanced cache matching work in Phase 5.
Validation note: `InferenceStats.swift` now samples `TokenPrefixCache` directly and `MonitorView.swift` now surfaces TTFT, prefill speed, cache match depth, cache memory pressure, disconnect totals, vision prepare time, and the prefix/supersequence/LCP hit breakdown from `LiveCounters` and `TokenPrefixCache`.
### Phase 5: Advanced Cache Matching
12. **Supersequence matching**Add `findSupersequenceMatchLocked()` and `trimCacheByOffset()` to `TokenPrefixCache`. Extend `lookup()` with subtree scan after prefix walk. Test: store a long entry, look up a shorter prefix of it → cache hit with trimmed KV.
13. **LCP matching**Add `findLCPMatchLocked()` to `TokenPrefixCache`. Extend `lookup()` with sibling-subtree scan at divergence point. Test: store `[SYS, A, B, X]`, look up `[SYS, A, B, Y]` → cache hit covering `[SYS, A, B]`, remaining `[Y]`.
14. **Match stats**Add `totalPrefixHits`, `totalSupersequenceHits`, `totalLCPHits` to stats and snapshot. Surface hit breakdown in MonitorView cache card.
12. [x] **Supersequence matching**`TokenPrefixCache` now includes `findSupersequenceMatchLocked()` and `trimCacheByOffset()`, and `lookup()` performs a subtree scan after a full-key walk with no direct entry. Coverage includes both logical cache tests and a model-backed test that verifies the leased KV cache is trimmed to the shorter prefix length.
13. [x] **LCP matching**`TokenPrefixCache` now includes `findLCPMatchLocked()`, and `lookup()` attempts LCP reuse only on actual divergence. Coverage includes direct cache tests for divergent suffix reuse and shallow-prefix rejection, plus model-backed same-system/different-user reuse validation.
14. [x] **Match stats**`TokenPrefixCache`, `InferenceStats`, and `MonitorView` now track and surface `prefixHits`, `supersequenceHits`, and `lcpHits` in the cache snapshot and monitor cache card.
### Phase 6: KV Cache Quantization
@@ -2654,13 +2654,13 @@ Validation note: `InferenceStats.swift` now samples `TokenPrefixCache` directly
- [ ] Multiple images in a single message → all images processed correctly
- [ ] Image + text in same message → both contribute to response
- [ ] Images in earlier messages, text-only follow-up → cache hit (vision encoder skipped)
- [ ] Same conversation, same images → cache hit on subsequent requests
- [ ] Same conversation, different image swapped → cache miss, fresh vision processing
- [x] Same conversation, same images → cache hit on subsequent requests
- [x] Same conversation, different image swapped → cache miss, fresh vision processing
- [ ] Text-only conversation on a VL model → no vision overhead, normal cache behavior
- [ ] Large images (4K+) → properly resized by UserInputProcessor, no OOM
- [ ] Base64 data-URI images decoded correctly (PNG, JPEG)
- [ ] Image fingerprinting: same image bytes → same fingerprint → cache hit
- [ ] Image fingerprinting: different images → different fingerprints → cache miss
- [x] Image fingerprinting: same image bytes → same fingerprint → cache hit
- [x] Image fingerprinting: different images → different fingerprints → cache miss
- [ ] Non-vision model rejects image inputs with clear error message
- [ ] Mixed: image in user msg 1, assistant response, text-only user msg 2 → cache covers all of msg 1 + response