feat: finally nailed down phases 1-4

This commit is contained in:
2026-03-20 12:05:24 +01:00
parent 5aed0107c6
commit 1f12fac5e2
9 changed files with 492 additions and 19 deletions

View File

@@ -3,6 +3,61 @@ import XCTest
@testable import MLX_Server
final class APIServerRewriteTests: XCTestCase {
func testQwenNonStreamingChatCompletionCachesAndReusesPrompt() async throws {
let harness = try await makeHarness(initialModelId: "qwen")
defer { harness.stop() }
let lookups = LookupEventCollector()
APIServer.debugLookupEventHandler = { event in
Task {
await lookups.record(event)
}
}
defer {
APIServer.debugLookupEventHandler = nil
}
let request = APIChatCompletionRequest(
model: "qwen",
messages: [
APIChatMessage(role: "user", content: .text("Reply with exactly one short word."), name: nil, tool_calls: nil, tool_call_id: nil)
],
temperature: 0,
top_p: 1,
max_tokens: 1,
stream: false,
stop: nil,
tools: nil,
tool_choice: nil,
frequency_penalty: nil,
presence_penalty: nil,
n: nil
)
let firstResponse = try await sendChatCompletion(request, port: harness.port)
XCTAssertEqual(firstResponse.choices.count, 1)
try await waitUntil(timeoutSeconds: 5) {
let snapshot = TokenPrefixCache.shared.snapshot()
return snapshot.totalEntries > 0 && snapshot.entries.allSatisfy { $0.modelId == "qwen" }
}
let firstSnapshot = TokenPrefixCache.shared.snapshot()
_ = try await sendChatCompletion(request, port: harness.port)
try await waitUntil(timeoutSeconds: 5) {
let events = await lookups.events()
return events.count >= 2 && TokenPrefixCache.shared.snapshot().totalHits > firstSnapshot.totalHits
}
let secondSnapshot = TokenPrefixCache.shared.snapshot()
let events = await lookups.events()
let secondLookup = try XCTUnwrap(events.last)
XCTAssertGreaterThan(secondSnapshot.totalHits, firstSnapshot.totalHits)
XCTAssertTrue(secondLookup.isHit)
XCTAssertGreaterThan(secondLookup.matchedTokenCount, 0)
}
func testHealthAndModelsEndpointsReturnExpectedPayloads() async throws {
let harness = try await makeHarness()
defer { harness.stop() }
@@ -69,6 +124,16 @@ final class APIServerRewriteTests: XCTestCase {
let harness = try await makeHarness()
defer { harness.stop() }
let lookups = LookupEventCollector()
APIServer.debugLookupEventHandler = { event in
Task {
await lookups.record(event)
}
}
defer {
APIServer.debugLookupEventHandler = nil
}
let request = APIChatCompletionRequest(
model: "gemma",
messages: [
@@ -89,10 +154,15 @@ final class APIServerRewriteTests: XCTestCase {
_ = try await sendChatCompletion(request, port: harness.port)
_ = try await sendChatCompletion(request, port: harness.port)
let live = LiveCounters.shared.snapshot()
XCTAssertGreaterThan(live.currentCacheMatchedPromptTokens, 0)
XCTAssertEqual(live.currentCacheMatchedPromptTokens, live.promptTokens)
XCTAssertEqual(live.currentCacheRebuiltPromptTokens, 0)
try await waitUntil(timeoutSeconds: 5) {
let events = await lookups.events()
return events.count >= 2
}
let events = await lookups.events()
let secondLookup = try XCTUnwrap(events.last)
XCTAssertTrue(secondLookup.isHit)
XCTAssertEqual(secondLookup.matchedTokenCount, secondLookup.promptTokenCount)
}
func testSingleTurnContinuationProducesPartialCacheHit() async throws {
@@ -365,6 +435,91 @@ final class APIServerRewriteTests: XCTestCase {
XCTAssertEqual(live.currentCacheMatchedPromptTokens, 0)
}
func testRequestModelFieldSwapsFromGemmaToQwenAndInvalidatesGemmaCache() async throws {
let harness = try await makeHarness(initialModelId: "gemma")
defer { harness.stop() }
let lookups = LookupEventCollector()
APIServer.debugLookupEventHandler = { event in
Task {
await lookups.record(event)
}
}
defer {
APIServer.debugLookupEventHandler = nil
}
let gemmaRequest = APIChatCompletionRequest(
model: "gemma",
messages: [
APIChatMessage(role: "user", content: .text("Answer with one word: river."), name: nil, tool_calls: nil, tool_call_id: nil)
],
temperature: 0,
top_p: 1,
max_tokens: 2,
stream: false,
stop: nil,
tools: nil,
tool_choice: nil,
frequency_penalty: nil,
presence_penalty: nil,
n: nil
)
_ = try await sendChatCompletion(gemmaRequest, port: harness.port)
try await waitUntil(timeoutSeconds: 5) {
TokenPrefixCache.shared.snapshot().entries.contains(where: { $0.modelId == "gemma" })
}
let qwenRequest = APIChatCompletionRequest(
model: "qwen",
messages: [
APIChatMessage(role: "user", content: .text("Answer with one word: river."), name: nil, tool_calls: nil, tool_call_id: nil)
],
temperature: 0,
top_p: 1,
max_tokens: 2,
stream: false,
stop: nil,
tools: nil,
tool_choice: nil,
frequency_penalty: nil,
presence_penalty: nil,
n: nil
)
_ = try await sendChatCompletion(qwenRequest, port: harness.port)
try await waitUntil(timeoutSeconds: 5) {
let snapshot = TokenPrefixCache.shared.snapshot()
let modelId = await MainActor.run { harness.modelManager.currentModel?.id }
return modelId == "qwen"
&& !snapshot.entries.isEmpty
&& snapshot.entries.allSatisfy { $0.modelId == "qwen" }
}
let afterSwapSnapshot = TokenPrefixCache.shared.snapshot()
let afterSwapEvents = await lookups.events()
let firstQwenLookup = try XCTUnwrap(afterSwapEvents.last)
XCTAssertTrue(afterSwapSnapshot.entries.allSatisfy { $0.modelId == "qwen" })
XCTAssertFalse(firstQwenLookup.isHit)
XCTAssertEqual(firstQwenLookup.matchedTokenCount, 0)
_ = try await sendChatCompletion(qwenRequest, port: harness.port)
try await waitUntil(timeoutSeconds: 5) {
let events = await lookups.events()
return events.count >= 3 && TokenPrefixCache.shared.snapshot().totalHits > afterSwapSnapshot.totalHits
}
let finalSnapshot = TokenPrefixCache.shared.snapshot()
let finalEvents = await lookups.events()
let secondQwenLookup = try XCTUnwrap(finalEvents.last)
XCTAssertGreaterThan(finalSnapshot.totalHits, afterSwapSnapshot.totalHits)
XCTAssertTrue(secondQwenLookup.isHit)
XCTAssertGreaterThan(secondQwenLookup.matchedTokenCount, 0)
}
func testStreamingChatCompletionReusesCacheAcrossThreeProgressivelyLongerTurns() async throws {
let harness = try await makeHarness()
defer { harness.stop() }
@@ -775,6 +930,130 @@ final class APIServerRewriteTests: XCTestCase {
XCTAssertGreaterThan(finalLiveSnapshot.totalCacheReusePromptTokens, afterDisconnectLiveSnapshot.totalCacheReusePromptTokens)
}
func testStreamingDisconnectStopsServerWorkWithinTwoHundredMilliseconds() async throws {
let harness = try await makeHarness()
defer { harness.stop() }
let request = APIChatCompletionRequest(
model: "gemma",
messages: [
APIChatMessage(role: "user", content: .text("Count from one to fifty with commas, using many tokens."), name: nil, tool_calls: nil, tool_call_id: nil)
],
temperature: 0,
top_p: 1,
max_tokens: 128,
stream: true,
stop: nil,
tools: nil,
tool_choice: nil,
frequency_penalty: nil,
presence_penalty: nil,
n: nil
)
let url = URL(string: "http://127.0.0.1:\(harness.port)/v1/chat/completions")!
var urlRequest = URLRequest(url: url)
urlRequest.httpMethod = "POST"
urlRequest.setValue("application/json", forHTTPHeaderField: "Content-Type")
urlRequest.httpBody = try JSONEncoder().encode(request)
let observer = StreamCancellationObserver()
let session = URLSession(configuration: .ephemeral)
let baselineDisconnects = LiveCounters.shared.snapshot().totalDisconnects
let task = Task {
let (bytes, response) = try await session.bytes(for: urlRequest)
let httpResponse = try XCTUnwrap(response as? HTTPURLResponse)
XCTAssertEqual(httpResponse.statusCode, 200)
for try await line in bytes.lines {
guard line.hasPrefix("data: ") else { continue }
let payload = String(line.dropFirst(6))
if payload == "[DONE]" {
break
}
guard let data = payload.data(using: .utf8) else { continue }
let chunk = try JSONDecoder().decode(APIChatCompletionChunk.self, from: data)
if let deltaContent = chunk.choices.first?.delta.content, !deltaContent.isEmpty {
await observer.markFirstContentSeen()
try await Task.sleep(nanoseconds: 30_000_000_000)
}
}
}
try await waitUntil(timeoutSeconds: 10) {
await observer.hasSeenFirstContent
}
let disconnectStartedAt = Date()
session.invalidateAndCancel()
task.cancel()
try await waitUntil(timeoutSeconds: 5, intervalNanoseconds: 10_000_000) {
let snapshot = LiveCounters.shared.snapshot()
return snapshot.totalDisconnects > baselineDisconnects && snapshot.activeRequests == 0
}
_ = try? await task.value
let elapsed = Date().timeIntervalSince(disconnectStartedAt)
XCTAssertLessThan(elapsed, 0.2)
}
func testRepeatedStreamingDisconnectsDoNotBreakSubsequentGeneration() async throws {
let harness = try await makeHarness()
defer { harness.stop() }
let request = APIChatCompletionRequest(
model: "gemma",
messages: [
APIChatMessage(role: "user", content: .text("Count from one to forty with commas, using many tokens."), name: nil, tool_calls: nil, tool_call_id: nil)
],
temperature: 0,
top_p: 1,
max_tokens: 96,
stream: true,
stop: nil,
tools: nil,
tool_choice: nil,
frequency_penalty: nil,
presence_penalty: nil,
n: nil
)
for expectedDisconnectCount in 1...3 {
try await cancelStreamingChatCompletionAfterFirstContentAndWaitForServerDisconnect(
request,
port: harness.port,
expectedDisconnectCount: expectedDisconnectCount
)
let liveSnapshot = LiveCounters.shared.snapshot()
XCTAssertEqual(liveSnapshot.totalDisconnects, expectedDisconnectCount)
XCTAssertEqual(liveSnapshot.activeRequests, 0)
}
let recoveryRequest = APIChatCompletionRequest(
model: "gemma",
messages: [
APIChatMessage(role: "user", content: .text("Reply with exactly one short word."), name: nil, tool_calls: nil, tool_call_id: nil)
],
temperature: 0,
top_p: 1,
max_tokens: 2,
stream: false,
stop: nil,
tools: nil,
tool_choice: nil,
frequency_penalty: nil,
presence_penalty: nil,
n: nil
)
let response = try await sendChatCompletion(recoveryRequest, port: harness.port)
XCTAssertEqual(response.choices.count, 1)
XCTAssertEqual(response.choices[0].message.role, "assistant")
XCTAssertFalse((response.choices[0].message.content ?? "").trimmingCharacters(in: .whitespacesAndNewlines).isEmpty)
}
func testStreamingToolCallChunksArriveInOpenAICompatibleOrder() async throws {
let harness = try await makeHarness()
defer { harness.stop() }
@@ -846,9 +1125,9 @@ final class APIServerRewriteTests: XCTestCase {
)
}
private func makeHarness() async throws -> TestHarness {
private func makeHarness(initialModelId: String = "gemma") async throws -> TestHarness {
let modelManager = await MainActor.run { ModelManager() }
let config = try XCTUnwrap(ModelConfig.resolve("gemma"))
let config = try XCTUnwrap(ModelConfig.resolve(initialModelId))
LiveCounters.shared.reset()
TokenPrefixCache.shared.reset()
@@ -994,6 +1273,19 @@ final class APIServerRewriteTests: XCTestCase {
_ = try? await task.value
}
private func cancelStreamingChatCompletionAfterFirstContentAndWaitForServerDisconnect(
_ request: APIChatCompletionRequest,
port: UInt16,
expectedDisconnectCount: Int
) async throws {
try await cancelStreamingChatCompletionAfterFirstContent(request, port: port)
try await waitUntil(timeoutSeconds: 5, intervalNanoseconds: 10_000_000) {
let snapshot = LiveCounters.shared.snapshot()
return snapshot.totalDisconnects >= expectedDisconnectCount && snapshot.activeRequests == 0
}
}
private func waitUntil(
timeoutSeconds: TimeInterval,
intervalNanoseconds: UInt64 = 100_000_000,