Files
MLXServer/MLXServerTests/Server/LocalModelResolverTests.swift

184 lines
7.2 KiB
Swift

import Foundation
import XCTest
@testable import MLX_Server
final class LocalModelResolverTests: XCTestCase {
func testDiscoverSystemHFModelsInfersTextOnlyMetadata() throws {
let base = try makeTempHFCache()
let snapshotDir = try makeHFSnapshot(base: base, repoId: "example/text-only")
try writeJSON(
[
"architectures": ["LlamaForCausalLM"],
"max_position_embeddings": 32768,
],
to: snapshotDir.appendingPathComponent("config.json")
)
try Data(repeating: 0x11, count: 64).write(to: snapshotDir.appendingPathComponent("model.safetensors"))
try Data(repeating: 0x22, count: 19).write(to: snapshotDir.appendingPathComponent("tokenizer.json"))
let expectedSize = Int64(
try Data(contentsOf: snapshotDir.appendingPathComponent("config.json")).count
+ Data(contentsOf: snapshotDir.appendingPathComponent("model.safetensors")).count
+ Data(contentsOf: snapshotDir.appendingPathComponent("tokenizer.json")).count
)
let discovered = LocalModelResolver.discoverSystemHFModels(in: base)
let model = try XCTUnwrap(discovered.first)
XCTAssertEqual(model.repoId, "example/text-only")
XCTAssertEqual(model.contextLength, 32768)
XCTAssertFalse(model.supportsImages)
XCTAssertEqual(model.loaderKinds, [.llm, .vlm])
XCTAssertEqual(model.sizeBytes, expectedSize)
}
func testDiscoverSystemHFModelsInfersVisionMetadata() throws {
let base = try makeTempHFCache()
let snapshotDir = try makeHFSnapshot(base: base, repoId: "example/vision-model")
try writeJSON(
[
"text_config": ["max_position_embeddings": 262144],
"vision_config": ["hidden_size": 768],
],
to: snapshotDir.appendingPathComponent("config.json")
)
try writeJSON(
["processor_class": "Qwen3VLProcessor"],
to: snapshotDir.appendingPathComponent("tokenizer_config.json")
)
try Data(repeating: 0x33, count: 12).write(to: snapshotDir.appendingPathComponent("processor_config.json"))
try Data(repeating: 0x44, count: 8).write(to: snapshotDir.appendingPathComponent("model.safetensors.index.json"))
let discovered = LocalModelResolver.discoverSystemHFModels(in: base)
let model = try XCTUnwrap(discovered.first)
XCTAssertEqual(model.repoId, "example/vision-model")
XCTAssertEqual(model.contextLength, 262144)
XCTAssertTrue(model.supportsImages)
XCTAssertEqual(model.loaderKinds, [.vlm, .llm])
}
func testMergedCatalogKeepsCuratedModelsAndAddsCustomLocalModels() {
let localModels = [
LocalModelResolver.LocalModelInfo(
repoId: "mlx-community/gemma-3-4b-it-4bit",
directory: URL(fileURLWithPath: "/tmp/gemma"),
sizeBytes: 1024,
contextLength: 128000,
loaderKinds: [.vlm, .llm],
supportsImages: true
),
LocalModelResolver.LocalModelInfo(
repoId: "custom-org/custom-model",
directory: URL(fileURLWithPath: "/tmp/custom"),
sizeBytes: 2048,
contextLength: 65536,
loaderKinds: [.llm, .vlm],
supportsImages: false
),
]
let merged = ModelConfig.mergedModels(localModels: localModels)
let gemma = merged.first(where: { $0.id == "gemma" })
let custom = merged.first(where: { $0.repoId == "custom-org/custom-model" })
XCTAssertEqual(gemma?.localSizeBytes, 1024)
XCTAssertEqual(custom?.id, "custom-org/custom-model")
XCTAssertEqual(custom?.contextLength, 65536)
XCTAssertFalse(custom?.isCurated ?? true)
}
func testResolveUnknownRepoIdCreatesRemoteCustomConfig() throws {
let config = try XCTUnwrap(ModelConfig.resolve("custom-owner/custom-repo"))
XCTAssertEqual(config.id, "custom-owner/custom-repo")
XCTAssertEqual(config.repoId, "custom-owner/custom-repo")
XCTAssertFalse(config.isCurated)
}
func testMergedCatalogAppliesSavedMetadataOverride() {
let repoId = "custom-org/override-model"
Preferences.setModelMetadataOverride(
ModelMetadataOverride(
contextLength: 123456,
primaryLoaderKind: .vlm,
supportsImages: true,
supportsTools: true
),
forRepoId: repoId
)
defer {
Preferences.removeModelMetadataOverride(forRepoId: repoId)
}
let localModels = [
LocalModelResolver.LocalModelInfo(
repoId: repoId,
directory: URL(fileURLWithPath: "/tmp/custom-override"),
sizeBytes: 2048,
contextLength: 65536,
loaderKinds: [.llm, .vlm],
supportsImages: false
),
]
let merged = ModelConfig.mergedModels(localModels: localModels)
let overridden = merged.first(where: { $0.repoId == repoId })
XCTAssertEqual(overridden?.contextLength, 123456)
XCTAssertEqual(overridden?.primaryLoaderKind, .vlm)
XCTAssertTrue(overridden?.supportsImages ?? false)
XCTAssertTrue(overridden?.supportsTools ?? false)
}
func testResolveUnknownRepoIdUsesSavedMetadataOverride() throws {
let repoId = "custom-owner/custom-repo-with-override"
Preferences.setModelMetadataOverride(
ModelMetadataOverride(
contextLength: 8192,
primaryLoaderKind: .llm,
supportsImages: false,
supportsTools: true
),
forRepoId: repoId
)
defer {
Preferences.removeModelMetadataOverride(forRepoId: repoId)
}
let config = try XCTUnwrap(ModelConfig.resolve(repoId))
XCTAssertEqual(config.contextLength, 8192)
XCTAssertEqual(config.primaryLoaderKind, .llm)
XCTAssertFalse(config.supportsImages)
XCTAssertTrue(config.supportsTools)
}
private func makeTempHFCache() throws -> URL {
let root = FileManager.default.temporaryDirectory
.appendingPathComponent(UUID().uuidString, isDirectory: true)
try FileManager.default.createDirectory(at: root, withIntermediateDirectories: true)
addTeardownBlock {
try? FileManager.default.removeItem(at: root)
}
return root
}
private func makeHFSnapshot(base: URL, repoId: String, hash: String = "abc123") throws -> URL {
let slug = repoId.replacingOccurrences(of: "/", with: "--")
let snapshotDir = base
.appendingPathComponent("models--\(slug)", isDirectory: true)
.appendingPathComponent("snapshots", isDirectory: true)
.appendingPathComponent(hash, isDirectory: true)
try FileManager.default.createDirectory(at: snapshotDir, withIntermediateDirectories: true)
return snapshotDir
}
private func writeJSON(_ object: Any, to url: URL) throws {
let data = try JSONSerialization.data(withJSONObject: object, options: [.prettyPrinted, .sortedKeys])
try data.write(to: url)
}
}