feat: implement phase 2 of session-cache-upgrade.md
This commit is contained in:
289
MLXServerTests/Server/ModelBackedInferenceValidationTests.swift
Normal file
289
MLXServerTests/Server/ModelBackedInferenceValidationTests.swift
Normal file
@@ -0,0 +1,289 @@
|
||||
import Foundation
|
||||
import Hub
|
||||
import MLXLMCommon
|
||||
import MLXVLM
|
||||
import XCTest
|
||||
@testable import MLX_Server
|
||||
|
||||
final class ModelBackedInferenceValidationTests: XCTestCase {
|
||||
private let onePixelPNGBase64 = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8BQDwAFgwJ/lRyXWQAAAABJRU5ErkJggg=="
|
||||
|
||||
func testPromptBuilderTokenizationMatchesLegacyShapingOnLocalGemma() async throws {
|
||||
let container = try await localGemmaContainer()
|
||||
let engine = InferenceEngine(container: container)
|
||||
let request = APIChatCompletionRequest(
|
||||
model: "gemma",
|
||||
messages: [
|
||||
APIChatMessage(role: "system", content: .text("You are concise."), name: nil, tool_calls: nil, tool_call_id: nil),
|
||||
APIChatMessage(
|
||||
role: "user",
|
||||
content: .parts([
|
||||
APIContentPart(type: "text", text: "What is in this image?", image_url: nil),
|
||||
APIContentPart(type: "image_url", text: nil, image_url: APIImageURL(url: "data:image/png;base64,\(onePixelPNGBase64)", detail: nil))
|
||||
]),
|
||||
name: nil,
|
||||
tool_calls: nil,
|
||||
tool_call_id: nil
|
||||
)
|
||||
],
|
||||
temperature: nil,
|
||||
top_p: nil,
|
||||
max_tokens: nil,
|
||||
stream: nil,
|
||||
stop: nil,
|
||||
tools: nil,
|
||||
tool_choice: nil,
|
||||
frequency_penalty: nil,
|
||||
presence_penalty: nil,
|
||||
n: nil
|
||||
)
|
||||
|
||||
let prepared = PromptBuilder.build(from: request, modelId: "mlx-community/gemma-3-4b-it-4bit", thinkingEnabled: false)
|
||||
let legacy = legacyBuild(from: request, modelId: "mlx-community/gemma-3-4b-it-4bit", thinkingEnabled: false)
|
||||
|
||||
let preparedInference = try await engine.prepare(prepared.userInput)
|
||||
let legacyInference = try await engine.prepare(legacy.userInput)
|
||||
|
||||
XCTAssertEqual(preparedInference.tokens, legacyInference.tokens)
|
||||
}
|
||||
|
||||
func testInferenceEngineMatchesChatSessionOnLocalGemma() async throws {
|
||||
let container = try await localGemmaContainer()
|
||||
let engine = InferenceEngine(container: container)
|
||||
let parameters = GenerateParameters(maxTokens: 1, temperature: 0)
|
||||
let request = APIChatCompletionRequest(
|
||||
model: "gemma",
|
||||
messages: [
|
||||
APIChatMessage(role: "user", content: .text("Say hello in one word."), name: nil, tool_calls: nil, tool_call_id: nil)
|
||||
],
|
||||
temperature: nil,
|
||||
top_p: nil,
|
||||
max_tokens: nil,
|
||||
stream: nil,
|
||||
stop: nil,
|
||||
tools: nil,
|
||||
tool_choice: nil,
|
||||
frequency_penalty: nil,
|
||||
presence_penalty: nil,
|
||||
n: nil
|
||||
)
|
||||
|
||||
let prepared = PromptBuilder.build(from: request, modelId: "mlx-community/gemma-3-4b-it-4bit", thinkingEnabled: true)
|
||||
let preparedInference = try await engine.prepare(prepared.userInput)
|
||||
let handle = try await engine.stream(
|
||||
InferenceEngine.InferenceRequest(
|
||||
input: preparedInference.lmInput,
|
||||
tokens: preparedInference.tokens,
|
||||
parameters: parameters,
|
||||
cachedKV: nil,
|
||||
cachedTokenCount: 0
|
||||
),
|
||||
cancellation: CancellationToken()
|
||||
)
|
||||
|
||||
let engineResult = await collectEngineOutput(handle.stream)
|
||||
|
||||
let session = ChatSession(container, generateParameters: parameters)
|
||||
let sessionResult = try await collectSessionOutput(
|
||||
session.streamDetails(to: "Say hello in one word.", images: [], videos: [])
|
||||
)
|
||||
|
||||
XCTAssertEqual(engineResult.text, sessionResult.text)
|
||||
XCTAssertEqual(engineResult.promptTokenCount, sessionResult.promptTokenCount)
|
||||
}
|
||||
|
||||
private func localGemmaContainer() async throws -> ModelContainer {
|
||||
try await LocalGemmaFixture.shared.container()
|
||||
}
|
||||
|
||||
private func legacyBuild(
|
||||
from request: APIChatCompletionRequest,
|
||||
modelId: String,
|
||||
thinkingEnabled: Bool
|
||||
) -> PromptBuilder.PreparedPrompt {
|
||||
var instructions = ""
|
||||
for msg in request.messages where msg.role == "system" {
|
||||
let text = msg.content?.textContent ?? ""
|
||||
if !text.isEmpty {
|
||||
if !instructions.isEmpty { instructions += "\n\n" }
|
||||
instructions += text
|
||||
}
|
||||
}
|
||||
|
||||
if let tools = request.tools, !tools.isEmpty {
|
||||
let toolSystemPrompt = ToolPromptBuilder.buildSystemPrompt(tools: tools, modelId: modelId)
|
||||
if !instructions.isEmpty { instructions += "\n\n" }
|
||||
instructions += toolSystemPrompt
|
||||
}
|
||||
|
||||
let isQwen = modelId.lowercased().contains("qwen")
|
||||
var chatMessages: [Chat.Message] = []
|
||||
var messageSignatures: [UInt64] = []
|
||||
var estimatedBytes = instructions.utf8.count
|
||||
var containsImages = false
|
||||
|
||||
for msg in request.messages where msg.role != "system" {
|
||||
let role: Chat.Message.Role = switch msg.role {
|
||||
case "assistant": .assistant
|
||||
case "tool": .user
|
||||
default: .user
|
||||
}
|
||||
|
||||
var text = msg.content?.textContent ?? ""
|
||||
if msg.role == "tool", !isQwen {
|
||||
text = "```tool_output\n\(text)\n```"
|
||||
}
|
||||
|
||||
if msg.role == "assistant", let toolCalls = msg.tool_calls, !toolCalls.isEmpty {
|
||||
let formattedCalls = isQwen
|
||||
? ToolPromptBuilder.formatQwenToolCalls(toolCalls)
|
||||
: ToolPromptBuilder.formatGemmaToolCalls(toolCalls)
|
||||
text = (text.isEmpty ? "" : text + "\n") + formattedCalls
|
||||
}
|
||||
|
||||
let imageURLs = msg.content?.imageURLs ?? []
|
||||
var messageImages: [UserInput.Image] = []
|
||||
var messageImageBytes = 0
|
||||
for urlString in imageURLs {
|
||||
if let decoded = ImageDecoder.decode(urlString) {
|
||||
messageImages.append(decoded.image)
|
||||
messageImageBytes += decoded.estimatedBytes
|
||||
}
|
||||
}
|
||||
|
||||
containsImages = containsImages || !messageImages.isEmpty
|
||||
chatMessages.append(Chat.Message(role: role, content: text, images: messageImages))
|
||||
messageSignatures.append(messageSignature(role: role, content: text, imageURLs: imageURLs))
|
||||
estimatedBytes += text.utf8.count + messageImageBytes
|
||||
}
|
||||
|
||||
let additionalContext: [String: any Sendable]? = thinkingEnabled
|
||||
? nil
|
||||
: ["enable_thinking": false]
|
||||
|
||||
let allImages = chatMessages.flatMap(\.images)
|
||||
let allMessages = (instructions.isEmpty ? [] : [Chat.Message(role: .system, content: instructions)]) + chatMessages
|
||||
let userInput = UserInput(
|
||||
prompt: .chat(allMessages),
|
||||
images: allImages,
|
||||
videos: [],
|
||||
tools: nil,
|
||||
additionalContext: additionalContext
|
||||
)
|
||||
|
||||
return PromptBuilder.PreparedPrompt(
|
||||
instructions: instructions,
|
||||
chatMessages: chatMessages,
|
||||
messageSignatures: messageSignatures,
|
||||
estimatedBytes: estimatedBytes,
|
||||
estimatedPromptTokens: (instructions.count + chatMessages.reduce(0) { $0 + $1.content.count }) * 10 / 35,
|
||||
containsImages: containsImages,
|
||||
additionalContext: additionalContext,
|
||||
userInput: userInput
|
||||
)
|
||||
}
|
||||
|
||||
private func messageSignature(role: Chat.Message.Role, content: String, imageURLs: [String]) -> UInt64 {
|
||||
var hash: UInt64 = 14_695_981_039_346_656_037
|
||||
|
||||
func mix(_ text: String) {
|
||||
for byte in text.utf8 {
|
||||
hash ^= UInt64(byte)
|
||||
hash &*= 1_099_511_628_211
|
||||
}
|
||||
}
|
||||
|
||||
switch role {
|
||||
case .assistant:
|
||||
mix("assistant")
|
||||
case .system:
|
||||
mix("system")
|
||||
case .user:
|
||||
mix("user")
|
||||
@unknown default:
|
||||
mix("unknown")
|
||||
}
|
||||
mix("|")
|
||||
mix(content)
|
||||
for imageURL in imageURLs {
|
||||
mix("|")
|
||||
mix(imageURL)
|
||||
}
|
||||
|
||||
return hash
|
||||
}
|
||||
|
||||
private func collectEngineOutput(_ stream: AsyncStream<Generation>) async -> GenerationResult {
|
||||
var text = ""
|
||||
var promptTokenCount = 0
|
||||
for await generation in stream {
|
||||
switch generation {
|
||||
case .chunk(let chunk):
|
||||
text += chunk
|
||||
case .info(let info):
|
||||
promptTokenCount = info.promptTokenCount
|
||||
case .toolCall:
|
||||
break
|
||||
}
|
||||
}
|
||||
return GenerationResult(text: text, promptTokenCount: promptTokenCount)
|
||||
}
|
||||
|
||||
private func collectSessionOutput(_ stream: AsyncThrowingStream<Generation, any Error>) async throws -> GenerationResult {
|
||||
var text = ""
|
||||
var promptTokenCount = 0
|
||||
for try await generation in stream {
|
||||
switch generation {
|
||||
case .chunk(let chunk):
|
||||
text += chunk
|
||||
case .info(let info):
|
||||
promptTokenCount = info.promptTokenCount
|
||||
case .toolCall:
|
||||
break
|
||||
}
|
||||
}
|
||||
return GenerationResult(text: text, promptTokenCount: promptTokenCount)
|
||||
}
|
||||
}
|
||||
|
||||
private struct GenerationResult {
|
||||
let text: String
|
||||
let promptTokenCount: Int
|
||||
}
|
||||
|
||||
private actor LocalGemmaFixture {
|
||||
static let shared = LocalGemmaFixture()
|
||||
|
||||
private var task: Task<ModelContainer, Error>?
|
||||
|
||||
func container() async throws -> ModelContainer {
|
||||
if let task {
|
||||
return try await task.value
|
||||
}
|
||||
|
||||
guard let config = ModelConfig.resolve("gemma") else {
|
||||
throw XCTSkip("Gemma model config is unavailable")
|
||||
}
|
||||
guard let localDir = LocalModelResolver.resolve(repoId: config.repoId) else {
|
||||
throw XCTSkip("Local gemma cache is unavailable")
|
||||
}
|
||||
|
||||
let loadTask = Task<ModelContainer, Error> {
|
||||
let cachesDir = FileManager.default.urls(for: .cachesDirectory, in: .userDomainMask).first
|
||||
let hub = HubApi(downloadBase: cachesDir, cache: nil)
|
||||
return try await VLMModelFactory.shared.loadContainer(
|
||||
hub: hub,
|
||||
configuration: ModelConfiguration(directory: localDir),
|
||||
progressHandler: { _ in }
|
||||
)
|
||||
}
|
||||
task = loadTask
|
||||
|
||||
do {
|
||||
return try await loadTask.value
|
||||
} catch {
|
||||
task = nil
|
||||
throw error
|
||||
}
|
||||
}
|
||||
}
|
||||
288
MLXServerTests/Server/PromptBuilderTests.swift
Normal file
288
MLXServerTests/Server/PromptBuilderTests.swift
Normal file
@@ -0,0 +1,288 @@
|
||||
import XCTest
|
||||
import MLXLMCommon
|
||||
@testable import MLX_Server
|
||||
|
||||
final class PromptBuilderTests: XCTestCase {
|
||||
private let onePixelPNGBase64 = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8BQDwAFgwJ/lRyXWQAAAABJRU5ErkJggg=="
|
||||
|
||||
func testBuildMatchesLegacyAPIServerShapingForGemma() {
|
||||
let toolCall = APIToolCall(
|
||||
id: "call_weather",
|
||||
function: APIFunctionCall(name: "weather", arguments: "{\"city\":\"Berlin\"}")
|
||||
)
|
||||
let request = APIChatCompletionRequest(
|
||||
model: "gemma",
|
||||
messages: [
|
||||
APIChatMessage(role: "system", content: .text("System 1"), name: nil, tool_calls: nil, tool_call_id: nil),
|
||||
APIChatMessage(role: "system", content: .text("System 2"), name: nil, tool_calls: nil, tool_call_id: nil),
|
||||
APIChatMessage(role: "assistant", content: .text("Let me check"), name: nil, tool_calls: [toolCall], tool_call_id: nil),
|
||||
APIChatMessage(
|
||||
role: "tool",
|
||||
content: .parts([
|
||||
APIContentPart(type: "text", text: "{\"temp\":19}", image_url: nil),
|
||||
APIContentPart(type: "image_url", text: nil, image_url: APIImageURL(url: "data:image/png;base64,\(onePixelPNGBase64)", detail: nil))
|
||||
]),
|
||||
name: nil,
|
||||
tool_calls: nil,
|
||||
tool_call_id: "call_weather"
|
||||
),
|
||||
APIChatMessage(role: "user", content: .text("Thanks"), name: nil, tool_calls: nil, tool_call_id: nil)
|
||||
],
|
||||
temperature: nil,
|
||||
top_p: nil,
|
||||
max_tokens: nil,
|
||||
stream: nil,
|
||||
stop: nil,
|
||||
tools: [
|
||||
APIToolDefinition(
|
||||
type: "function",
|
||||
function: APIFunctionDefinition(
|
||||
name: "weather",
|
||||
description: "Lookup weather",
|
||||
parameters: ["type": AnyCodable("object")]
|
||||
)
|
||||
)
|
||||
],
|
||||
tool_choice: nil,
|
||||
frequency_penalty: nil,
|
||||
presence_penalty: nil,
|
||||
n: nil
|
||||
)
|
||||
|
||||
let prepared = PromptBuilder.build(from: request, modelId: "mlx-community/gemma-3-4b-it-4bit", thinkingEnabled: false)
|
||||
let legacy = legacyBuild(from: request, modelId: "mlx-community/gemma-3-4b-it-4bit", thinkingEnabled: false)
|
||||
|
||||
XCTAssertEqual(prepared.instructions, legacy.instructions)
|
||||
XCTAssertEqual(prepared.chatMessages.map { $0.role.roleLabel }, legacy.chatMessages.map { $0.role.roleLabel })
|
||||
XCTAssertEqual(prepared.chatMessages.map(\.content), legacy.chatMessages.map(\.content))
|
||||
XCTAssertEqual(prepared.chatMessages.map { $0.images.count }, legacy.chatMessages.map { $0.images.count })
|
||||
XCTAssertEqual(prepared.messageSignatures, legacy.messageSignatures)
|
||||
XCTAssertEqual(prepared.estimatedBytes, legacy.estimatedBytes)
|
||||
XCTAssertEqual(prepared.estimatedPromptTokens, legacy.estimatedPromptTokens)
|
||||
XCTAssertEqual(prepared.containsImages, legacy.containsImages)
|
||||
XCTAssertEqual(prepared.additionalContext?["enable_thinking"] as? Bool, legacy.additionalContext?["enable_thinking"] as? Bool)
|
||||
}
|
||||
|
||||
func testBuildAggregatesInstructionsAndMessages() {
|
||||
let request = APIChatCompletionRequest(
|
||||
model: "gemma",
|
||||
messages: [
|
||||
APIChatMessage(role: "system", content: .text("Base system"), name: nil, tool_calls: nil, tool_call_id: nil),
|
||||
APIChatMessage(role: "system", content: .text("Extra system"), name: nil, tool_calls: nil, tool_call_id: nil),
|
||||
APIChatMessage(role: "user", content: .text("Hello"), name: nil, tool_calls: nil, tool_call_id: nil)
|
||||
],
|
||||
temperature: nil,
|
||||
top_p: nil,
|
||||
max_tokens: nil,
|
||||
stream: nil,
|
||||
stop: nil,
|
||||
tools: nil,
|
||||
tool_choice: nil,
|
||||
frequency_penalty: nil,
|
||||
presence_penalty: nil,
|
||||
n: nil
|
||||
)
|
||||
|
||||
let prepared = PromptBuilder.build(from: request, modelId: "mlx-community/gemma-3-4b-it-4bit", thinkingEnabled: false)
|
||||
|
||||
XCTAssertEqual(prepared.instructions, "Base system\n\nExtra system")
|
||||
XCTAssertEqual(prepared.chatMessages.count, 1)
|
||||
XCTAssertEqual(prepared.chatMessages[0].content, "Hello")
|
||||
XCTAssertEqual(prepared.messageSignatures.count, 1)
|
||||
XCTAssertFalse(prepared.containsImages)
|
||||
XCTAssertNotNil(prepared.additionalContext)
|
||||
XCTAssertGreaterThan(prepared.estimatedPromptTokens, 0)
|
||||
}
|
||||
|
||||
func testBuildFormatsAssistantToolCallsForQwen() {
|
||||
let toolCall = APIToolCall(
|
||||
id: "call_1",
|
||||
function: APIFunctionCall(name: "weather", arguments: "{\"city\":\"Berlin\"}")
|
||||
)
|
||||
let request = APIChatCompletionRequest(
|
||||
model: "qwen",
|
||||
messages: [
|
||||
APIChatMessage(role: "assistant", content: .text("Let me check."), name: nil, tool_calls: [toolCall], tool_call_id: nil)
|
||||
],
|
||||
temperature: nil,
|
||||
top_p: nil,
|
||||
max_tokens: nil,
|
||||
stream: nil,
|
||||
stop: nil,
|
||||
tools: nil,
|
||||
tool_choice: nil,
|
||||
frequency_penalty: nil,
|
||||
presence_penalty: nil,
|
||||
n: nil
|
||||
)
|
||||
|
||||
let prepared = PromptBuilder.build(from: request, modelId: "mlx-community/Qwen3-VL-4B-Instruct-4bit", thinkingEnabled: true)
|
||||
|
||||
XCTAssertEqual(prepared.chatMessages.count, 1)
|
||||
XCTAssertTrue(prepared.chatMessages[0].content.contains("Let me check."))
|
||||
XCTAssertTrue(prepared.chatMessages[0].content.contains("<tool_call>"))
|
||||
XCTAssertNil(prepared.additionalContext)
|
||||
}
|
||||
|
||||
func testBuildWrapsGemmaToolOutputsAndTracksImages() {
|
||||
let request = APIChatCompletionRequest(
|
||||
model: "gemma",
|
||||
messages: [
|
||||
APIChatMessage(
|
||||
role: "tool",
|
||||
content: .parts([
|
||||
APIContentPart(type: "text", text: "{\"ok\":true}", image_url: nil),
|
||||
APIContentPart(type: "image_url", text: nil, image_url: APIImageURL(url: "data:image/png;base64,\(onePixelPNGBase64)", detail: nil))
|
||||
]),
|
||||
name: nil,
|
||||
tool_calls: nil,
|
||||
tool_call_id: "call_1"
|
||||
)
|
||||
],
|
||||
temperature: nil,
|
||||
top_p: nil,
|
||||
max_tokens: nil,
|
||||
stream: nil,
|
||||
stop: nil,
|
||||
tools: nil,
|
||||
tool_choice: nil,
|
||||
frequency_penalty: nil,
|
||||
presence_penalty: nil,
|
||||
n: nil
|
||||
)
|
||||
|
||||
let prepared = PromptBuilder.build(from: request, modelId: "mlx-community/gemma-3-4b-it-4bit", thinkingEnabled: true)
|
||||
|
||||
XCTAssertTrue(prepared.chatMessages[0].content.contains("```tool_output"))
|
||||
XCTAssertTrue(prepared.containsImages)
|
||||
XCTAssertEqual(prepared.chatMessages[0].images.count, 1)
|
||||
XCTAssertGreaterThan(prepared.estimatedBytes, prepared.chatMessages[0].content.utf8.count)
|
||||
}
|
||||
|
||||
private func legacyBuild(
|
||||
from request: APIChatCompletionRequest,
|
||||
modelId: String,
|
||||
thinkingEnabled: Bool
|
||||
) -> PromptBuilder.PreparedPrompt {
|
||||
var instructions = ""
|
||||
for msg in request.messages where msg.role == "system" {
|
||||
let text = msg.content?.textContent ?? ""
|
||||
if !text.isEmpty {
|
||||
if !instructions.isEmpty { instructions += "\n\n" }
|
||||
instructions += text
|
||||
}
|
||||
}
|
||||
|
||||
if let tools = request.tools, !tools.isEmpty {
|
||||
let toolSystemPrompt = ToolPromptBuilder.buildSystemPrompt(tools: tools, modelId: modelId)
|
||||
if !instructions.isEmpty { instructions += "\n\n" }
|
||||
instructions += toolSystemPrompt
|
||||
}
|
||||
|
||||
let isQwen = modelId.lowercased().contains("qwen")
|
||||
var chatMessages: [Chat.Message] = []
|
||||
var messageSignatures: [UInt64] = []
|
||||
var estimatedBytes = instructions.utf8.count
|
||||
var containsImages = false
|
||||
|
||||
for msg in request.messages where msg.role != "system" {
|
||||
let role: Chat.Message.Role = switch msg.role {
|
||||
case "assistant": .assistant
|
||||
case "tool": .user
|
||||
default: .user
|
||||
}
|
||||
|
||||
var text = msg.content?.textContent ?? ""
|
||||
if msg.role == "tool", !isQwen {
|
||||
text = "```tool_output\n\(text)\n```"
|
||||
}
|
||||
|
||||
if msg.role == "assistant", let toolCalls = msg.tool_calls, !toolCalls.isEmpty {
|
||||
let formattedCalls = isQwen
|
||||
? ToolPromptBuilder.formatQwenToolCalls(toolCalls)
|
||||
: ToolPromptBuilder.formatGemmaToolCalls(toolCalls)
|
||||
text = (text.isEmpty ? "" : text + "\n") + formattedCalls
|
||||
}
|
||||
|
||||
let imageURLs = msg.content?.imageURLs ?? []
|
||||
var messageImages: [UserInput.Image] = []
|
||||
var messageImageBytes = 0
|
||||
for urlString in imageURLs {
|
||||
if let decoded = ImageDecoder.decode(urlString) {
|
||||
messageImages.append(decoded.image)
|
||||
messageImageBytes += decoded.estimatedBytes
|
||||
}
|
||||
}
|
||||
|
||||
containsImages = containsImages || !messageImages.isEmpty
|
||||
chatMessages.append(Chat.Message(role: role, content: text, images: messageImages))
|
||||
messageSignatures.append(messageSignature(role: role, content: text, imageURLs: imageURLs))
|
||||
estimatedBytes += text.utf8.count + messageImageBytes
|
||||
}
|
||||
|
||||
let additionalContext: [String: any Sendable]? = thinkingEnabled
|
||||
? nil
|
||||
: ["enable_thinking": false]
|
||||
|
||||
let allImages = chatMessages.flatMap(\.images)
|
||||
let userInput = UserInput(
|
||||
prompt: .chat((instructions.isEmpty ? [] : [Chat.Message(role: .system, content: instructions)]) + chatMessages),
|
||||
images: allImages,
|
||||
videos: [],
|
||||
tools: nil,
|
||||
additionalContext: additionalContext
|
||||
)
|
||||
|
||||
return PromptBuilder.PreparedPrompt(
|
||||
instructions: instructions,
|
||||
chatMessages: chatMessages,
|
||||
messageSignatures: messageSignatures,
|
||||
estimatedBytes: estimatedBytes,
|
||||
estimatedPromptTokens: (instructions.count + chatMessages.reduce(0) { $0 + $1.content.count }) * 10 / 35,
|
||||
containsImages: containsImages,
|
||||
additionalContext: additionalContext,
|
||||
userInput: userInput
|
||||
)
|
||||
}
|
||||
|
||||
private func messageSignature(role: Chat.Message.Role, content: String, imageURLs: [String]) -> UInt64 {
|
||||
var hash: UInt64 = 14_695_981_039_346_656_037
|
||||
|
||||
func mix(_ text: String) {
|
||||
for byte in text.utf8 {
|
||||
hash ^= UInt64(byte)
|
||||
hash &*= 1_099_511_628_211
|
||||
}
|
||||
}
|
||||
|
||||
switch role {
|
||||
case .assistant:
|
||||
mix("assistant")
|
||||
case .system:
|
||||
mix("system")
|
||||
case .user:
|
||||
mix("user")
|
||||
@unknown default:
|
||||
mix("unknown")
|
||||
}
|
||||
mix("|")
|
||||
mix(content)
|
||||
for imageURL in imageURLs {
|
||||
mix("|")
|
||||
mix(imageURL)
|
||||
}
|
||||
|
||||
return hash
|
||||
}
|
||||
}
|
||||
|
||||
private extension Chat.Message.Role {
|
||||
var roleLabel: String {
|
||||
switch self {
|
||||
case .assistant: "assistant"
|
||||
case .system: "system"
|
||||
case .user: "user"
|
||||
@unknown default: "unknown"
|
||||
}
|
||||
}
|
||||
}
|
||||
130
MLXServerTests/Server/TokenPrefixCacheTests.swift
Normal file
130
MLXServerTests/Server/TokenPrefixCacheTests.swift
Normal file
@@ -0,0 +1,130 @@
|
||||
import Foundation
|
||||
import XCTest
|
||||
import MLXLMCommon
|
||||
@testable import MLX_Server
|
||||
|
||||
final class TokenPrefixCacheTests: XCTestCase {
|
||||
func testStoreAndLookupRemovesCheckedOutEntry() {
|
||||
var now = Date(timeIntervalSince1970: 100)
|
||||
let cache = TokenPrefixCache(
|
||||
memoryBudgetBytes: 10_000,
|
||||
estimateBytesProvider: { _ in 1_024 },
|
||||
nowProvider: { now }
|
||||
)
|
||||
|
||||
let entryId = UUID()
|
||||
cache.store(entryId: entryId, kvCache: [], cacheKey: [1, 2, 3], modelId: "model")
|
||||
|
||||
XCTAssertEqual(cache.snapshot().totalEntries, 1)
|
||||
|
||||
let lease = cache.lookup(cacheKey: [1, 2, 3, 4], modelId: "model")
|
||||
|
||||
XCTAssertTrue(lease.isHit)
|
||||
XCTAssertEqual(lease.entryId, entryId)
|
||||
XCTAssertEqual(lease.matchedTokenCount, 3)
|
||||
XCTAssertNotNil(lease.kvCache)
|
||||
XCTAssertEqual(cache.snapshot().totalEntries, 0)
|
||||
}
|
||||
|
||||
func testLookupPrefersDeepestPrefixMatch() {
|
||||
var now = Date(timeIntervalSince1970: 100)
|
||||
let cache = TokenPrefixCache(
|
||||
memoryBudgetBytes: 10_000,
|
||||
estimateBytesProvider: { _ in 1_024 },
|
||||
nowProvider: { now }
|
||||
)
|
||||
|
||||
cache.store(entryId: UUID(), kvCache: [], cacheKey: [1, 2], modelId: "model")
|
||||
now.addTimeInterval(1)
|
||||
let deepId = UUID()
|
||||
cache.store(entryId: deepId, kvCache: [], cacheKey: [1, 2, 3], modelId: "model")
|
||||
|
||||
let lease = cache.lookup(cacheKey: [1, 2, 3, 4], modelId: "model")
|
||||
|
||||
XCTAssertTrue(lease.isHit)
|
||||
XCTAssertEqual(lease.entryId, deepId)
|
||||
XCTAssertEqual(lease.matchedTokenCount, 3)
|
||||
}
|
||||
|
||||
func testEvictsLeastRecentlyUsedEntryWhenOverBudget() {
|
||||
var now = Date(timeIntervalSince1970: 100)
|
||||
let cache = TokenPrefixCache(
|
||||
memoryBudgetBytes: 2_048,
|
||||
estimateBytesProvider: { _ in 1_024 },
|
||||
nowProvider: { now }
|
||||
)
|
||||
|
||||
let firstId = UUID()
|
||||
cache.store(entryId: firstId, kvCache: [], cacheKey: [1], modelId: "model")
|
||||
now.addTimeInterval(1)
|
||||
cache.store(entryId: UUID(), kvCache: [], cacheKey: [2], modelId: "model")
|
||||
now.addTimeInterval(1)
|
||||
cache.store(entryId: UUID(), kvCache: [], cacheKey: [3], modelId: "model")
|
||||
|
||||
let firstLookup = cache.lookup(cacheKey: [1], modelId: "model")
|
||||
let secondLookup = cache.lookup(cacheKey: [2], modelId: "model")
|
||||
let thirdLookup = cache.lookup(cacheKey: [3], modelId: "model")
|
||||
|
||||
XCTAssertFalse(firstLookup.isHit)
|
||||
XCTAssertTrue(secondLookup.isHit)
|
||||
XCTAssertTrue(thirdLookup.isHit)
|
||||
}
|
||||
|
||||
func testSnapshotPrunesExpiredEntries() {
|
||||
var now = Date(timeIntervalSince1970: 100)
|
||||
let cache = TokenPrefixCache(
|
||||
memoryBudgetBytes: 10_000,
|
||||
idleTTL: 5,
|
||||
estimateBytesProvider: { _ in 1_024 },
|
||||
nowProvider: { now }
|
||||
)
|
||||
|
||||
cache.store(entryId: UUID(), kvCache: [], cacheKey: [1, 2, 3], modelId: "model")
|
||||
XCTAssertEqual(cache.snapshot().totalEntries, 1)
|
||||
|
||||
now.addTimeInterval(10)
|
||||
let snapshot = cache.snapshot()
|
||||
|
||||
XCTAssertEqual(snapshot.totalEntries, 0)
|
||||
XCTAssertGreaterThanOrEqual(snapshot.totalEvictions, 1)
|
||||
}
|
||||
|
||||
func testLookupPrunesTrieNodesForRemovedBranch() {
|
||||
let cache = TokenPrefixCache(
|
||||
memoryBudgetBytes: 10_000,
|
||||
estimateBytesProvider: { _ in 1_024 }
|
||||
)
|
||||
|
||||
cache.store(entryId: UUID(), kvCache: [], cacheKey: [1, 2, 3], modelId: "model")
|
||||
cache.store(entryId: UUID(), kvCache: [], cacheKey: [1, 2, 4], modelId: "model")
|
||||
|
||||
XCTAssertEqual(cache.debugTrieNodeCount(), 5)
|
||||
|
||||
_ = cache.lookup(cacheKey: [1, 2, 3], modelId: "model")
|
||||
|
||||
XCTAssertEqual(cache.debugTrieNodeCount(), 4)
|
||||
|
||||
_ = cache.lookup(cacheKey: [1, 2, 4], modelId: "model")
|
||||
|
||||
XCTAssertEqual(cache.debugTrieNodeCount(), 1)
|
||||
}
|
||||
|
||||
func testSnapshotReportsHitRateAndTokenTotals() {
|
||||
let cache = TokenPrefixCache(
|
||||
memoryBudgetBytes: 10_000,
|
||||
estimateBytesProvider: { _ in 2_048 }
|
||||
)
|
||||
|
||||
cache.store(entryId: UUID(), kvCache: [], cacheKey: [10, 20, 30], modelId: "model")
|
||||
_ = cache.lookup(cacheKey: [10, 20, 30, 40], modelId: "model")
|
||||
_ = cache.lookup(cacheKey: [99], modelId: "model")
|
||||
|
||||
let snapshot = cache.snapshot()
|
||||
|
||||
XCTAssertEqual(snapshot.totalHits, 1)
|
||||
XCTAssertEqual(snapshot.totalMisses, 1)
|
||||
XCTAssertEqual(snapshot.hitRate, 50, accuracy: 0.001)
|
||||
XCTAssertEqual(snapshot.totalCachedTokens, 0)
|
||||
XCTAssertEqual(snapshot.estimatedBytes, 0)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user