feat: finished all open things up to and including phase 6

This commit is contained in:
2026-03-21 08:41:13 +01:00
parent 0325fa8964
commit 107ac0524b
9 changed files with 457 additions and 33 deletions

View File

@@ -153,6 +153,61 @@ final class APIServerRewriteTests: XCTestCase {
XCTAssertEqual(secondLookup.matchedTokenCount, secondLookup.promptTokenCount)
}
func testSingleImageAndTextPromptProducesVisionResponse() async throws {
let harness = try await makeHarness(initialModelId: "gemma")
defer { harness.stop() }
let response = try await sendChatCompletion(
visionRequest(
modelId: "gemma",
dataURI: TestImageFixtures.primaryDataURI,
prompt: "Describe this image in one short word."
),
port: harness.port
)
XCTAssertEqual(response.choices.count, 1)
XCTAssertFalse((response.choices[0].message.content ?? "").trimmingCharacters(in: .whitespacesAndNewlines).isEmpty)
XCTAssertGreaterThan(LiveCounters.shared.snapshot().totalVisionEncoderDuration, 0)
}
func testMultipleImagesInSingleMessageProduceVisionResponse() async throws {
let harness = try await makeHarness(initialModelId: "gemma")
defer { harness.stop() }
let request = APIChatCompletionRequest(
model: "gemma",
messages: [
APIChatMessage(
role: "user",
content: .parts([
APIContentPart(type: "text", text: "Compare these two images in a few words.", image_url: nil),
APIContentPart(type: "image_url", text: nil, image_url: APIImageURL(url: TestImageFixtures.primaryDataURI, detail: nil)),
APIContentPart(type: "image_url", text: nil, image_url: APIImageURL(url: TestImageFixtures.alternateDataURI, detail: nil))
]),
name: nil,
tool_calls: nil,
tool_call_id: nil
)
],
temperature: 0,
top_p: 1,
max_tokens: 6,
stream: false,
stop: nil,
tools: nil,
tool_choice: nil,
frequency_penalty: nil,
presence_penalty: nil,
n: nil
)
let response = try await sendChatCompletion(request, port: harness.port)
XCTAssertEqual(response.choices.count, 1)
XCTAssertFalse((response.choices[0].message.content ?? "").trimmingCharacters(in: .whitespacesAndNewlines).isEmpty)
}
func testVisionPromptDifferentImageMissesCache() async throws {
let harness = try await makeHarness()
defer { harness.stop() }
@@ -241,6 +296,74 @@ final class APIServerRewriteTests: XCTestCase {
XCTAssertLessThan(secondLookup.matchedTokenCount, secondLookup.promptTokenCount)
}
func testTextOnlyRequestOnVisionModelDoesNotRecordVisionTime() async throws {
let harness = try await makeHarness(initialModelId: "gemma")
defer { harness.stop() }
let request = APIChatCompletionRequest(
model: "gemma",
messages: [
APIChatMessage(role: "user", content: .text("Answer in one word: stone."), name: nil, tool_calls: nil, tool_call_id: nil)
],
temperature: 0,
top_p: 1,
max_tokens: 2,
stream: false,
stop: nil,
tools: nil,
tool_choice: nil,
frequency_penalty: nil,
presence_penalty: nil,
n: nil
)
let response = try await sendChatCompletion(request, port: harness.port)
XCTAssertEqual(response.choices.count, 1)
XCTAssertFalse((response.choices[0].message.content ?? "").trimmingCharacters(in: .whitespacesAndNewlines).isEmpty)
XCTAssertEqual(LiveCounters.shared.snapshot().totalVisionEncoderDuration, 0)
}
func testLargeImagePromptSucceedsOnVisionModel() async throws {
let harness = try await makeHarness(initialModelId: "gemma")
defer { harness.stop() }
let response = try await sendChatCompletion(
visionRequest(
modelId: "gemma",
dataURI: TestImageFixtures.largeDataURI,
prompt: "Describe this image briefly."
),
port: harness.port
)
XCTAssertEqual(response.choices.count, 1)
XCTAssertFalse((response.choices[0].message.content ?? "").trimmingCharacters(in: .whitespacesAndNewlines).isEmpty)
XCTAssertGreaterThan(LiveCounters.shared.snapshot().totalVisionEncoderDuration, 0)
}
func testNonVisionModelRejectsImageInputsWithClearError() async throws {
guard let stheno = ModelConfig.resolve("stheno"), stheno.isLocal else {
throw XCTSkip("Local non-vision model fixture is unavailable")
}
let harness = try await makeHarness(initialModelId: "stheno")
defer { harness.stop() }
let response = try await sendChatCompletionExpectingStatus(
visionRequest(
modelId: "stheno",
dataURI: TestImageFixtures.primaryDataURI,
prompt: "Describe this image in one word."
),
port: harness.port,
expectedStatus: 400
)
XCTAssertTrue(response.body.contains("vision_not_supported"))
XCTAssertTrue(response.body.contains("does not support image inputs"))
}
func testSecondIdenticalRequestIsFullCacheHitWithZeroRebuiltPromptTokens() async throws {
let harness = try await makeHarness()
defer { harness.stop() }
@@ -1378,6 +1501,23 @@ final class APIServerRewriteTests: XCTestCase {
return try JSONDecoder().decode(APIChatCompletionResponse.self, from: data)
}
private func sendChatCompletionExpectingStatus(
_ request: APIChatCompletionRequest,
port: UInt16,
expectedStatus: Int
) async throws -> (statusCode: Int, body: String, bodyData: Data) {
let url = URL(string: "http://127.0.0.1:\(port)/v1/chat/completions")!
var urlRequest = URLRequest(url: url)
urlRequest.httpMethod = "POST"
urlRequest.setValue("application/json", forHTTPHeaderField: "Content-Type")
urlRequest.httpBody = try JSONEncoder().encode(request)
let (data, response) = try await URLSession.shared.data(for: urlRequest)
let httpResponse = try XCTUnwrap(response as? HTTPURLResponse)
XCTAssertEqual(httpResponse.statusCode, expectedStatus, String(data: data, encoding: .utf8) ?? "")
return (httpResponse.statusCode, String(data: data, encoding: .utf8) ?? "", data)
}
private func sendModelsRequest(port: UInt16) async throws -> APIModelListResponse {
let response = try await sendRawRequest(path: "/v1/models", port: port)
XCTAssertEqual(response.statusCode, 200)

View File

@@ -1,3 +1,4 @@
import MLXLMCommon
import XCTest
@testable import MLX_Server
@@ -15,4 +16,24 @@ final class ImageDecoderTests: XCTestCase {
XCTAssertNotNil(image)
XCTAssertGreaterThanOrEqual(image?.estimatedBytes ?? 0, 4)
}
func testDecodeJPEGDataURI() {
let image = ImageDecoder.decode(TestImageFixtures.primaryJPEGDataURI)
XCTAssertNotNil(image)
XCTAssertGreaterThanOrEqual(image?.estimatedBytes ?? 0, 64 * 64 * 4)
}
func testDecodeLarge4KDataURI() throws {
let image = try XCTUnwrap(ImageDecoder.decode(TestImageFixtures.largeDataURI))
XCTAssertGreaterThanOrEqual(image.estimatedBytes, 4_096 * 4_096 * 4)
if case .ciImage(let ciImage) = image.image {
XCTAssertEqual(Int(ciImage.extent.width), 4_096)
XCTAssertEqual(Int(ciImage.extent.height), 4_096)
} else {
XCTFail("Expected CIImage-backed decoded image")
}
}
}

View File

@@ -5,6 +5,16 @@ import MLXVLM
import XCTest
@testable import MLX_Server
private struct GemmaPreprocessorConfig: Decodable {
let do_resize: Bool
let size: GemmaPreprocessorSize
}
private struct GemmaPreprocessorSize: Decodable {
let height: Int
let width: Int
}
final class ModelBackedInferenceValidationTests: XCTestCase {
func testPromptBuilderTokenizationMatchesLegacyShapingOnLocalGemma() async throws {
let container = try await localGemmaContainer()
@@ -146,6 +156,35 @@ final class ModelBackedInferenceValidationTests: XCTestCase {
XCTAssertEqual(lease.matchedTokenCount, prepared.tokens.count)
}
func testLarge4KImageUsesGemmaResizeConfigAndPreparesSuccessfully() async throws {
let container = try await localGemmaContainer()
let engine = InferenceEngine(container: container)
let preprocessorURL = try XCTUnwrap(
LocalModelResolver.resolve(repoId: "mlx-community/gemma-3-4b-it-4bit")?
.appendingPathComponent("preprocessor_config.json"),
"Local Gemma preprocessor config is unavailable"
)
let preprocessorData = try Data(contentsOf: preprocessorURL)
let preprocessor = try JSONDecoder().decode(GemmaPreprocessorConfig.self, from: preprocessorData)
let decoded = try XCTUnwrap(ImageDecoder.decode(TestImageFixtures.largeDataURI))
let userInput = UserInput(
prompt: .chat([
Chat.Message(role: .user, content: "What is in this image?", images: [decoded.image])
]),
images: [decoded.image],
videos: [],
tools: nil,
additionalContext: ["enable_thinking": false]
)
let prepared = try await engine.prepare(userInput)
XCTAssertTrue(preprocessor.do_resize)
XCTAssertEqual(preprocessor.size.height, preprocessor.size.width)
XCTAssertLessThan(preprocessor.size.height, 4_096)
XCTAssertFalse(prepared.tokens.isEmpty)
}
func testTokenPrefixCacheFindsLCPHitForSameSystemDifferentUserOnLocalGemmaTokens() async throws {
let container = try await localGemmaContainer()
let engine = InferenceEngine(container: container)

View File

@@ -1,3 +1,4 @@
import AppKit
import Foundation
enum TestImageFixtures {
@@ -22,9 +23,66 @@ enum TestImageFixtures {
return data.base64EncodedString()
}
private static func generatedBitmapData(
width: Int,
height: Int,
fileType: NSBitmapImageRep.FileType,
compressionFactor: Double? = nil
) -> Data {
let bytesPerRow = width * 4
guard let rep = NSBitmapImageRep(
bitmapDataPlanes: nil,
pixelsWide: width,
pixelsHigh: height,
bitsPerSample: 8,
samplesPerPixel: 4,
hasAlpha: true,
isPlanar: false,
colorSpaceName: .deviceRGB,
bytesPerRow: bytesPerRow,
bitsPerPixel: 32
) else {
fatalError("Failed to create bitmap fixture")
}
NSGraphicsContext.saveGraphicsState()
NSGraphicsContext.current = NSGraphicsContext(bitmapImageRep: rep)
let imageRect = NSRect(x: 0, y: 0, width: CGFloat(width), height: CGFloat(height))
NSColor(calibratedRed: 0.18, green: 0.45, blue: 0.87, alpha: 1).setFill()
imageRect.fill()
NSColor.white.setStroke()
let inset = CGFloat(max(8, min(width, height) / 16))
NSBezierPath(rect: imageRect.insetBy(dx: inset, dy: inset)).stroke()
NSGraphicsContext.restoreGraphicsState()
var properties: [NSBitmapImageRep.PropertyKey: Any] = [:]
if let compressionFactor {
properties[.compressionFactor] = compressionFactor
}
guard let data = rep.representation(using: fileType, properties: properties) else {
fatalError("Failed to encode bitmap fixture")
}
return data
}
static let primaryPNGBase64 = loadBase64(named: "icon_16x16.png")
static let alternatePNGBase64 = loadBase64(named: "icon_32x32.png")
static let primaryJPEGBase64 = generatedBitmapData(
width: 64,
height: 64,
fileType: .jpeg,
compressionFactor: 0.85
).base64EncodedString()
static let largePNGBase64 = generatedBitmapData(
width: 4_096,
height: 4_096,
fileType: .png
).base64EncodedString()
static let primaryDataURI = "data:image/png;base64,\(primaryPNGBase64)"
static let alternateDataURI = "data:image/png;base64,\(alternatePNGBase64)"
static let primaryJPEGDataURI = "data:image/jpeg;base64,\(primaryJPEGBase64)"
static let largeDataURI = "data:image/png;base64,\(largePNGBase64)"
}

View File

@@ -1,4 +1,5 @@
import Foundation
import MLX
import XCTest
import MLXLMCommon
@testable import MLX_Server
@@ -225,6 +226,96 @@ final class TokenPrefixCacheTests: XCTestCase {
XCTAssertEqual(snapshot.lcpHits, 0)
}
func testSupersequenceSkipsNonTrimmableLayersGracefully() {
let cache = TokenPrefixCache(
memoryBudgetBytes: 10_000,
estimateBytesProvider: { _ in 1_024 }
)
let layer = TestTrimRecordingCache(offset: 4, trimmable: false)
cache.store(entryId: UUID(), kvCache: [layer], cacheKey: [1, 2, 3, 4], modelId: "model")
let lease = cache.lookup(cacheKey: [1, 2, 3], modelId: "model")
let snapshot = cache.snapshot()
XCTAssertFalse(lease.isHit)
XCTAssertEqual(layer.offset, 4)
XCTAssertTrue(layer.trimCalls.isEmpty)
XCTAssertEqual(snapshot.supersequenceHits, 0)
XCTAssertEqual(snapshot.totalMisses, 1)
}
func testSupersequenceChoosesShallowestCandidate() {
let cache = TokenPrefixCache(
memoryBudgetBytes: 10_000,
estimateBytesProvider: { _ in 1_024 }
)
let shallowestId = UUID()
cache.store(entryId: UUID(), kvCache: [], cacheKey: [1, 2, 3, 4, 5], modelId: "model")
cache.store(entryId: UUID(), kvCache: [], cacheKey: [1, 2, 3, 4], modelId: "model")
cache.store(entryId: shallowestId, kvCache: [], cacheKey: [1, 2, 3], modelId: "model")
let lease = cache.lookup(cacheKey: [1, 2], modelId: "model")
XCTAssertTrue(lease.isHit)
XCTAssertEqual(lease.entryId, shallowestId)
XCTAssertEqual(lease.matchedTokenCount, 2)
}
func testSupersequencePathWinsWhenFullQueryWalkCanAlsoSeeDivergentSibling() {
let cache = TokenPrefixCache(
memoryBudgetBytes: 10_000,
estimateBytesProvider: { _ in 1_024 }
)
let supersequenceId = UUID()
cache.store(entryId: supersequenceId, kvCache: [], cacheKey: [1, 2, 3], modelId: "model")
cache.store(entryId: UUID(), kvCache: [], cacheKey: [1, 9, 8], modelId: "model")
let lease = cache.lookup(cacheKey: [1, 2], modelId: "model")
let snapshot = cache.snapshot()
XCTAssertTrue(lease.isHit)
XCTAssertEqual(lease.entryId, supersequenceId)
XCTAssertEqual(snapshot.supersequenceHits, 1)
XCTAssertEqual(snapshot.lcpHits, 0)
}
func testLCPChoosesShallowestSiblingCandidate() {
let cache = TokenPrefixCache(
memoryBudgetBytes: 10_000,
estimateBytesProvider: { _ in 1_024 }
)
let shallowestId = UUID()
cache.store(entryId: UUID(), kvCache: [], cacheKey: [1, 2, 3, 7], modelId: "model")
cache.store(entryId: UUID(), kvCache: [], cacheKey: [1, 2, 4, 7, 8], modelId: "model")
cache.store(entryId: shallowestId, kvCache: [], cacheKey: [1, 2, 5], modelId: "model")
let lease = cache.lookup(cacheKey: [1, 2, 9, 9], modelId: "model")
XCTAssertTrue(lease.isHit)
XCTAssertEqual(lease.entryId, shallowestId)
XCTAssertEqual(lease.matchedTokenCount, 2)
}
func testTrimUsesExactExcessAndReducesOffset() {
let cache = TokenPrefixCache(
memoryBudgetBytes: 10_000,
estimateBytesProvider: { _ in 1_024 }
)
let layer = TestTrimRecordingCache(offset: 5, trimmable: true)
cache.store(entryId: UUID(), kvCache: [layer], cacheKey: [1, 2, 3, 4, 5], modelId: "model")
let lease = cache.lookup(cacheKey: [1, 2, 3], modelId: "model")
XCTAssertTrue(lease.isHit)
XCTAssertEqual(layer.trimCalls, [2])
XCTAssertEqual(layer.offset, 3)
}
func testComputeMemoryBudgetUsesFallbackWhenDeviceUnavailable() {
let budget = TokenPrefixCache.computeMemoryBudget(recommendedWorkingSetSize: nil)
@@ -248,4 +339,53 @@ final class TokenPrefixCacheTests: XCTestCase {
XCTAssertEqual(budget, 8 * 1024 * 1024 * 1024)
}
}
private final class TestTrimRecordingCache: KVCache {
private var arrays: [MLXArray] = []
var offset: Int
let maxSize: Int? = nil
let trimmable: Bool
private(set) var trimCalls: [Int] = []
init(offset: Int, trimmable: Bool) {
self.offset = offset
self.trimmable = trimmable
}
func innerState() -> [MLXArray] {
arrays
}
var state: [MLXArray] {
get { arrays }
set { arrays = newValue }
}
var metaState: [String] {
get { [String(offset)] }
set { offset = Int(newValue.first ?? "0") ?? 0 }
}
var isTrimmable: Bool { trimmable }
func update(keys: MLXArray, values: MLXArray) -> (MLXArray, MLXArray) {
fatalError("TestTrimRecordingCache does not support update")
}
@discardableResult
func trim(_ n: Int) -> Int {
guard trimmable else { return 0 }
trimCalls.append(n)
offset = max(0, offset - n)
return n
}
func makeMask(
n: Int,
windowSize: Int?,
returnArray: Bool
) -> MLXFast.ScaledDotProductAttentionMaskMode {
.none
}
}