feat: vision properly coverd with tests and completed
This commit is contained in:
@@ -120,6 +120,124 @@ final class APIServerRewriteTests: XCTestCase {
|
||||
XCTAssertGreaterThan(secondLiveSnapshot.cacheMatchDepth, 0)
|
||||
}
|
||||
|
||||
func testVisionPromptCachesAndReusesSameImageRequest() async throws {
|
||||
let harness = try await makeHarness()
|
||||
defer { harness.stop() }
|
||||
|
||||
let lookups = LookupEventCollector()
|
||||
APIServer.debugLookupEventHandler = { event in
|
||||
Task {
|
||||
await lookups.record(event)
|
||||
}
|
||||
}
|
||||
defer {
|
||||
APIServer.debugLookupEventHandler = nil
|
||||
}
|
||||
|
||||
let request = visionRequest(dataURI: TestImageFixtures.primaryDataURI, prompt: "Describe this image in one word.")
|
||||
|
||||
_ = try await sendChatCompletion(request, port: harness.port)
|
||||
_ = try await sendChatCompletion(request, port: harness.port)
|
||||
|
||||
try await waitUntil(timeoutSeconds: 5) {
|
||||
let events = await lookups.events()
|
||||
return events.count >= 2
|
||||
}
|
||||
|
||||
let events = await lookups.events()
|
||||
let secondLookup = try XCTUnwrap(events.last)
|
||||
XCTAssertTrue(secondLookup.isHit)
|
||||
XCTAssertEqual(secondLookup.matchedTokenCount, secondLookup.promptTokenCount)
|
||||
}
|
||||
|
||||
func testVisionPromptDifferentImageMissesCache() async throws {
|
||||
let harness = try await makeHarness()
|
||||
defer { harness.stop() }
|
||||
|
||||
let lookups = LookupEventCollector()
|
||||
APIServer.debugLookupEventHandler = { event in
|
||||
Task {
|
||||
await lookups.record(event)
|
||||
}
|
||||
}
|
||||
defer {
|
||||
APIServer.debugLookupEventHandler = nil
|
||||
}
|
||||
|
||||
_ = try await sendChatCompletion(visionRequest(dataURI: TestImageFixtures.primaryDataURI, prompt: "Describe this image in one word."), port: harness.port)
|
||||
_ = try await sendChatCompletion(visionRequest(dataURI: TestImageFixtures.alternateDataURI, prompt: "Describe this image in one word."), port: harness.port)
|
||||
|
||||
try await waitUntil(timeoutSeconds: 5) {
|
||||
let events = await lookups.events()
|
||||
return events.count >= 2
|
||||
}
|
||||
|
||||
let events = await lookups.events()
|
||||
let secondLookup = try XCTUnwrap(events.last)
|
||||
XCTAssertFalse(secondLookup.isHit)
|
||||
XCTAssertEqual(secondLookup.matchedTokenCount, 0)
|
||||
}
|
||||
|
||||
func testTextOnlyFollowUpReusesEarlierImagePrefix() async throws {
|
||||
let harness = try await makeHarness()
|
||||
defer { harness.stop() }
|
||||
|
||||
let lookups = LookupEventCollector()
|
||||
APIServer.debugLookupEventHandler = { event in
|
||||
Task {
|
||||
await lookups.record(event)
|
||||
}
|
||||
}
|
||||
defer {
|
||||
APIServer.debugLookupEventHandler = nil
|
||||
}
|
||||
|
||||
let firstRequest = visionRequest(dataURI: TestImageFixtures.primaryDataURI, prompt: "Describe this image in one short word.")
|
||||
let firstResponse = try await sendChatCompletion(firstRequest, port: harness.port)
|
||||
let assistantContent = try XCTUnwrap(firstResponse.choices.first?.message.content)
|
||||
|
||||
let followUpRequest = APIChatCompletionRequest(
|
||||
model: "gemma",
|
||||
messages: [
|
||||
APIChatMessage(
|
||||
role: "user",
|
||||
content: .parts([
|
||||
APIContentPart(type: "text", text: "Describe this image in one short word.", image_url: nil),
|
||||
APIContentPart(type: "image_url", text: nil, image_url: APIImageURL(url: TestImageFixtures.primaryDataURI, detail: nil))
|
||||
]),
|
||||
name: nil,
|
||||
tool_calls: nil,
|
||||
tool_call_id: nil
|
||||
),
|
||||
APIChatMessage(role: "assistant", content: .text(assistantContent), name: nil, tool_calls: nil, tool_call_id: nil),
|
||||
APIChatMessage(role: "user", content: .text("Now answer in one word: what color is the sky?"), name: nil, tool_calls: nil, tool_call_id: nil)
|
||||
],
|
||||
temperature: 0,
|
||||
top_p: 1,
|
||||
max_tokens: 2,
|
||||
stream: false,
|
||||
stop: nil,
|
||||
tools: nil,
|
||||
tool_choice: nil,
|
||||
frequency_penalty: nil,
|
||||
presence_penalty: nil,
|
||||
n: nil
|
||||
)
|
||||
|
||||
_ = try await sendChatCompletion(followUpRequest, port: harness.port)
|
||||
|
||||
try await waitUntil(timeoutSeconds: 5) {
|
||||
let events = await lookups.events()
|
||||
return events.count >= 2
|
||||
}
|
||||
|
||||
let events = await lookups.events()
|
||||
let secondLookup = try XCTUnwrap(events.last)
|
||||
XCTAssertTrue(secondLookup.isHit)
|
||||
XCTAssertGreaterThan(secondLookup.matchedTokenCount, 0)
|
||||
XCTAssertLessThan(secondLookup.matchedTokenCount, secondLookup.promptTokenCount)
|
||||
}
|
||||
|
||||
func testSecondIdenticalRequestIsFullCacheHitWithZeroRebuiltPromptTokens() async throws {
|
||||
let harness = try await makeHarness()
|
||||
defer { harness.stop() }
|
||||
@@ -1216,6 +1334,34 @@ final class APIServerRewriteTests: XCTestCase {
|
||||
return TestHarness(server: server, modelManager: modelManager, port: port)
|
||||
}
|
||||
|
||||
private func visionRequest(dataURI: String, prompt: String) -> APIChatCompletionRequest {
|
||||
APIChatCompletionRequest(
|
||||
model: "gemma",
|
||||
messages: [
|
||||
APIChatMessage(
|
||||
role: "user",
|
||||
content: .parts([
|
||||
APIContentPart(type: "text", text: prompt, image_url: nil),
|
||||
APIContentPart(type: "image_url", text: nil, image_url: APIImageURL(url: dataURI, detail: nil))
|
||||
]),
|
||||
name: nil,
|
||||
tool_calls: nil,
|
||||
tool_call_id: nil
|
||||
)
|
||||
],
|
||||
temperature: 0,
|
||||
top_p: 1,
|
||||
max_tokens: 2,
|
||||
stream: false,
|
||||
stop: nil,
|
||||
tools: nil,
|
||||
tool_choice: nil,
|
||||
frequency_penalty: nil,
|
||||
presence_penalty: nil,
|
||||
n: nil
|
||||
)
|
||||
}
|
||||
|
||||
private func sendChatCompletion(_ request: APIChatCompletionRequest, port: UInt16) async throws -> APIChatCompletionResponse {
|
||||
let url = URL(string: "http://127.0.0.1:\(port)/v1/chat/completions")!
|
||||
var urlRequest = URLRequest(url: url)
|
||||
|
||||
@@ -2,17 +2,15 @@ import XCTest
|
||||
@testable import MLX_Server
|
||||
|
||||
final class ImageDecoderTests: XCTestCase {
|
||||
private let onePixelPNGBase64 = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8BQDwAFgwJ/lRyXWQAAAABJRU5ErkJggg=="
|
||||
|
||||
func testDecodeDataURI() {
|
||||
let image = ImageDecoder.decode("data:image/png;base64,\(onePixelPNGBase64)")
|
||||
let image = ImageDecoder.decode(TestImageFixtures.primaryDataURI)
|
||||
|
||||
XCTAssertNotNil(image)
|
||||
XCTAssertGreaterThanOrEqual(image?.estimatedBytes ?? 0, 4)
|
||||
}
|
||||
|
||||
func testDecodePlainBase64() {
|
||||
let image = ImageDecoder.decode(onePixelPNGBase64)
|
||||
let image = ImageDecoder.decode(TestImageFixtures.primaryPNGBase64)
|
||||
|
||||
XCTAssertNotNil(image)
|
||||
XCTAssertGreaterThanOrEqual(image?.estimatedBytes ?? 0, 4)
|
||||
|
||||
@@ -6,8 +6,6 @@ import XCTest
|
||||
@testable import MLX_Server
|
||||
|
||||
final class ModelBackedInferenceValidationTests: XCTestCase {
|
||||
private let onePixelPNGBase64 = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8BQDwAFgwJ/lRyXWQAAAABJRU5ErkJggg=="
|
||||
|
||||
func testPromptBuilderTokenizationMatchesLegacyShapingOnLocalGemma() async throws {
|
||||
let container = try await localGemmaContainer()
|
||||
let engine = InferenceEngine(container: container)
|
||||
@@ -19,7 +17,7 @@ final class ModelBackedInferenceValidationTests: XCTestCase {
|
||||
role: "user",
|
||||
content: .parts([
|
||||
APIContentPart(type: "text", text: "What is in this image?", image_url: nil),
|
||||
APIContentPart(type: "image_url", text: nil, image_url: APIImageURL(url: "data:image/png;base64,\(onePixelPNGBase64)", detail: nil))
|
||||
APIContentPart(type: "image_url", text: nil, image_url: APIImageURL(url: TestImageFixtures.primaryDataURI, detail: nil))
|
||||
]),
|
||||
name: nil,
|
||||
tool_calls: nil,
|
||||
@@ -92,6 +90,62 @@ final class ModelBackedInferenceValidationTests: XCTestCase {
|
||||
XCTAssertEqual(engineResult.promptTokenCount, sessionResult.promptTokenCount)
|
||||
}
|
||||
|
||||
func testVisionCacheKeyChangesWhenImageChangesButTokensStayTheSame() async throws {
|
||||
let container = try await localGemmaContainer()
|
||||
let engine = InferenceEngine(container: container)
|
||||
|
||||
let first = PromptBuilder.build(
|
||||
from: visionRequest(dataURI: TestImageFixtures.primaryDataURI),
|
||||
modelId: "mlx-community/gemma-3-4b-it-4bit",
|
||||
thinkingEnabled: false
|
||||
)
|
||||
let second = PromptBuilder.build(
|
||||
from: visionRequest(dataURI: TestImageFixtures.alternateDataURI),
|
||||
modelId: "mlx-community/gemma-3-4b-it-4bit",
|
||||
thinkingEnabled: false
|
||||
)
|
||||
|
||||
let firstPrepared = try await engine.prepare(first.userInput, imageFingerprints: first.imageFingerprints)
|
||||
let secondPrepared = try await engine.prepare(second.userInput, imageFingerprints: second.imageFingerprints)
|
||||
|
||||
XCTAssertEqual(firstPrepared.tokens, secondPrepared.tokens)
|
||||
XCTAssertNotEqual(firstPrepared.cacheKey, secondPrepared.cacheKey)
|
||||
}
|
||||
|
||||
func testStoredLiveGemmaVisionCacheReusesSameImagePrompt() async throws {
|
||||
let container = try await localGemmaContainer()
|
||||
let engine = InferenceEngine(container: container)
|
||||
|
||||
let prompt = PromptBuilder.build(
|
||||
from: visionRequest(dataURI: TestImageFixtures.primaryDataURI),
|
||||
modelId: "mlx-community/gemma-3-4b-it-4bit",
|
||||
thinkingEnabled: false
|
||||
)
|
||||
|
||||
let prepared = try await engine.prepare(prompt.userInput, imageFingerprints: prompt.imageFingerprints)
|
||||
let handle = try await engine.stream(
|
||||
InferenceEngine.InferenceRequest(
|
||||
input: prepared.lmInput,
|
||||
tokens: prepared.tokens,
|
||||
parameters: GenerateParameters(maxTokens: 2, temperature: 0),
|
||||
cachedKV: nil,
|
||||
cachedTokenCount: 0
|
||||
),
|
||||
cancellation: CancellationToken()
|
||||
)
|
||||
|
||||
_ = await collectEngineOutput(handle.stream)
|
||||
trimCacheToPrompt(handle.workingCache, promptTokenCount: prepared.tokens.count)
|
||||
|
||||
let cache = TokenPrefixCache(memoryBudgetBytes: 1_000_000_000, estimateBytesProvider: { _ in 1_024 })
|
||||
cache.store(entryId: UUID(), kvCache: handle.workingCache, cacheKey: prepared.cacheKey, modelId: "gemma")
|
||||
|
||||
let lease = cache.lookup(cacheKey: prepared.cacheKey, modelId: "gemma")
|
||||
|
||||
XCTAssertTrue(lease.isHit)
|
||||
XCTAssertEqual(lease.matchedTokenCount, prepared.tokens.count)
|
||||
}
|
||||
|
||||
func testTokenPrefixCacheFindsLCPHitForSameSystemDifferentUserOnLocalGemmaTokens() async throws {
|
||||
let container = try await localGemmaContainer()
|
||||
let engine = InferenceEngine(container: container)
|
||||
@@ -225,6 +279,71 @@ final class ModelBackedInferenceValidationTests: XCTestCase {
|
||||
XCTAssertLessThan(lease.matchedTokenCount, firstPrepared.tokens.count)
|
||||
}
|
||||
|
||||
func testStoredLiveGemmaCacheSupportsSupersequenceReuseForShorterPrefix() async throws {
|
||||
let container = try await localGemmaContainer()
|
||||
let engine = InferenceEngine(container: container)
|
||||
|
||||
let prompt = PromptBuilder.build(
|
||||
from: APIChatCompletionRequest(
|
||||
model: "gemma",
|
||||
messages: [
|
||||
APIChatMessage(role: "system", content: .text("You are terse and literal."), name: nil, tool_calls: nil, tool_call_id: nil),
|
||||
APIChatMessage(role: "user", content: .text("Respond with one word for cat, then one word for dog."), name: nil, tool_calls: nil, tool_call_id: nil),
|
||||
],
|
||||
temperature: nil,
|
||||
top_p: nil,
|
||||
max_tokens: nil,
|
||||
stream: nil,
|
||||
stop: nil,
|
||||
tools: nil,
|
||||
tool_choice: nil,
|
||||
frequency_penalty: nil,
|
||||
presence_penalty: nil,
|
||||
n: nil
|
||||
),
|
||||
modelId: "mlx-community/gemma-3-4b-it-4bit",
|
||||
thinkingEnabled: true
|
||||
)
|
||||
|
||||
let prepared = try await engine.prepare(prompt.userInput)
|
||||
XCTAssertGreaterThan(prepared.tokens.count, 16)
|
||||
|
||||
let handle = try await engine.stream(
|
||||
InferenceEngine.InferenceRequest(
|
||||
input: prepared.lmInput,
|
||||
tokens: prepared.tokens,
|
||||
parameters: GenerateParameters(maxTokens: 2, temperature: 0),
|
||||
cachedKV: nil,
|
||||
cachedTokenCount: 0
|
||||
),
|
||||
cancellation: CancellationToken()
|
||||
)
|
||||
|
||||
_ = await collectEngineOutput(handle.stream)
|
||||
trimCacheToPrompt(handle.workingCache, promptTokenCount: prepared.tokens.count)
|
||||
|
||||
let shorterTokenCount = prepared.tokens.count - 16
|
||||
let shorterPrefix = Array(prepared.tokens.prefix(shorterTokenCount))
|
||||
|
||||
let cache = TokenPrefixCache(memoryBudgetBytes: 1_000_000_000, estimateBytesProvider: { _ in 1_024 })
|
||||
cache.store(entryId: UUID(), kvCache: handle.workingCache, cacheKey: prepared.tokens, modelId: "gemma")
|
||||
|
||||
let lease = cache.lookup(cacheKey: shorterPrefix, modelId: "gemma")
|
||||
|
||||
XCTAssertTrue(lease.isHit)
|
||||
XCTAssertEqual(lease.matchedTokenCount, shorterTokenCount)
|
||||
let leasedCache = try XCTUnwrap(lease.kvCache)
|
||||
XCTAssertFalse(leasedCache.isEmpty)
|
||||
for layer in leasedCache {
|
||||
XCTAssertEqual(layer.offset, shorterTokenCount)
|
||||
}
|
||||
|
||||
let snapshot = cache.snapshot()
|
||||
XCTAssertEqual(snapshot.supersequenceHits, 1)
|
||||
XCTAssertEqual(snapshot.lcpHits, 0)
|
||||
XCTAssertEqual(snapshot.prefixHits, 0)
|
||||
}
|
||||
|
||||
func testTokenPrefixCacheCanFalseHitDifferentSystemPromptsOnRawGemmaTokens() async throws {
|
||||
let container = try await localGemmaContainer()
|
||||
let engine = InferenceEngine(container: container)
|
||||
@@ -376,6 +495,7 @@ final class ModelBackedInferenceValidationTests: XCTestCase {
|
||||
instructions: instructions,
|
||||
chatMessages: chatMessages,
|
||||
messageSignatures: messageSignatures,
|
||||
imageFingerprints: imageURLsFingerprintOrder(from: request),
|
||||
estimatedBytes: estimatedBytes,
|
||||
estimatedPromptTokens: (instructions.count + chatMessages.reduce(0) { $0 + $1.content.count }) * 10 / 35,
|
||||
containsImages: containsImages,
|
||||
@@ -384,6 +504,48 @@ final class ModelBackedInferenceValidationTests: XCTestCase {
|
||||
)
|
||||
}
|
||||
|
||||
private func visionRequest(dataURI: String) -> APIChatCompletionRequest {
|
||||
APIChatCompletionRequest(
|
||||
model: "gemma",
|
||||
messages: [
|
||||
APIChatMessage(
|
||||
role: "user",
|
||||
content: .parts([
|
||||
APIContentPart(type: "text", text: "What is in this image?", image_url: nil),
|
||||
APIContentPart(type: "image_url", text: nil, image_url: APIImageURL(url: dataURI, detail: nil))
|
||||
]),
|
||||
name: nil,
|
||||
tool_calls: nil,
|
||||
tool_call_id: nil
|
||||
)
|
||||
],
|
||||
temperature: nil,
|
||||
top_p: nil,
|
||||
max_tokens: nil,
|
||||
stream: nil,
|
||||
stop: nil,
|
||||
tools: nil,
|
||||
tool_choice: nil,
|
||||
frequency_penalty: nil,
|
||||
presence_penalty: nil,
|
||||
n: nil
|
||||
)
|
||||
}
|
||||
|
||||
private func imageURLsFingerprintOrder(from request: APIChatCompletionRequest) -> [UInt64] {
|
||||
request.messages
|
||||
.filter { $0.role != "system" }
|
||||
.flatMap { $0.content?.imageURLs ?? [] }
|
||||
.reduce(into: [UInt64]()) { fingerprints, imageURL in
|
||||
var hash: UInt64 = 14_695_981_039_346_656_037
|
||||
for byte in imageURL.utf8 {
|
||||
hash ^= UInt64(byte)
|
||||
hash &*= 1_099_511_628_211
|
||||
}
|
||||
fingerprints.append(hash)
|
||||
}
|
||||
}
|
||||
|
||||
private func messageSignature(role: Chat.Message.Role, content: String, imageURLs: [String]) -> UInt64 {
|
||||
var hash: UInt64 = 14_695_981_039_346_656_037
|
||||
|
||||
|
||||
@@ -3,8 +3,6 @@ import MLXLMCommon
|
||||
@testable import MLX_Server
|
||||
|
||||
final class PromptBuilderTests: XCTestCase {
|
||||
private let onePixelPNGBase64 = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8BQDwAFgwJ/lRyXWQAAAABJRU5ErkJggg=="
|
||||
|
||||
func testBuildMatchesLegacyAPIServerShapingForGemma() {
|
||||
let toolCall = APIToolCall(
|
||||
id: "call_weather",
|
||||
@@ -20,7 +18,7 @@ final class PromptBuilderTests: XCTestCase {
|
||||
role: "tool",
|
||||
content: .parts([
|
||||
APIContentPart(type: "text", text: "{\"temp\":19}", image_url: nil),
|
||||
APIContentPart(type: "image_url", text: nil, image_url: APIImageURL(url: "data:image/png;base64,\(onePixelPNGBase64)", detail: nil))
|
||||
APIContentPart(type: "image_url", text: nil, image_url: APIImageURL(url: TestImageFixtures.primaryDataURI, detail: nil))
|
||||
]),
|
||||
name: nil,
|
||||
tool_calls: nil,
|
||||
@@ -132,7 +130,7 @@ final class PromptBuilderTests: XCTestCase {
|
||||
role: "tool",
|
||||
content: .parts([
|
||||
APIContentPart(type: "text", text: "{\"ok\":true}", image_url: nil),
|
||||
APIContentPart(type: "image_url", text: nil, image_url: APIImageURL(url: "data:image/png;base64,\(onePixelPNGBase64)", detail: nil))
|
||||
APIContentPart(type: "image_url", text: nil, image_url: APIImageURL(url: TestImageFixtures.primaryDataURI, detail: nil))
|
||||
]),
|
||||
name: nil,
|
||||
tool_calls: nil,
|
||||
@@ -156,9 +154,70 @@ final class PromptBuilderTests: XCTestCase {
|
||||
XCTAssertTrue(prepared.chatMessages[0].content.contains("```tool_output"))
|
||||
XCTAssertTrue(prepared.containsImages)
|
||||
XCTAssertEqual(prepared.chatMessages[0].images.count, 1)
|
||||
XCTAssertEqual(prepared.imageFingerprints.count, 1)
|
||||
XCTAssertGreaterThan(prepared.estimatedBytes, prepared.chatMessages[0].content.utf8.count)
|
||||
}
|
||||
|
||||
func testBuildHashesRawImageSourcesIntoStableFingerprints() {
|
||||
let firstRequest = APIChatCompletionRequest(
|
||||
model: "gemma",
|
||||
messages: [
|
||||
APIChatMessage(
|
||||
role: "user",
|
||||
content: .parts([
|
||||
APIContentPart(type: "text", text: "Describe this.", image_url: nil),
|
||||
APIContentPart(type: "image_url", text: nil, image_url: APIImageURL(url: TestImageFixtures.primaryDataURI, detail: nil))
|
||||
]),
|
||||
name: nil,
|
||||
tool_calls: nil,
|
||||
tool_call_id: nil
|
||||
)
|
||||
],
|
||||
temperature: nil,
|
||||
top_p: nil,
|
||||
max_tokens: nil,
|
||||
stream: nil,
|
||||
stop: nil,
|
||||
tools: nil,
|
||||
tool_choice: nil,
|
||||
frequency_penalty: nil,
|
||||
presence_penalty: nil,
|
||||
n: nil
|
||||
)
|
||||
let secondRequest = APIChatCompletionRequest(
|
||||
model: "gemma",
|
||||
messages: [
|
||||
APIChatMessage(
|
||||
role: "user",
|
||||
content: .parts([
|
||||
APIContentPart(type: "text", text: "Describe this.", image_url: nil),
|
||||
APIContentPart(type: "image_url", text: nil, image_url: APIImageURL(url: TestImageFixtures.alternateDataURI, detail: nil))
|
||||
]),
|
||||
name: nil,
|
||||
tool_calls: nil,
|
||||
tool_call_id: nil
|
||||
)
|
||||
],
|
||||
temperature: nil,
|
||||
top_p: nil,
|
||||
max_tokens: nil,
|
||||
stream: nil,
|
||||
stop: nil,
|
||||
tools: nil,
|
||||
tool_choice: nil,
|
||||
frequency_penalty: nil,
|
||||
presence_penalty: nil,
|
||||
n: nil
|
||||
)
|
||||
|
||||
let firstPrepared = PromptBuilder.build(from: firstRequest, modelId: "mlx-community/gemma-3-4b-it-4bit", thinkingEnabled: true)
|
||||
let secondPrepared = PromptBuilder.build(from: secondRequest, modelId: "mlx-community/gemma-3-4b-it-4bit", thinkingEnabled: true)
|
||||
|
||||
XCTAssertEqual(firstPrepared.imageFingerprints.count, 1)
|
||||
XCTAssertEqual(secondPrepared.imageFingerprints.count, 1)
|
||||
XCTAssertNotEqual(firstPrepared.imageFingerprints, secondPrepared.imageFingerprints)
|
||||
}
|
||||
|
||||
private func legacyBuild(
|
||||
from request: APIChatCompletionRequest,
|
||||
modelId: String,
|
||||
@@ -237,6 +296,7 @@ final class PromptBuilderTests: XCTestCase {
|
||||
instructions: instructions,
|
||||
chatMessages: chatMessages,
|
||||
messageSignatures: messageSignatures,
|
||||
imageFingerprints: imageURLsFingerprintOrder(from: request),
|
||||
estimatedBytes: estimatedBytes,
|
||||
estimatedPromptTokens: (instructions.count + chatMessages.reduce(0) { $0 + $1.content.count }) * 10 / 35,
|
||||
containsImages: containsImages,
|
||||
@@ -245,6 +305,20 @@ final class PromptBuilderTests: XCTestCase {
|
||||
)
|
||||
}
|
||||
|
||||
private func imageURLsFingerprintOrder(from request: APIChatCompletionRequest) -> [UInt64] {
|
||||
request.messages
|
||||
.filter { $0.role != "system" }
|
||||
.flatMap { $0.content?.imageURLs ?? [] }
|
||||
.reduce(into: [UInt64]()) { fingerprints, imageURL in
|
||||
var hash: UInt64 = 14_695_981_039_346_656_037
|
||||
for byte in imageURL.utf8 {
|
||||
hash ^= UInt64(byte)
|
||||
hash &*= 1_099_511_628_211
|
||||
}
|
||||
fingerprints.append(hash)
|
||||
}
|
||||
}
|
||||
|
||||
private func messageSignature(role: Chat.Message.Role, content: String, imageURLs: [String]) -> UInt64 {
|
||||
var hash: UInt64 = 14_695_981_039_346_656_037
|
||||
|
||||
|
||||
30
MLXServerTests/Server/TestImageFixtures.swift
Normal file
30
MLXServerTests/Server/TestImageFixtures.swift
Normal file
@@ -0,0 +1,30 @@
|
||||
import Foundation
|
||||
|
||||
enum TestImageFixtures {
|
||||
private static let repoRoot: URL = {
|
||||
URL(fileURLWithPath: #filePath)
|
||||
.deletingLastPathComponent()
|
||||
.deletingLastPathComponent()
|
||||
.deletingLastPathComponent()
|
||||
}()
|
||||
|
||||
private static func loadBase64(named name: String) -> String {
|
||||
let url = repoRoot
|
||||
.appendingPathComponent("MLXServer")
|
||||
.appendingPathComponent("Assets.xcassets")
|
||||
.appendingPathComponent("AppIcon.appiconset")
|
||||
.appendingPathComponent(name)
|
||||
|
||||
guard let data = try? Data(contentsOf: url) else {
|
||||
fatalError("Missing image fixture at \(url.path)")
|
||||
}
|
||||
|
||||
return data.base64EncodedString()
|
||||
}
|
||||
|
||||
static let primaryPNGBase64 = loadBase64(named: "icon_16x16.png")
|
||||
static let alternatePNGBase64 = loadBase64(named: "icon_32x32.png")
|
||||
|
||||
static let primaryDataURI = "data:image/png;base64,\(primaryPNGBase64)"
|
||||
static let alternateDataURI = "data:image/png;base64,\(alternatePNGBase64)"
|
||||
}
|
||||
@@ -109,6 +109,21 @@ final class TokenPrefixCacheTests: XCTestCase {
|
||||
XCTAssertEqual(cache.debugTrieNodeCount(), 1)
|
||||
}
|
||||
|
||||
func testCheckoutHitDoesNotCountAsEviction() {
|
||||
let cache = TokenPrefixCache(
|
||||
memoryBudgetBytes: 10_000,
|
||||
estimateBytesProvider: { _ in 1_024 }
|
||||
)
|
||||
|
||||
cache.store(entryId: UUID(), kvCache: [], cacheKey: [1, 2, 3], modelId: "model")
|
||||
|
||||
let lease = cache.lookup(cacheKey: [1, 2, 3, 4], modelId: "model")
|
||||
let snapshot = cache.snapshot()
|
||||
|
||||
XCTAssertTrue(lease.isHit)
|
||||
XCTAssertEqual(snapshot.totalEvictions, 0)
|
||||
}
|
||||
|
||||
func testSnapshotReportsHitRateAndTokenTotals() {
|
||||
let cache = TokenPrefixCache(
|
||||
memoryBudgetBytes: 10_000,
|
||||
|
||||
Reference in New Issue
Block a user