feat: vision properly coverd with tests and completed

This commit is contained in:
2026-03-20 12:57:27 +01:00
parent e59be9df1a
commit 0761254d17
12 changed files with 648 additions and 40 deletions

View File

@@ -3,8 +3,6 @@ import MLXLMCommon
@testable import MLX_Server
final class PromptBuilderTests: XCTestCase {
private let onePixelPNGBase64 = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8BQDwAFgwJ/lRyXWQAAAABJRU5ErkJggg=="
func testBuildMatchesLegacyAPIServerShapingForGemma() {
let toolCall = APIToolCall(
id: "call_weather",
@@ -20,7 +18,7 @@ final class PromptBuilderTests: XCTestCase {
role: "tool",
content: .parts([
APIContentPart(type: "text", text: "{\"temp\":19}", image_url: nil),
APIContentPart(type: "image_url", text: nil, image_url: APIImageURL(url: "data:image/png;base64,\(onePixelPNGBase64)", detail: nil))
APIContentPart(type: "image_url", text: nil, image_url: APIImageURL(url: TestImageFixtures.primaryDataURI, detail: nil))
]),
name: nil,
tool_calls: nil,
@@ -132,7 +130,7 @@ final class PromptBuilderTests: XCTestCase {
role: "tool",
content: .parts([
APIContentPart(type: "text", text: "{\"ok\":true}", image_url: nil),
APIContentPart(type: "image_url", text: nil, image_url: APIImageURL(url: "data:image/png;base64,\(onePixelPNGBase64)", detail: nil))
APIContentPart(type: "image_url", text: nil, image_url: APIImageURL(url: TestImageFixtures.primaryDataURI, detail: nil))
]),
name: nil,
tool_calls: nil,
@@ -156,9 +154,70 @@ final class PromptBuilderTests: XCTestCase {
XCTAssertTrue(prepared.chatMessages[0].content.contains("```tool_output"))
XCTAssertTrue(prepared.containsImages)
XCTAssertEqual(prepared.chatMessages[0].images.count, 1)
XCTAssertEqual(prepared.imageFingerprints.count, 1)
XCTAssertGreaterThan(prepared.estimatedBytes, prepared.chatMessages[0].content.utf8.count)
}
func testBuildHashesRawImageSourcesIntoStableFingerprints() {
let firstRequest = APIChatCompletionRequest(
model: "gemma",
messages: [
APIChatMessage(
role: "user",
content: .parts([
APIContentPart(type: "text", text: "Describe this.", image_url: nil),
APIContentPart(type: "image_url", text: nil, image_url: APIImageURL(url: TestImageFixtures.primaryDataURI, detail: nil))
]),
name: nil,
tool_calls: nil,
tool_call_id: nil
)
],
temperature: nil,
top_p: nil,
max_tokens: nil,
stream: nil,
stop: nil,
tools: nil,
tool_choice: nil,
frequency_penalty: nil,
presence_penalty: nil,
n: nil
)
let secondRequest = APIChatCompletionRequest(
model: "gemma",
messages: [
APIChatMessage(
role: "user",
content: .parts([
APIContentPart(type: "text", text: "Describe this.", image_url: nil),
APIContentPart(type: "image_url", text: nil, image_url: APIImageURL(url: TestImageFixtures.alternateDataURI, detail: nil))
]),
name: nil,
tool_calls: nil,
tool_call_id: nil
)
],
temperature: nil,
top_p: nil,
max_tokens: nil,
stream: nil,
stop: nil,
tools: nil,
tool_choice: nil,
frequency_penalty: nil,
presence_penalty: nil,
n: nil
)
let firstPrepared = PromptBuilder.build(from: firstRequest, modelId: "mlx-community/gemma-3-4b-it-4bit", thinkingEnabled: true)
let secondPrepared = PromptBuilder.build(from: secondRequest, modelId: "mlx-community/gemma-3-4b-it-4bit", thinkingEnabled: true)
XCTAssertEqual(firstPrepared.imageFingerprints.count, 1)
XCTAssertEqual(secondPrepared.imageFingerprints.count, 1)
XCTAssertNotEqual(firstPrepared.imageFingerprints, secondPrepared.imageFingerprints)
}
private func legacyBuild(
from request: APIChatCompletionRequest,
modelId: String,
@@ -237,6 +296,7 @@ final class PromptBuilderTests: XCTestCase {
instructions: instructions,
chatMessages: chatMessages,
messageSignatures: messageSignatures,
imageFingerprints: imageURLsFingerprintOrder(from: request),
estimatedBytes: estimatedBytes,
estimatedPromptTokens: (instructions.count + chatMessages.reduce(0) { $0 + $1.content.count }) * 10 / 35,
containsImages: containsImages,
@@ -245,6 +305,20 @@ final class PromptBuilderTests: XCTestCase {
)
}
private func imageURLsFingerprintOrder(from request: APIChatCompletionRequest) -> [UInt64] {
request.messages
.filter { $0.role != "system" }
.flatMap { $0.content?.imageURLs ?? [] }
.reduce(into: [UInt64]()) { fingerprints, imageURL in
var hash: UInt64 = 14_695_981_039_346_656_037
for byte in imageURL.utf8 {
hash ^= UInt64(byte)
hash &*= 1_099_511_628_211
}
fingerprints.append(hash)
}
}
private func messageSignature(role: Chat.Message.Role, content: String, imageURLs: [String]) -> UInt64 {
var hash: UInt64 = 14_695_981_039_346_656_037