feat: vision properly coverd with tests and completed
This commit is contained in:
@@ -3,8 +3,6 @@ import MLXLMCommon
|
||||
@testable import MLX_Server
|
||||
|
||||
final class PromptBuilderTests: XCTestCase {
|
||||
private let onePixelPNGBase64 = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8BQDwAFgwJ/lRyXWQAAAABJRU5ErkJggg=="
|
||||
|
||||
func testBuildMatchesLegacyAPIServerShapingForGemma() {
|
||||
let toolCall = APIToolCall(
|
||||
id: "call_weather",
|
||||
@@ -20,7 +18,7 @@ final class PromptBuilderTests: XCTestCase {
|
||||
role: "tool",
|
||||
content: .parts([
|
||||
APIContentPart(type: "text", text: "{\"temp\":19}", image_url: nil),
|
||||
APIContentPart(type: "image_url", text: nil, image_url: APIImageURL(url: "data:image/png;base64,\(onePixelPNGBase64)", detail: nil))
|
||||
APIContentPart(type: "image_url", text: nil, image_url: APIImageURL(url: TestImageFixtures.primaryDataURI, detail: nil))
|
||||
]),
|
||||
name: nil,
|
||||
tool_calls: nil,
|
||||
@@ -132,7 +130,7 @@ final class PromptBuilderTests: XCTestCase {
|
||||
role: "tool",
|
||||
content: .parts([
|
||||
APIContentPart(type: "text", text: "{\"ok\":true}", image_url: nil),
|
||||
APIContentPart(type: "image_url", text: nil, image_url: APIImageURL(url: "data:image/png;base64,\(onePixelPNGBase64)", detail: nil))
|
||||
APIContentPart(type: "image_url", text: nil, image_url: APIImageURL(url: TestImageFixtures.primaryDataURI, detail: nil))
|
||||
]),
|
||||
name: nil,
|
||||
tool_calls: nil,
|
||||
@@ -156,9 +154,70 @@ final class PromptBuilderTests: XCTestCase {
|
||||
XCTAssertTrue(prepared.chatMessages[0].content.contains("```tool_output"))
|
||||
XCTAssertTrue(prepared.containsImages)
|
||||
XCTAssertEqual(prepared.chatMessages[0].images.count, 1)
|
||||
XCTAssertEqual(prepared.imageFingerprints.count, 1)
|
||||
XCTAssertGreaterThan(prepared.estimatedBytes, prepared.chatMessages[0].content.utf8.count)
|
||||
}
|
||||
|
||||
func testBuildHashesRawImageSourcesIntoStableFingerprints() {
|
||||
let firstRequest = APIChatCompletionRequest(
|
||||
model: "gemma",
|
||||
messages: [
|
||||
APIChatMessage(
|
||||
role: "user",
|
||||
content: .parts([
|
||||
APIContentPart(type: "text", text: "Describe this.", image_url: nil),
|
||||
APIContentPart(type: "image_url", text: nil, image_url: APIImageURL(url: TestImageFixtures.primaryDataURI, detail: nil))
|
||||
]),
|
||||
name: nil,
|
||||
tool_calls: nil,
|
||||
tool_call_id: nil
|
||||
)
|
||||
],
|
||||
temperature: nil,
|
||||
top_p: nil,
|
||||
max_tokens: nil,
|
||||
stream: nil,
|
||||
stop: nil,
|
||||
tools: nil,
|
||||
tool_choice: nil,
|
||||
frequency_penalty: nil,
|
||||
presence_penalty: nil,
|
||||
n: nil
|
||||
)
|
||||
let secondRequest = APIChatCompletionRequest(
|
||||
model: "gemma",
|
||||
messages: [
|
||||
APIChatMessage(
|
||||
role: "user",
|
||||
content: .parts([
|
||||
APIContentPart(type: "text", text: "Describe this.", image_url: nil),
|
||||
APIContentPart(type: "image_url", text: nil, image_url: APIImageURL(url: TestImageFixtures.alternateDataURI, detail: nil))
|
||||
]),
|
||||
name: nil,
|
||||
tool_calls: nil,
|
||||
tool_call_id: nil
|
||||
)
|
||||
],
|
||||
temperature: nil,
|
||||
top_p: nil,
|
||||
max_tokens: nil,
|
||||
stream: nil,
|
||||
stop: nil,
|
||||
tools: nil,
|
||||
tool_choice: nil,
|
||||
frequency_penalty: nil,
|
||||
presence_penalty: nil,
|
||||
n: nil
|
||||
)
|
||||
|
||||
let firstPrepared = PromptBuilder.build(from: firstRequest, modelId: "mlx-community/gemma-3-4b-it-4bit", thinkingEnabled: true)
|
||||
let secondPrepared = PromptBuilder.build(from: secondRequest, modelId: "mlx-community/gemma-3-4b-it-4bit", thinkingEnabled: true)
|
||||
|
||||
XCTAssertEqual(firstPrepared.imageFingerprints.count, 1)
|
||||
XCTAssertEqual(secondPrepared.imageFingerprints.count, 1)
|
||||
XCTAssertNotEqual(firstPrepared.imageFingerprints, secondPrepared.imageFingerprints)
|
||||
}
|
||||
|
||||
private func legacyBuild(
|
||||
from request: APIChatCompletionRequest,
|
||||
modelId: String,
|
||||
@@ -237,6 +296,7 @@ final class PromptBuilderTests: XCTestCase {
|
||||
instructions: instructions,
|
||||
chatMessages: chatMessages,
|
||||
messageSignatures: messageSignatures,
|
||||
imageFingerprints: imageURLsFingerprintOrder(from: request),
|
||||
estimatedBytes: estimatedBytes,
|
||||
estimatedPromptTokens: (instructions.count + chatMessages.reduce(0) { $0 + $1.content.count }) * 10 / 35,
|
||||
containsImages: containsImages,
|
||||
@@ -245,6 +305,20 @@ final class PromptBuilderTests: XCTestCase {
|
||||
)
|
||||
}
|
||||
|
||||
private func imageURLsFingerprintOrder(from request: APIChatCompletionRequest) -> [UInt64] {
|
||||
request.messages
|
||||
.filter { $0.role != "system" }
|
||||
.flatMap { $0.content?.imageURLs ?? [] }
|
||||
.reduce(into: [UInt64]()) { fingerprints, imageURL in
|
||||
var hash: UInt64 = 14_695_981_039_346_656_037
|
||||
for byte in imageURL.utf8 {
|
||||
hash ^= UInt64(byte)
|
||||
hash &*= 1_099_511_628_211
|
||||
}
|
||||
fingerprints.append(hash)
|
||||
}
|
||||
}
|
||||
|
||||
private func messageSignature(role: Chat.Message.Role, content: String, imageURLs: [String]) -> UInt64 {
|
||||
var hash: UInt64 = 14_695_981_039_346_656_037
|
||||
|
||||
|
||||
Reference in New Issue
Block a user