From ed6cc5f5d10c9ea8ccc05e01b09c4a9542735c0f Mon Sep 17 00:00:00 2001 From: Chili Palmer Date: Tue, 17 Mar 2026 21:24:04 +0100 Subject: [PATCH] fix: better handling of API stuff, still not where internal chat is --- CLAUDE.md | 9 + MLXServer/Models/InferenceStats.swift | 188 ++++++++++++----- MLXServer/Server/APIServer.swift | 292 ++++++++++++++------------ MLXServer/Server/ToolCallParser.swift | 59 ++++++ 4 files changed, 358 insertions(+), 190 deletions(-) diff --git a/CLAUDE.md b/CLAUDE.md index 3af05c9..3f9265a 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -36,6 +36,15 @@ open "build/Debug/MLX Server.app" | `gemma` | `mlx-community/gemma-3-4b-it-4bit` | Vision + tool use via `tool_code` blocks (128k context) | | `qwen` | `mlx-community/Qwen3-VL-4B-Instruct-4bit` | Vision + tool use via `` tags (256k context) | +## Critical Performance Rule + +**Inference speed is the #1 priority.** The token generation loop must never be blocked or slowed by anything else — no MainActor hops, no SwiftUI observation, no synchronous I/O. Everything that isn't inference (stats collection, UI updates, logging) must run on separate threads via loose coupling: + +- **`LiveCounters`** (thread-safe singleton with `OSAllocatedUnfairLock`) is the bridge: generation code writes to it directly from any thread with zero actor overhead. +- **`InferenceStats`** (UI-side, `@Observable @MainActor`) polls `LiveCounters` at 1Hz via a timer — never the other way around. +- SSE streaming (`sendSSEEvent`/`sendData`) runs nonisolated off MainActor so token sends don't compete with SwiftUI rendering. +- Never gate token output on UI state, analytics, or any `@MainActor`-isolated code. + ## Key Design Decisions - Uses `mlx-swift-lm` (`MLXVLM` / `VLMModelFactory`) as the inference backend — supports both text and vision in a single model load diff --git a/MLXServer/Models/InferenceStats.swift b/MLXServer/Models/InferenceStats.swift index 5efc6a5..66c08e2 100644 --- a/MLXServer/Models/InferenceStats.swift +++ b/MLXServer/Models/InferenceStats.swift @@ -1,11 +1,124 @@ import Foundation +import os + +// MARK: - Thread-safe live counters (written from any thread, no actor isolation) + +/// Lock-protected counters that the generation loop writes to directly. +/// No MainActor requirement — the UI polls these via the 1Hz timer. +final class LiveCounters: @unchecked Sendable { + static let shared = LiveCounters() + + private let lock = OSAllocatedUnfairLock() + + // Current request + private var _activeRequests: Int = 0 + private var _promptTokens: Int = 0 + private var _generationTokens: Int = 0 + private var _tokensPerSecond: Double = 0 + private var _isPrefilling: Bool = false + private var _isGenerating: Bool = false + private var _contextMax: Int = 0 + + // Cumulative + private var _totalRequests: Int = 0 + private var _totalPromptTokens: Int = 0 + private var _totalGenerationTokens: Int = 0 + + func requestStarted(contextLength: Int) { + lock.lock() + _activeRequests += 1 + _totalRequests += 1 + _isPrefilling = true + _isGenerating = false + _promptTokens = 0 + _generationTokens = 0 + _tokensPerSecond = 0 + _contextMax = contextLength + lock.unlock() + } + + func prefillCompleted(promptTokens: Int) { + lock.lock() + _isPrefilling = false + _isGenerating = true + _promptTokens = promptTokens + _totalPromptTokens += promptTokens + lock.unlock() + } + + func tokenGenerated(tokensPerSecond: Double, totalGenerated: Int) { + lock.lock() + _generationTokens = totalGenerated + _tokensPerSecond = tokensPerSecond + lock.unlock() + } + + func requestCompleted(generationTokens: Int) { + lock.lock() + _activeRequests = max(0, _activeRequests - 1) + _totalGenerationTokens += generationTokens + if _activeRequests == 0 { + _isGenerating = false + _isPrefilling = false + _tokensPerSecond = 0 + } + lock.unlock() + } + + func reset() { + lock.lock() + _activeRequests = 0 + _promptTokens = 0 + _generationTokens = 0 + _tokensPerSecond = 0 + _isPrefilling = false + _isGenerating = false + _contextMax = 0 + _totalRequests = 0 + _totalPromptTokens = 0 + _totalGenerationTokens = 0 + lock.unlock() + } + + /// Atomic snapshot for the UI timer. + func snapshot() -> Snapshot { + lock.lock() + let s = Snapshot( + activeRequests: _activeRequests, + promptTokens: _promptTokens, + generationTokens: _generationTokens, + tokensPerSecond: _tokensPerSecond, + isPrefilling: _isPrefilling, + isGenerating: _isGenerating, + contextMax: _contextMax, + totalRequests: _totalRequests, + totalPromptTokens: _totalPromptTokens, + totalGenerationTokens: _totalGenerationTokens + ) + lock.unlock() + return s + } + + struct Snapshot { + let activeRequests: Int + let promptTokens: Int + let generationTokens: Int + let tokensPerSecond: Double + let isPrefilling: Bool + let isGenerating: Bool + let contextMax: Int + let totalRequests: Int + let totalPromptTokens: Int + let totalGenerationTokens: Int + } +} + +// MARK: - Observable stats for the UI (polls LiveCounters at 1Hz) -/// Lightweight stats collector for inference activity visualization. -/// All mutations happen on @MainActor to avoid locks. @Observable @MainActor final class InferenceStats { - // MARK: - Current request state + // MARK: - Current request state (refreshed from LiveCounters) var activeRequests: Int = 0 var currentPromptTokens: Int = 0 @@ -40,11 +153,9 @@ final class InferenceStats { private var sampleTimer: Timer? private var lastGenerationTokenCount: Int = 0 private var lastPromptTokenCount: Int = 0 - private var lastSampleTime: Date = .now func startSampling() { guard sampleTimer == nil else { return } - lastSampleTime = .now sampleTimer = Timer.scheduledTimer(withTimeInterval: 1.0, repeats: true) { [weak self] _ in Task { @MainActor in self?.recordSample() @@ -58,19 +169,31 @@ final class InferenceStats { } private func recordSample() { + // Pull live values from the thread-safe counters + let snap = LiveCounters.shared.snapshot() + + activeRequests = snap.activeRequests + currentPromptTokens = snap.promptTokens + currentGenerationTokens = snap.generationTokens + currentTokensPerSecond = snap.tokensPerSecond + isPrefilling = snap.isPrefilling + isGenerating = snap.isGenerating + contextMax = snap.contextMax + contextUsed = snap.promptTokens + snap.generationTokens + totalRequests = snap.totalRequests + totalPromptTokens = snap.totalPromptTokens + totalGenerationTokens = snap.totalGenerationTokens + let now = Date.now + let genDelta = snap.totalGenerationTokens - lastGenerationTokenCount + let promptDelta = snap.totalPromptTokens - lastPromptTokenCount + lastGenerationTokenCount = snap.totalGenerationTokens + lastPromptTokenCount = snap.totalPromptTokens - // Token rate: tokens generated since last sample - let genDelta = totalGenerationTokens - lastGenerationTokenCount - let promptDelta = totalPromptTokens - lastPromptTokenCount - lastGenerationTokenCount = totalGenerationTokens - lastPromptTokenCount = totalPromptTokens - - tokenRateHistory.append(DataPoint(timestamp: now, value: currentTokensPerSecond)) + tokenRateHistory.append(DataPoint(timestamp: now, value: snap.tokensPerSecond)) generationTokenHistory.append(DataPoint(timestamp: now, value: Double(genDelta))) promptTokenHistory.append(DataPoint(timestamp: now, value: Double(promptDelta))) - // Trim to ring buffer size if tokenRateHistory.count > Self.maxHistoryPoints { tokenRateHistory.removeFirst(tokenRateHistory.count - Self.maxHistoryPoints) } @@ -82,45 +205,8 @@ final class InferenceStats { } } - // MARK: - Event recording (called from APIServer) - - func requestStarted(contextLength: Int) { - activeRequests += 1 - totalRequests += 1 - isPrefilling = true - isGenerating = false - currentPromptTokens = 0 - currentGenerationTokens = 0 - currentTokensPerSecond = 0 - contextMax = contextLength - contextUsed = 0 - } - - func prefillCompleted(promptTokens: Int) { - isPrefilling = false - isGenerating = true - currentPromptTokens = promptTokens - totalPromptTokens += promptTokens - contextUsed = promptTokens - } - - func tokenGenerated(tokensPerSecond: Double, totalGenerated: Int) { - currentGenerationTokens = totalGenerated - currentTokensPerSecond = tokensPerSecond - contextUsed = currentPromptTokens + totalGenerated - } - - func requestCompleted(promptTokens: Int, generationTokens: Int) { - activeRequests = max(0, activeRequests - 1) - totalGenerationTokens += generationTokens - if activeRequests == 0 { - isGenerating = false - isPrefilling = false - currentTokensPerSecond = 0 - } - } - func reset() { + LiveCounters.shared.reset() activeRequests = 0 currentPromptTokens = 0 currentGenerationTokens = 0 diff --git a/MLXServer/Server/APIServer.swift b/MLXServer/Server/APIServer.swift index 2a36415..f570899 100644 --- a/MLXServer/Server/APIServer.swift +++ b/MLXServer/Server/APIServer.swift @@ -20,6 +20,7 @@ final class APIServer { private var cachedSession: ChatSession? private var cachedMessages: [Chat.Message]? private var cachedModelId: String? + private var cachedInstructions: String = "" func start(modelManager: ModelManager, port: Int = 1234) { guard !isRunning else { return } @@ -29,6 +30,10 @@ final class APIServer { do { let params = NWParameters.tcp params.allowLocalEndpointReuse = true + // Disable Nagle's algorithm so small SSE events go out immediately + if let tcpOptions = params.defaultProtocolStack.transportProtocol as? NWProtocolTCP.Options { + tcpOptions.noDelay = true + } listener = try NWListener(using: params, on: NWEndpoint.Port(integerLiteral: UInt16(port))) listener?.stateUpdateHandler = { [weak self] state in @@ -68,6 +73,7 @@ final class APIServer { cachedSession = nil cachedMessages = nil cachedModelId = nil + cachedInstructions = "" inferenceStats.stopSampling() } @@ -183,6 +189,7 @@ final class APIServer { cachedSession = nil cachedMessages = nil cachedModelId = nil + cachedInstructions = "" await modelManager.loadModel(targetConfig) } } @@ -196,6 +203,7 @@ final class APIServer { cachedSession = nil cachedMessages = nil cachedModelId = nil + cachedInstructions = "" await modelManager.loadModel(config) } @@ -220,32 +228,33 @@ final class APIServer { var images: [UserInput.Image] = [] let currentModelRepoId = modelManager.currentModel?.repoId ?? modelName - // Inject tool definitions into the system prompt if tools are provided + // Build the instructions string (system prompt + tool definitions). + // This is passed to ChatSession via `instructions:` rather than injected + // as history messages, so it avoids an expensive history-replay prefill. + var instructions: String = "" + + // Collect system message text from the request + for msg in request.messages where msg.role == "system" { + let text = msg.content?.textContent ?? "" + if !text.isEmpty { + if !instructions.isEmpty { instructions += "\n\n" } + instructions += text + } + } + + // Append tool definitions to instructions if let tools = request.tools, !tools.isEmpty { let toolSystemPrompt = ToolPromptBuilder.buildSystemPrompt(tools: tools, modelId: currentModelRepoId) - - // Check if there's already a system message - let hasSystem = request.messages.contains { $0.role == "system" } - if hasSystem { - // Append tool prompt to existing system message (handled below during conversion) - } else { - // For Gemma: inject as user message (Gemma doesn't support system role natively) - // For Qwen: inject as system message - if currentModelRepoId.lowercased().contains("qwen") { - chatMessages.append(Chat.Message(role: .system, content: toolSystemPrompt)) - } else { - chatMessages.append(Chat.Message(role: .user, content: toolSystemPrompt)) - chatMessages.append(Chat.Message(role: .assistant, content: "Understood. I will use the provided tools when appropriate.")) - } - } + if !instructions.isEmpty { instructions += "\n\n" } + instructions += toolSystemPrompt } let toolsForInjection = request.tools let isQwen = currentModelRepoId.lowercased().contains("qwen") - for msg in request.messages { + // Convert non-system messages to Chat.Message + for msg in request.messages where msg.role != "system" { let role: Chat.Message.Role = switch msg.role { - case "system": .system case "assistant": .assistant case "tool": .user default: .user @@ -253,12 +262,6 @@ final class APIServer { var text = msg.content?.textContent ?? "" - // If this is a system message and tools are provided, append tool definitions - if msg.role == "system", let tools = toolsForInjection, !tools.isEmpty { - let toolSystemPrompt = ToolPromptBuilder.buildSystemPrompt(tools: tools, modelId: currentModelRepoId) - text = text + "\n\n" + toolSystemPrompt - } - // Format tool_call_id responses as tool_output for the model if msg.role == "tool" { if isQwen { @@ -328,6 +331,7 @@ final class APIServer { let canReuse = cachedSession != nil && cachedModelId == currentModelId && cachedMessages != nil + && cachedInstructions == instructions && messagesMatch(cachedMessages!, allButLast) let session: ChatSession @@ -339,15 +343,21 @@ final class APIServer { if cachedSession != nil { print("[APIServer] History diverged, creating fresh session") } + // Use `instructions:` for system/tool prompt (matches internal chat pattern). + // Only conversation turns go in `history:` — this avoids replaying the + // large tool prompt as history on every new session. + let instr = instructions.isEmpty ? nil : instructions if !allButLast.isEmpty { session = ChatSession( container, + instructions: instr, history: allButLast, generateParameters: generateParams ) } else { session = ChatSession( container, + instructions: instr, generateParameters: generateParams ) } @@ -356,7 +366,7 @@ final class APIServer { // Extract images from the last message only (ChatSession.streamDetails takes images separately) let lastImages = lastMessage.images - inferenceStats.requestStarted(contextLength: contextLength) + LiveCounters.shared.requestStarted(contextLength: contextLength) if isStream { await handleStreamingResponse( @@ -387,6 +397,7 @@ final class APIServer { cachedSession = session cachedMessages = chatMessages // full messages including the one just sent cachedModelId = currentModelId + cachedInstructions = instructions } /// Decode a base64 data URI (data:image/png;base64,...) into a UserInput.Image. @@ -439,20 +450,20 @@ final class APIServer { case .chunk(let text): fullText += text completionTokens += 1 - inferenceStats.tokenGenerated(tokensPerSecond: 0, totalGenerated: completionTokens) + LiveCounters.shared.tokenGenerated(tokensPerSecond: 0, totalGenerated: completionTokens) case .info(let info): promptTokens = info.promptTokenCount completionTokens = info.generationTokenCount - inferenceStats.prefillCompleted(promptTokens: promptTokens) + LiveCounters.shared.prefillCompleted(promptTokens: promptTokens) if info.tokensPerSecond > 0 { - inferenceStats.tokenGenerated(tokensPerSecond: info.tokensPerSecond, totalGenerated: completionTokens) + LiveCounters.shared.tokenGenerated(tokensPerSecond: info.tokensPerSecond, totalGenerated: completionTokens) } case .toolCall(let call): frameworkToolCalls.append(call) } } - inferenceStats.requestCompleted(promptTokens: promptTokens, generationTokens: completionTokens) + LiveCounters.shared.requestCompleted(generationTokens: completionTokens) // Parse tool calls: first check framework-detected ones, then our own text parser var finishReason = "stop" @@ -524,7 +535,7 @@ final class APIServer { sendResponse(connection: connection, status: 200, body: String(data: json, encoding: .utf8) ?? "{}") } } catch { - inferenceStats.requestCompleted(promptTokens: 0, generationTokens: 0) + LiveCounters.shared.requestCompleted(generationTokens: 0) sendResponse(connection: connection, status: 500, body: #"{"error":"\#(error.localizedDescription)"}"#) } } @@ -552,15 +563,10 @@ final class APIServer { "", ].joined(separator: "\r\n") - let headerSent = await withCheckedContinuation { continuation in - connection.send(content: header.data(using: .utf8), completion: .contentProcessed({ _ in - continuation.resume(returning: true) - })) - } - guard headerSent else { return } + await Self.sendData(connection: connection, data: header.data(using: .utf8)!) // Send initial role chunk - sendSSEEvent(connection: connection, chunk: APIChatCompletionChunk( + await Self.sendSSEEvent(connection: connection, chunk: APIChatCompletionChunk( id: requestId, object: "chat.completion.chunk", created: created, @@ -569,78 +575,39 @@ final class APIServer { usage: nil )) - // When tools are available, buffer full response to parse tool calls - // (otherwise raw tool-call markup leaks into streamed text) - let bufferForTools = tools != nil && !(tools?.isEmpty ?? true) + let hasTools = tools != nil && !(tools?.isEmpty ?? true) - var promptTokens = 0 - var completionTokens = 0 - var fullText = "" - var frameworkToolCalls: [MLXLMCommon.ToolCall] = [] - - do { - let stream = session.streamDetails( - to: prompt, - images: images, - videos: [] + // Run the generation loop OFF MainActor. + // ChatSession and NWConnection don't need MainActor. + // Running on MainActor caused every token to compete with SwiftUI + // rendering, creating back-pressure that coalesced all output. + let stream = session.streamDetails( + to: prompt, + images: images, + videos: [] + ) + // Transfer non-Sendable values to the nonisolated loop. + // Safe because we don't touch session/images again until after the loop. + let result = await { + nonisolated(unsafe) let stream = stream + return await Self.runStreamingLoop( + connection: connection, + stream: stream, + requestId: requestId, + created: created, + modelName: modelName ) + }() - for try await generation in stream { - switch generation { - case .chunk(let text): - completionTokens += 1 - fullText += text - inferenceStats.tokenGenerated(tokensPerSecond: 0, totalGenerated: completionTokens) + let (promptTokens, completionTokens, fullText, frameworkToolCalls) = result - if !bufferForTools { - sendSSEEvent(connection: connection, chunk: APIChatCompletionChunk( - id: requestId, - object: "chat.completion.chunk", - created: created, - model: modelName, - choices: [APIStreamChoice(index: 0, delta: APIDeltaMessage(role: nil, content: text, tool_calls: nil), finish_reason: nil)], - usage: nil - )) - } - - case .info(let info): - promptTokens = info.promptTokenCount - completionTokens = info.generationTokenCount - inferenceStats.prefillCompleted(promptTokens: promptTokens) - if info.tokensPerSecond > 0 { - inferenceStats.tokenGenerated(tokensPerSecond: info.tokensPerSecond, totalGenerated: completionTokens) - } - - case .toolCall(let call): - frameworkToolCalls.append(call) - } - } - } catch { - inferenceStats.requestCompleted(promptTokens: promptTokens, generationTokens: completionTokens) - let errorEvent = "data: {\"error\":\"\(error.localizedDescription)\"}\n\n" - connection.send(content: errorEvent.data(using: .utf8), completion: .contentProcessed({ _ in })) - } + // Stats were already updated by LiveCounters inside the loop // Post-generation: handle tool calls (framework-detected or text-parsed) var finishReason = "stop" if !frameworkToolCalls.isEmpty { - // Framework natively detected tool calls (e.g. Qwen) finishReason = "tool_calls" - - // Emit any buffered text content - if !fullText.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty { - sendSSEEvent(connection: connection, chunk: APIChatCompletionChunk( - id: requestId, - object: "chat.completion.chunk", - created: created, - model: modelName, - choices: [APIStreamChoice(index: 0, delta: APIDeltaMessage(role: nil, content: fullText, tool_calls: nil), finish_reason: nil)], - usage: nil - )) - } - - // Emit tool call chunks for (i, tc) in frameworkToolCalls.enumerated() { let argsDict = tc.function.arguments.mapValues { $0.anyValue } let argsJSON: String @@ -657,7 +624,7 @@ final class APIServer { type: "function", function: APIFunctionCall(name: tc.function.name, arguments: argsJSON) ) - sendSSEEvent(connection: connection, chunk: APIChatCompletionChunk( + await Self.sendSSEEvent(connection: connection, chunk: APIChatCompletionChunk( id: requestId, object: "chat.completion.chunk", created: created, @@ -666,27 +633,11 @@ final class APIServer { usage: nil )) } - } else if bufferForTools { - // Text-parsed tool calls (e.g. Gemma tool_code blocks) - let (cleanText, parsed) = ToolCallParser.parse(text: fullText, tools: tools) + } else if hasTools { + let (_, parsed) = ToolCallParser.parse(text: fullText, tools: tools) if !parsed.isEmpty { finishReason = "tool_calls" - fullText = cleanText } - - // Emit buffered content (cleaned of tool-call markup) - if !fullText.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty { - sendSSEEvent(connection: connection, chunk: APIChatCompletionChunk( - id: requestId, - object: "chat.completion.chunk", - created: created, - model: modelName, - choices: [APIStreamChoice(index: 0, delta: APIDeltaMessage(role: nil, content: fullText, tool_calls: nil), finish_reason: nil)], - usage: nil - )) - } - - // Emit tool call chunks for (i, tc) in parsed.enumerated() { let apiToolCall = APIToolCall( index: i, @@ -694,7 +645,7 @@ final class APIServer { type: "function", function: APIFunctionCall(name: tc.name, arguments: tc.arguments) ) - sendSSEEvent(connection: connection, chunk: APIChatCompletionChunk( + await Self.sendSSEEvent(connection: connection, chunk: APIChatCompletionChunk( id: requestId, object: "chat.completion.chunk", created: created, @@ -706,7 +657,7 @@ final class APIServer { } // Final chunk with finish_reason and usage - sendSSEEvent(connection: connection, chunk: APIChatCompletionChunk( + await Self.sendSSEEvent(connection: connection, chunk: APIChatCompletionChunk( id: requestId, object: "chat.completion.chunk", created: created, @@ -719,20 +670,83 @@ final class APIServer { ) )) - inferenceStats.requestCompleted(promptTokens: promptTokens, generationTokens: completionTokens) + LiveCounters.shared.requestCompleted(generationTokens: completionTokens) // Send [DONE] and close - let done = "data: [DONE]\n\n" - connection.send(content: done.data(using: .utf8), completion: .contentProcessed({ _ in - connection.cancel() - })) + await Self.sendData(connection: connection, data: "data: [DONE]\n\n".data(using: .utf8)!) + connection.cancel() } - private func sendSSEEvent(connection: NWConnection, chunk: APIChatCompletionChunk) { + /// Run the token generation + SSE send loop entirely off MainActor. + /// This is critical: if the loop runs on MainActor, every token requires + /// multiple actor hops competing with SwiftUI, causing all output to batch. + nonisolated private static func runStreamingLoop( + connection: NWConnection, + stream: AsyncThrowingStream, + requestId: String, + created: Int, + modelName: String + ) async -> (Int, Int, String, [MLXLMCommon.ToolCall]) { + var promptTokens = 0 + var completionTokens = 0 + var fullText = "" + var frameworkToolCalls: [MLXLMCommon.ToolCall] = [] + + do { + for try await generation in stream { + switch generation { + case .chunk(let text): + completionTokens += 1 + fullText += text + + // Update live counters directly — no MainActor hop needed + LiveCounters.shared.tokenGenerated(tokensPerSecond: 0, totalGenerated: completionTokens) + + // Send directly — no MainActor hop. + await sendSSEEvent(connection: connection, chunk: APIChatCompletionChunk( + id: requestId, + object: "chat.completion.chunk", + created: created, + model: modelName, + choices: [APIStreamChoice(index: 0, delta: APIDeltaMessage(role: nil, content: text, tool_calls: nil), finish_reason: nil)], + usage: nil + )) + + case .info(let info): + promptTokens = info.promptTokenCount + completionTokens = info.generationTokenCount + LiveCounters.shared.prefillCompleted(promptTokens: promptTokens) + if info.tokensPerSecond > 0 { + LiveCounters.shared.tokenGenerated(tokensPerSecond: info.tokensPerSecond, totalGenerated: completionTokens) + } + + case .toolCall(let call): + frameworkToolCalls.append(call) + } + } + } catch { + let errorEvent = "data: {\"error\":\"\(error.localizedDescription)\"}\n\n" + await sendData(connection: connection, data: errorEvent.data(using: .utf8)!) + } + + return (promptTokens, completionTokens, fullText, frameworkToolCalls) + } + + /// Send an SSE event and wait for the protocol stack to process it. + nonisolated private static func sendSSEEvent(connection: NWConnection, chunk: APIChatCompletionChunk) async { guard let json = try? JSONEncoder().encode(chunk), let jsonString = String(data: json, encoding: .utf8) else { return } let event = "data: \(jsonString)\n\n" - connection.send(content: event.data(using: .utf8), completion: .contentProcessed({ _ in })) + await sendData(connection: connection, data: event.data(using: .utf8)!) + } + + /// Send raw data on the connection and wait for the protocol stack to process it. + nonisolated private static func sendData(connection: NWConnection, data: Data) async { + await withCheckedContinuation { (continuation: CheckedContinuation) in + connection.send(content: data, completion: .contentProcessed({ _ in + continuation.resume() + })) + } } // MARK: - HTTP helpers @@ -778,19 +792,19 @@ final class APIServer { ] } - /// Check if cached messages are a prefix of new messages (for KV cache reuse). - /// The cached messages include the full history from the previous request. - /// For cache reuse, all but the last message of the new request must match - /// all but the last message of the cached messages (the cached last was the - /// previous user prompt, which is now part of the history). + /// Check if the cached session can be reused for the new history. + /// + /// After a request the session's KV cache contains: + /// cachedMessages (history + user prompt) + the generated assistant response. + /// On the next request the client sends back the full conversation, so + /// `newHistory` (allButLast) is typically `cachedMessages` + 1 assistant reply. + /// We allow reuse when `cached` is a prefix of `newHistory` and there is at most + /// one extra message (the assistant response the session already generated). + /// More than one extra message (e.g. injected tool results) means the session + /// hasn't processed them, so we must create a fresh session. private func messagesMatch(_ cached: [Chat.Message], _ newHistory: [Chat.Message]) -> Bool { - // The cached messages are the full chatMessages from the previous request. - // For the cache to be reusable, the new history (allButLast) must match - // exactly what the session has already processed. - // After a request, the session has seen: cachedMessages' history + prompt + response. - // So on the next request, if newHistory == cachedMessages, the session already - // contains all of those turns and we can just send the new last message. - guard cached.count == newHistory.count else { return false } + guard cached.count <= newHistory.count, + newHistory.count <= cached.count + 1 else { return false } for (a, b) in zip(cached, newHistory) { if a.role != b.role || a.content != b.content { return false } } diff --git a/MLXServer/Server/ToolCallParser.swift b/MLXServer/Server/ToolCallParser.swift index 96f1d18..b4d8e4a 100644 --- a/MLXServer/Server/ToolCallParser.swift +++ b/MLXServer/Server/ToolCallParser.swift @@ -24,6 +24,12 @@ enum ToolCallParser { return (gemmaClean, gemmaCalls) } + // Try bare function calls matching known tool names: tool_name(args...) + let (bareClean, bareCalls) = parseBareToolCalls(text: text, tools: tools) + if !bareCalls.isEmpty { + return (bareClean, bareCalls) + } + return (text, []) } @@ -187,4 +193,57 @@ enum ToolCallParser { return (cleanText, toolCalls) } + + // MARK: - Bare function calls: tool_name(args...) + + /// Parse bare function calls that match known tool names. + /// Handles models that output tool calls without fences or XML tags. + private static func parseBareToolCalls(text: String, tools: [APIToolDefinition]?) -> (String, [ParsedToolCall]) { + guard let tools, !tools.isEmpty else { return (text, []) } + + let toolNames = tools.map { $0.function.name } + guard !toolNames.isEmpty else { return (text, []) } + + // Build regex: (tool_name1|tool_name2)\s*\(.*\) + let escapedNames = toolNames.map { NSRegularExpression.escapedPattern(for: $0) } + let pattern = "(" + escapedNames.joined(separator: "|") + #")\s*\((.*)\)"# + guard let regex = try? NSRegularExpression(pattern: pattern, options: .dotMatchesLineSeparators) else { + return (text, []) + } + + let nsText = text as NSString + let matches = regex.matches(in: text, range: NSRange(location: 0, length: nsText.length)) + guard !matches.isEmpty else { return (text, []) } + + var toolDefs: [String: [String]] = [:] + for tool in tools { + let paramNames = tool.function.parameters?["properties"]?.value as? [String: Any] + toolDefs[tool.function.name] = paramNames.map { Array($0.keys).sorted() } ?? [] + } + + var toolCalls: [ParsedToolCall] = [] + for (i, match) in matches.enumerated() { + let name = nsText.substring(with: match.range(at: 1)) + let argsStr = nsText.substring(with: match.range(at: 2)).trimmingCharacters(in: .whitespacesAndNewlines) + + var args: [String: Any] = [:] + if !argsStr.isEmpty { + if let (_, parsed) = parsePythonCall("\(name)(\(argsStr))", toolDefs: toolDefs) as (String, [String: Any])? { + args = parsed + } + } + + let argsJSON = (try? JSONSerialization.data(withJSONObject: args)) + .flatMap { String(data: $0, encoding: .utf8) } ?? "{}" + let callId = String(format: "call_%d_%08d", i, abs(name.hashValue) % 100_000_000) + toolCalls.append(ParsedToolCall(id: callId, name: name, arguments: argsJSON)) + } + + let cleanText = regex.stringByReplacingMatches( + in: text, range: NSRange(location: 0, length: nsText.length), + withTemplate: "" + ).trimmingCharacters(in: .whitespacesAndNewlines) + + return (cleanText, toolCalls) + } }