fix: better handling of API stuff, still not where internal chat is

This commit is contained in:
2026-03-17 21:24:04 +01:00
parent 20f9c0bcc4
commit ed6cc5f5d1
4 changed files with 358 additions and 190 deletions

View File

@@ -36,6 +36,15 @@ open "build/Debug/MLX Server.app"
| `gemma` | `mlx-community/gemma-3-4b-it-4bit` | Vision + tool use via `tool_code` blocks (128k context) | | `gemma` | `mlx-community/gemma-3-4b-it-4bit` | Vision + tool use via `tool_code` blocks (128k context) |
| `qwen` | `mlx-community/Qwen3-VL-4B-Instruct-4bit` | Vision + tool use via `<tool_call>` tags (256k context) | | `qwen` | `mlx-community/Qwen3-VL-4B-Instruct-4bit` | Vision + tool use via `<tool_call>` tags (256k context) |
## Critical Performance Rule
**Inference speed is the #1 priority.** The token generation loop must never be blocked or slowed by anything else — no MainActor hops, no SwiftUI observation, no synchronous I/O. Everything that isn't inference (stats collection, UI updates, logging) must run on separate threads via loose coupling:
- **`LiveCounters`** (thread-safe singleton with `OSAllocatedUnfairLock`) is the bridge: generation code writes to it directly from any thread with zero actor overhead.
- **`InferenceStats`** (UI-side, `@Observable @MainActor`) polls `LiveCounters` at 1Hz via a timer — never the other way around.
- SSE streaming (`sendSSEEvent`/`sendData`) runs nonisolated off MainActor so token sends don't compete with SwiftUI rendering.
- Never gate token output on UI state, analytics, or any `@MainActor`-isolated code.
## Key Design Decisions ## Key Design Decisions
- Uses `mlx-swift-lm` (`MLXVLM` / `VLMModelFactory`) as the inference backend — supports both text and vision in a single model load - Uses `mlx-swift-lm` (`MLXVLM` / `VLMModelFactory`) as the inference backend — supports both text and vision in a single model load

View File

@@ -1,11 +1,124 @@
import Foundation import Foundation
import os
// MARK: - Thread-safe live counters (written from any thread, no actor isolation)
/// Lock-protected counters that the generation loop writes to directly.
/// No MainActor requirement the UI polls these via the 1Hz timer.
final class LiveCounters: @unchecked Sendable {
static let shared = LiveCounters()
private let lock = OSAllocatedUnfairLock()
// Current request
private var _activeRequests: Int = 0
private var _promptTokens: Int = 0
private var _generationTokens: Int = 0
private var _tokensPerSecond: Double = 0
private var _isPrefilling: Bool = false
private var _isGenerating: Bool = false
private var _contextMax: Int = 0
// Cumulative
private var _totalRequests: Int = 0
private var _totalPromptTokens: Int = 0
private var _totalGenerationTokens: Int = 0
func requestStarted(contextLength: Int) {
lock.lock()
_activeRequests += 1
_totalRequests += 1
_isPrefilling = true
_isGenerating = false
_promptTokens = 0
_generationTokens = 0
_tokensPerSecond = 0
_contextMax = contextLength
lock.unlock()
}
func prefillCompleted(promptTokens: Int) {
lock.lock()
_isPrefilling = false
_isGenerating = true
_promptTokens = promptTokens
_totalPromptTokens += promptTokens
lock.unlock()
}
func tokenGenerated(tokensPerSecond: Double, totalGenerated: Int) {
lock.lock()
_generationTokens = totalGenerated
_tokensPerSecond = tokensPerSecond
lock.unlock()
}
func requestCompleted(generationTokens: Int) {
lock.lock()
_activeRequests = max(0, _activeRequests - 1)
_totalGenerationTokens += generationTokens
if _activeRequests == 0 {
_isGenerating = false
_isPrefilling = false
_tokensPerSecond = 0
}
lock.unlock()
}
func reset() {
lock.lock()
_activeRequests = 0
_promptTokens = 0
_generationTokens = 0
_tokensPerSecond = 0
_isPrefilling = false
_isGenerating = false
_contextMax = 0
_totalRequests = 0
_totalPromptTokens = 0
_totalGenerationTokens = 0
lock.unlock()
}
/// Atomic snapshot for the UI timer.
func snapshot() -> Snapshot {
lock.lock()
let s = Snapshot(
activeRequests: _activeRequests,
promptTokens: _promptTokens,
generationTokens: _generationTokens,
tokensPerSecond: _tokensPerSecond,
isPrefilling: _isPrefilling,
isGenerating: _isGenerating,
contextMax: _contextMax,
totalRequests: _totalRequests,
totalPromptTokens: _totalPromptTokens,
totalGenerationTokens: _totalGenerationTokens
)
lock.unlock()
return s
}
struct Snapshot {
let activeRequests: Int
let promptTokens: Int
let generationTokens: Int
let tokensPerSecond: Double
let isPrefilling: Bool
let isGenerating: Bool
let contextMax: Int
let totalRequests: Int
let totalPromptTokens: Int
let totalGenerationTokens: Int
}
}
// MARK: - Observable stats for the UI (polls LiveCounters at 1Hz)
/// Lightweight stats collector for inference activity visualization.
/// All mutations happen on @MainActor to avoid locks.
@Observable @Observable
@MainActor @MainActor
final class InferenceStats { final class InferenceStats {
// MARK: - Current request state // MARK: - Current request state (refreshed from LiveCounters)
var activeRequests: Int = 0 var activeRequests: Int = 0
var currentPromptTokens: Int = 0 var currentPromptTokens: Int = 0
@@ -40,11 +153,9 @@ final class InferenceStats {
private var sampleTimer: Timer? private var sampleTimer: Timer?
private var lastGenerationTokenCount: Int = 0 private var lastGenerationTokenCount: Int = 0
private var lastPromptTokenCount: Int = 0 private var lastPromptTokenCount: Int = 0
private var lastSampleTime: Date = .now
func startSampling() { func startSampling() {
guard sampleTimer == nil else { return } guard sampleTimer == nil else { return }
lastSampleTime = .now
sampleTimer = Timer.scheduledTimer(withTimeInterval: 1.0, repeats: true) { [weak self] _ in sampleTimer = Timer.scheduledTimer(withTimeInterval: 1.0, repeats: true) { [weak self] _ in
Task { @MainActor in Task { @MainActor in
self?.recordSample() self?.recordSample()
@@ -58,19 +169,31 @@ final class InferenceStats {
} }
private func recordSample() { private func recordSample() {
// Pull live values from the thread-safe counters
let snap = LiveCounters.shared.snapshot()
activeRequests = snap.activeRequests
currentPromptTokens = snap.promptTokens
currentGenerationTokens = snap.generationTokens
currentTokensPerSecond = snap.tokensPerSecond
isPrefilling = snap.isPrefilling
isGenerating = snap.isGenerating
contextMax = snap.contextMax
contextUsed = snap.promptTokens + snap.generationTokens
totalRequests = snap.totalRequests
totalPromptTokens = snap.totalPromptTokens
totalGenerationTokens = snap.totalGenerationTokens
let now = Date.now let now = Date.now
let genDelta = snap.totalGenerationTokens - lastGenerationTokenCount
let promptDelta = snap.totalPromptTokens - lastPromptTokenCount
lastGenerationTokenCount = snap.totalGenerationTokens
lastPromptTokenCount = snap.totalPromptTokens
// Token rate: tokens generated since last sample tokenRateHistory.append(DataPoint(timestamp: now, value: snap.tokensPerSecond))
let genDelta = totalGenerationTokens - lastGenerationTokenCount
let promptDelta = totalPromptTokens - lastPromptTokenCount
lastGenerationTokenCount = totalGenerationTokens
lastPromptTokenCount = totalPromptTokens
tokenRateHistory.append(DataPoint(timestamp: now, value: currentTokensPerSecond))
generationTokenHistory.append(DataPoint(timestamp: now, value: Double(genDelta))) generationTokenHistory.append(DataPoint(timestamp: now, value: Double(genDelta)))
promptTokenHistory.append(DataPoint(timestamp: now, value: Double(promptDelta))) promptTokenHistory.append(DataPoint(timestamp: now, value: Double(promptDelta)))
// Trim to ring buffer size
if tokenRateHistory.count > Self.maxHistoryPoints { if tokenRateHistory.count > Self.maxHistoryPoints {
tokenRateHistory.removeFirst(tokenRateHistory.count - Self.maxHistoryPoints) tokenRateHistory.removeFirst(tokenRateHistory.count - Self.maxHistoryPoints)
} }
@@ -82,45 +205,8 @@ final class InferenceStats {
} }
} }
// MARK: - Event recording (called from APIServer)
func requestStarted(contextLength: Int) {
activeRequests += 1
totalRequests += 1
isPrefilling = true
isGenerating = false
currentPromptTokens = 0
currentGenerationTokens = 0
currentTokensPerSecond = 0
contextMax = contextLength
contextUsed = 0
}
func prefillCompleted(promptTokens: Int) {
isPrefilling = false
isGenerating = true
currentPromptTokens = promptTokens
totalPromptTokens += promptTokens
contextUsed = promptTokens
}
func tokenGenerated(tokensPerSecond: Double, totalGenerated: Int) {
currentGenerationTokens = totalGenerated
currentTokensPerSecond = tokensPerSecond
contextUsed = currentPromptTokens + totalGenerated
}
func requestCompleted(promptTokens: Int, generationTokens: Int) {
activeRequests = max(0, activeRequests - 1)
totalGenerationTokens += generationTokens
if activeRequests == 0 {
isGenerating = false
isPrefilling = false
currentTokensPerSecond = 0
}
}
func reset() { func reset() {
LiveCounters.shared.reset()
activeRequests = 0 activeRequests = 0
currentPromptTokens = 0 currentPromptTokens = 0
currentGenerationTokens = 0 currentGenerationTokens = 0

View File

@@ -20,6 +20,7 @@ final class APIServer {
private var cachedSession: ChatSession? private var cachedSession: ChatSession?
private var cachedMessages: [Chat.Message]? private var cachedMessages: [Chat.Message]?
private var cachedModelId: String? private var cachedModelId: String?
private var cachedInstructions: String = ""
func start(modelManager: ModelManager, port: Int = 1234) { func start(modelManager: ModelManager, port: Int = 1234) {
guard !isRunning else { return } guard !isRunning else { return }
@@ -29,6 +30,10 @@ final class APIServer {
do { do {
let params = NWParameters.tcp let params = NWParameters.tcp
params.allowLocalEndpointReuse = true params.allowLocalEndpointReuse = true
// Disable Nagle's algorithm so small SSE events go out immediately
if let tcpOptions = params.defaultProtocolStack.transportProtocol as? NWProtocolTCP.Options {
tcpOptions.noDelay = true
}
listener = try NWListener(using: params, on: NWEndpoint.Port(integerLiteral: UInt16(port))) listener = try NWListener(using: params, on: NWEndpoint.Port(integerLiteral: UInt16(port)))
listener?.stateUpdateHandler = { [weak self] state in listener?.stateUpdateHandler = { [weak self] state in
@@ -68,6 +73,7 @@ final class APIServer {
cachedSession = nil cachedSession = nil
cachedMessages = nil cachedMessages = nil
cachedModelId = nil cachedModelId = nil
cachedInstructions = ""
inferenceStats.stopSampling() inferenceStats.stopSampling()
} }
@@ -183,6 +189,7 @@ final class APIServer {
cachedSession = nil cachedSession = nil
cachedMessages = nil cachedMessages = nil
cachedModelId = nil cachedModelId = nil
cachedInstructions = ""
await modelManager.loadModel(targetConfig) await modelManager.loadModel(targetConfig)
} }
} }
@@ -196,6 +203,7 @@ final class APIServer {
cachedSession = nil cachedSession = nil
cachedMessages = nil cachedMessages = nil
cachedModelId = nil cachedModelId = nil
cachedInstructions = ""
await modelManager.loadModel(config) await modelManager.loadModel(config)
} }
@@ -220,32 +228,33 @@ final class APIServer {
var images: [UserInput.Image] = [] var images: [UserInput.Image] = []
let currentModelRepoId = modelManager.currentModel?.repoId ?? modelName let currentModelRepoId = modelManager.currentModel?.repoId ?? modelName
// Inject tool definitions into the system prompt if tools are provided // Build the instructions string (system prompt + tool definitions).
// This is passed to ChatSession via `instructions:` rather than injected
// as history messages, so it avoids an expensive history-replay prefill.
var instructions: String = ""
// Collect system message text from the request
for msg in request.messages where msg.role == "system" {
let text = msg.content?.textContent ?? ""
if !text.isEmpty {
if !instructions.isEmpty { instructions += "\n\n" }
instructions += text
}
}
// Append tool definitions to instructions
if let tools = request.tools, !tools.isEmpty { if let tools = request.tools, !tools.isEmpty {
let toolSystemPrompt = ToolPromptBuilder.buildSystemPrompt(tools: tools, modelId: currentModelRepoId) let toolSystemPrompt = ToolPromptBuilder.buildSystemPrompt(tools: tools, modelId: currentModelRepoId)
if !instructions.isEmpty { instructions += "\n\n" }
// Check if there's already a system message instructions += toolSystemPrompt
let hasSystem = request.messages.contains { $0.role == "system" }
if hasSystem {
// Append tool prompt to existing system message (handled below during conversion)
} else {
// For Gemma: inject as user message (Gemma doesn't support system role natively)
// For Qwen: inject as system message
if currentModelRepoId.lowercased().contains("qwen") {
chatMessages.append(Chat.Message(role: .system, content: toolSystemPrompt))
} else {
chatMessages.append(Chat.Message(role: .user, content: toolSystemPrompt))
chatMessages.append(Chat.Message(role: .assistant, content: "Understood. I will use the provided tools when appropriate."))
}
}
} }
let toolsForInjection = request.tools let toolsForInjection = request.tools
let isQwen = currentModelRepoId.lowercased().contains("qwen") let isQwen = currentModelRepoId.lowercased().contains("qwen")
for msg in request.messages { // Convert non-system messages to Chat.Message
for msg in request.messages where msg.role != "system" {
let role: Chat.Message.Role = switch msg.role { let role: Chat.Message.Role = switch msg.role {
case "system": .system
case "assistant": .assistant case "assistant": .assistant
case "tool": .user case "tool": .user
default: .user default: .user
@@ -253,12 +262,6 @@ final class APIServer {
var text = msg.content?.textContent ?? "" var text = msg.content?.textContent ?? ""
// If this is a system message and tools are provided, append tool definitions
if msg.role == "system", let tools = toolsForInjection, !tools.isEmpty {
let toolSystemPrompt = ToolPromptBuilder.buildSystemPrompt(tools: tools, modelId: currentModelRepoId)
text = text + "\n\n" + toolSystemPrompt
}
// Format tool_call_id responses as tool_output for the model // Format tool_call_id responses as tool_output for the model
if msg.role == "tool" { if msg.role == "tool" {
if isQwen { if isQwen {
@@ -328,6 +331,7 @@ final class APIServer {
let canReuse = cachedSession != nil let canReuse = cachedSession != nil
&& cachedModelId == currentModelId && cachedModelId == currentModelId
&& cachedMessages != nil && cachedMessages != nil
&& cachedInstructions == instructions
&& messagesMatch(cachedMessages!, allButLast) && messagesMatch(cachedMessages!, allButLast)
let session: ChatSession let session: ChatSession
@@ -339,15 +343,21 @@ final class APIServer {
if cachedSession != nil { if cachedSession != nil {
print("[APIServer] History diverged, creating fresh session") print("[APIServer] History diverged, creating fresh session")
} }
// Use `instructions:` for system/tool prompt (matches internal chat pattern).
// Only conversation turns go in `history:` this avoids replaying the
// large tool prompt as history on every new session.
let instr = instructions.isEmpty ? nil : instructions
if !allButLast.isEmpty { if !allButLast.isEmpty {
session = ChatSession( session = ChatSession(
container, container,
instructions: instr,
history: allButLast, history: allButLast,
generateParameters: generateParams generateParameters: generateParams
) )
} else { } else {
session = ChatSession( session = ChatSession(
container, container,
instructions: instr,
generateParameters: generateParams generateParameters: generateParams
) )
} }
@@ -356,7 +366,7 @@ final class APIServer {
// Extract images from the last message only (ChatSession.streamDetails takes images separately) // Extract images from the last message only (ChatSession.streamDetails takes images separately)
let lastImages = lastMessage.images let lastImages = lastMessage.images
inferenceStats.requestStarted(contextLength: contextLength) LiveCounters.shared.requestStarted(contextLength: contextLength)
if isStream { if isStream {
await handleStreamingResponse( await handleStreamingResponse(
@@ -387,6 +397,7 @@ final class APIServer {
cachedSession = session cachedSession = session
cachedMessages = chatMessages // full messages including the one just sent cachedMessages = chatMessages // full messages including the one just sent
cachedModelId = currentModelId cachedModelId = currentModelId
cachedInstructions = instructions
} }
/// Decode a base64 data URI (data:image/png;base64,...) into a UserInput.Image. /// Decode a base64 data URI (data:image/png;base64,...) into a UserInput.Image.
@@ -439,20 +450,20 @@ final class APIServer {
case .chunk(let text): case .chunk(let text):
fullText += text fullText += text
completionTokens += 1 completionTokens += 1
inferenceStats.tokenGenerated(tokensPerSecond: 0, totalGenerated: completionTokens) LiveCounters.shared.tokenGenerated(tokensPerSecond: 0, totalGenerated: completionTokens)
case .info(let info): case .info(let info):
promptTokens = info.promptTokenCount promptTokens = info.promptTokenCount
completionTokens = info.generationTokenCount completionTokens = info.generationTokenCount
inferenceStats.prefillCompleted(promptTokens: promptTokens) LiveCounters.shared.prefillCompleted(promptTokens: promptTokens)
if info.tokensPerSecond > 0 { if info.tokensPerSecond > 0 {
inferenceStats.tokenGenerated(tokensPerSecond: info.tokensPerSecond, totalGenerated: completionTokens) LiveCounters.shared.tokenGenerated(tokensPerSecond: info.tokensPerSecond, totalGenerated: completionTokens)
} }
case .toolCall(let call): case .toolCall(let call):
frameworkToolCalls.append(call) frameworkToolCalls.append(call)
} }
} }
inferenceStats.requestCompleted(promptTokens: promptTokens, generationTokens: completionTokens) LiveCounters.shared.requestCompleted(generationTokens: completionTokens)
// Parse tool calls: first check framework-detected ones, then our own text parser // Parse tool calls: first check framework-detected ones, then our own text parser
var finishReason = "stop" var finishReason = "stop"
@@ -524,7 +535,7 @@ final class APIServer {
sendResponse(connection: connection, status: 200, body: String(data: json, encoding: .utf8) ?? "{}") sendResponse(connection: connection, status: 200, body: String(data: json, encoding: .utf8) ?? "{}")
} }
} catch { } catch {
inferenceStats.requestCompleted(promptTokens: 0, generationTokens: 0) LiveCounters.shared.requestCompleted(generationTokens: 0)
sendResponse(connection: connection, status: 500, body: #"{"error":"\#(error.localizedDescription)"}"#) sendResponse(connection: connection, status: 500, body: #"{"error":"\#(error.localizedDescription)"}"#)
} }
} }
@@ -552,15 +563,10 @@ final class APIServer {
"", "",
].joined(separator: "\r\n") ].joined(separator: "\r\n")
let headerSent = await withCheckedContinuation { continuation in await Self.sendData(connection: connection, data: header.data(using: .utf8)!)
connection.send(content: header.data(using: .utf8), completion: .contentProcessed({ _ in
continuation.resume(returning: true)
}))
}
guard headerSent else { return }
// Send initial role chunk // Send initial role chunk
sendSSEEvent(connection: connection, chunk: APIChatCompletionChunk( await Self.sendSSEEvent(connection: connection, chunk: APIChatCompletionChunk(
id: requestId, id: requestId,
object: "chat.completion.chunk", object: "chat.completion.chunk",
created: created, created: created,
@@ -569,78 +575,39 @@ final class APIServer {
usage: nil usage: nil
)) ))
// When tools are available, buffer full response to parse tool calls let hasTools = tools != nil && !(tools?.isEmpty ?? true)
// (otherwise raw tool-call markup leaks into streamed text)
let bufferForTools = tools != nil && !(tools?.isEmpty ?? true)
var promptTokens = 0 // Run the generation loop OFF MainActor.
var completionTokens = 0 // ChatSession and NWConnection don't need MainActor.
var fullText = "" // Running on MainActor caused every token to compete with SwiftUI
var frameworkToolCalls: [MLXLMCommon.ToolCall] = [] // rendering, creating back-pressure that coalesced all output.
let stream = session.streamDetails(
do { to: prompt,
let stream = session.streamDetails( images: images,
to: prompt, videos: []
images: images, )
videos: [] // Transfer non-Sendable values to the nonisolated loop.
// Safe because we don't touch session/images again until after the loop.
let result = await {
nonisolated(unsafe) let stream = stream
return await Self.runStreamingLoop(
connection: connection,
stream: stream,
requestId: requestId,
created: created,
modelName: modelName
) )
}()
for try await generation in stream { let (promptTokens, completionTokens, fullText, frameworkToolCalls) = result
switch generation {
case .chunk(let text):
completionTokens += 1
fullText += text
inferenceStats.tokenGenerated(tokensPerSecond: 0, totalGenerated: completionTokens)
if !bufferForTools { // Stats were already updated by LiveCounters inside the loop
sendSSEEvent(connection: connection, chunk: APIChatCompletionChunk(
id: requestId,
object: "chat.completion.chunk",
created: created,
model: modelName,
choices: [APIStreamChoice(index: 0, delta: APIDeltaMessage(role: nil, content: text, tool_calls: nil), finish_reason: nil)],
usage: nil
))
}
case .info(let info):
promptTokens = info.promptTokenCount
completionTokens = info.generationTokenCount
inferenceStats.prefillCompleted(promptTokens: promptTokens)
if info.tokensPerSecond > 0 {
inferenceStats.tokenGenerated(tokensPerSecond: info.tokensPerSecond, totalGenerated: completionTokens)
}
case .toolCall(let call):
frameworkToolCalls.append(call)
}
}
} catch {
inferenceStats.requestCompleted(promptTokens: promptTokens, generationTokens: completionTokens)
let errorEvent = "data: {\"error\":\"\(error.localizedDescription)\"}\n\n"
connection.send(content: errorEvent.data(using: .utf8), completion: .contentProcessed({ _ in }))
}
// Post-generation: handle tool calls (framework-detected or text-parsed) // Post-generation: handle tool calls (framework-detected or text-parsed)
var finishReason = "stop" var finishReason = "stop"
if !frameworkToolCalls.isEmpty { if !frameworkToolCalls.isEmpty {
// Framework natively detected tool calls (e.g. Qwen)
finishReason = "tool_calls" finishReason = "tool_calls"
// Emit any buffered text content
if !fullText.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty {
sendSSEEvent(connection: connection, chunk: APIChatCompletionChunk(
id: requestId,
object: "chat.completion.chunk",
created: created,
model: modelName,
choices: [APIStreamChoice(index: 0, delta: APIDeltaMessage(role: nil, content: fullText, tool_calls: nil), finish_reason: nil)],
usage: nil
))
}
// Emit tool call chunks
for (i, tc) in frameworkToolCalls.enumerated() { for (i, tc) in frameworkToolCalls.enumerated() {
let argsDict = tc.function.arguments.mapValues { $0.anyValue } let argsDict = tc.function.arguments.mapValues { $0.anyValue }
let argsJSON: String let argsJSON: String
@@ -657,7 +624,7 @@ final class APIServer {
type: "function", type: "function",
function: APIFunctionCall(name: tc.function.name, arguments: argsJSON) function: APIFunctionCall(name: tc.function.name, arguments: argsJSON)
) )
sendSSEEvent(connection: connection, chunk: APIChatCompletionChunk( await Self.sendSSEEvent(connection: connection, chunk: APIChatCompletionChunk(
id: requestId, id: requestId,
object: "chat.completion.chunk", object: "chat.completion.chunk",
created: created, created: created,
@@ -666,27 +633,11 @@ final class APIServer {
usage: nil usage: nil
)) ))
} }
} else if bufferForTools { } else if hasTools {
// Text-parsed tool calls (e.g. Gemma tool_code blocks) let (_, parsed) = ToolCallParser.parse(text: fullText, tools: tools)
let (cleanText, parsed) = ToolCallParser.parse(text: fullText, tools: tools)
if !parsed.isEmpty { if !parsed.isEmpty {
finishReason = "tool_calls" finishReason = "tool_calls"
fullText = cleanText
} }
// Emit buffered content (cleaned of tool-call markup)
if !fullText.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty {
sendSSEEvent(connection: connection, chunk: APIChatCompletionChunk(
id: requestId,
object: "chat.completion.chunk",
created: created,
model: modelName,
choices: [APIStreamChoice(index: 0, delta: APIDeltaMessage(role: nil, content: fullText, tool_calls: nil), finish_reason: nil)],
usage: nil
))
}
// Emit tool call chunks
for (i, tc) in parsed.enumerated() { for (i, tc) in parsed.enumerated() {
let apiToolCall = APIToolCall( let apiToolCall = APIToolCall(
index: i, index: i,
@@ -694,7 +645,7 @@ final class APIServer {
type: "function", type: "function",
function: APIFunctionCall(name: tc.name, arguments: tc.arguments) function: APIFunctionCall(name: tc.name, arguments: tc.arguments)
) )
sendSSEEvent(connection: connection, chunk: APIChatCompletionChunk( await Self.sendSSEEvent(connection: connection, chunk: APIChatCompletionChunk(
id: requestId, id: requestId,
object: "chat.completion.chunk", object: "chat.completion.chunk",
created: created, created: created,
@@ -706,7 +657,7 @@ final class APIServer {
} }
// Final chunk with finish_reason and usage // Final chunk with finish_reason and usage
sendSSEEvent(connection: connection, chunk: APIChatCompletionChunk( await Self.sendSSEEvent(connection: connection, chunk: APIChatCompletionChunk(
id: requestId, id: requestId,
object: "chat.completion.chunk", object: "chat.completion.chunk",
created: created, created: created,
@@ -719,20 +670,83 @@ final class APIServer {
) )
)) ))
inferenceStats.requestCompleted(promptTokens: promptTokens, generationTokens: completionTokens) LiveCounters.shared.requestCompleted(generationTokens: completionTokens)
// Send [DONE] and close // Send [DONE] and close
let done = "data: [DONE]\n\n" await Self.sendData(connection: connection, data: "data: [DONE]\n\n".data(using: .utf8)!)
connection.send(content: done.data(using: .utf8), completion: .contentProcessed({ _ in connection.cancel()
connection.cancel()
}))
} }
private func sendSSEEvent(connection: NWConnection, chunk: APIChatCompletionChunk) { /// Run the token generation + SSE send loop entirely off MainActor.
/// This is critical: if the loop runs on MainActor, every token requires
/// multiple actor hops competing with SwiftUI, causing all output to batch.
nonisolated private static func runStreamingLoop(
connection: NWConnection,
stream: AsyncThrowingStream<Generation, any Error>,
requestId: String,
created: Int,
modelName: String
) async -> (Int, Int, String, [MLXLMCommon.ToolCall]) {
var promptTokens = 0
var completionTokens = 0
var fullText = ""
var frameworkToolCalls: [MLXLMCommon.ToolCall] = []
do {
for try await generation in stream {
switch generation {
case .chunk(let text):
completionTokens += 1
fullText += text
// Update live counters directly no MainActor hop needed
LiveCounters.shared.tokenGenerated(tokensPerSecond: 0, totalGenerated: completionTokens)
// Send directly no MainActor hop.
await sendSSEEvent(connection: connection, chunk: APIChatCompletionChunk(
id: requestId,
object: "chat.completion.chunk",
created: created,
model: modelName,
choices: [APIStreamChoice(index: 0, delta: APIDeltaMessage(role: nil, content: text, tool_calls: nil), finish_reason: nil)],
usage: nil
))
case .info(let info):
promptTokens = info.promptTokenCount
completionTokens = info.generationTokenCount
LiveCounters.shared.prefillCompleted(promptTokens: promptTokens)
if info.tokensPerSecond > 0 {
LiveCounters.shared.tokenGenerated(tokensPerSecond: info.tokensPerSecond, totalGenerated: completionTokens)
}
case .toolCall(let call):
frameworkToolCalls.append(call)
}
}
} catch {
let errorEvent = "data: {\"error\":\"\(error.localizedDescription)\"}\n\n"
await sendData(connection: connection, data: errorEvent.data(using: .utf8)!)
}
return (promptTokens, completionTokens, fullText, frameworkToolCalls)
}
/// Send an SSE event and wait for the protocol stack to process it.
nonisolated private static func sendSSEEvent(connection: NWConnection, chunk: APIChatCompletionChunk) async {
guard let json = try? JSONEncoder().encode(chunk), guard let json = try? JSONEncoder().encode(chunk),
let jsonString = String(data: json, encoding: .utf8) else { return } let jsonString = String(data: json, encoding: .utf8) else { return }
let event = "data: \(jsonString)\n\n" let event = "data: \(jsonString)\n\n"
connection.send(content: event.data(using: .utf8), completion: .contentProcessed({ _ in })) await sendData(connection: connection, data: event.data(using: .utf8)!)
}
/// Send raw data on the connection and wait for the protocol stack to process it.
nonisolated private static func sendData(connection: NWConnection, data: Data) async {
await withCheckedContinuation { (continuation: CheckedContinuation<Void, Never>) in
connection.send(content: data, completion: .contentProcessed({ _ in
continuation.resume()
}))
}
} }
// MARK: - HTTP helpers // MARK: - HTTP helpers
@@ -778,19 +792,19 @@ final class APIServer {
] ]
} }
/// Check if cached messages are a prefix of new messages (for KV cache reuse). /// Check if the cached session can be reused for the new history.
/// The cached messages include the full history from the previous request. ///
/// For cache reuse, all but the last message of the new request must match /// After a request the session's KV cache contains:
/// all but the last message of the cached messages (the cached last was the /// cachedMessages (history + user prompt) + the generated assistant response.
/// previous user prompt, which is now part of the history). /// On the next request the client sends back the full conversation, so
/// `newHistory` (allButLast) is typically `cachedMessages` + 1 assistant reply.
/// We allow reuse when `cached` is a prefix of `newHistory` and there is at most
/// one extra message (the assistant response the session already generated).
/// More than one extra message (e.g. injected tool results) means the session
/// hasn't processed them, so we must create a fresh session.
private func messagesMatch(_ cached: [Chat.Message], _ newHistory: [Chat.Message]) -> Bool { private func messagesMatch(_ cached: [Chat.Message], _ newHistory: [Chat.Message]) -> Bool {
// The cached messages are the full chatMessages from the previous request. guard cached.count <= newHistory.count,
// For the cache to be reusable, the new history (allButLast) must match newHistory.count <= cached.count + 1 else { return false }
// exactly what the session has already processed.
// After a request, the session has seen: cachedMessages' history + prompt + response.
// So on the next request, if newHistory == cachedMessages, the session already
// contains all of those turns and we can just send the new last message.
guard cached.count == newHistory.count else { return false }
for (a, b) in zip(cached, newHistory) { for (a, b) in zip(cached, newHistory) {
if a.role != b.role || a.content != b.content { return false } if a.role != b.role || a.content != b.content { return false }
} }

View File

@@ -24,6 +24,12 @@ enum ToolCallParser {
return (gemmaClean, gemmaCalls) return (gemmaClean, gemmaCalls)
} }
// Try bare function calls matching known tool names: tool_name(args...)
let (bareClean, bareCalls) = parseBareToolCalls(text: text, tools: tools)
if !bareCalls.isEmpty {
return (bareClean, bareCalls)
}
return (text, []) return (text, [])
} }
@@ -187,4 +193,57 @@ enum ToolCallParser {
return (cleanText, toolCalls) return (cleanText, toolCalls)
} }
// MARK: - Bare function calls: tool_name(args...)
/// Parse bare function calls that match known tool names.
/// Handles models that output tool calls without fences or XML tags.
private static func parseBareToolCalls(text: String, tools: [APIToolDefinition]?) -> (String, [ParsedToolCall]) {
guard let tools, !tools.isEmpty else { return (text, []) }
let toolNames = tools.map { $0.function.name }
guard !toolNames.isEmpty else { return (text, []) }
// Build regex: (tool_name1|tool_name2)\s*\(.*\)
let escapedNames = toolNames.map { NSRegularExpression.escapedPattern(for: $0) }
let pattern = "(" + escapedNames.joined(separator: "|") + #")\s*\((.*)\)"#
guard let regex = try? NSRegularExpression(pattern: pattern, options: .dotMatchesLineSeparators) else {
return (text, [])
}
let nsText = text as NSString
let matches = regex.matches(in: text, range: NSRange(location: 0, length: nsText.length))
guard !matches.isEmpty else { return (text, []) }
var toolDefs: [String: [String]] = [:]
for tool in tools {
let paramNames = tool.function.parameters?["properties"]?.value as? [String: Any]
toolDefs[tool.function.name] = paramNames.map { Array($0.keys).sorted() } ?? []
}
var toolCalls: [ParsedToolCall] = []
for (i, match) in matches.enumerated() {
let name = nsText.substring(with: match.range(at: 1))
let argsStr = nsText.substring(with: match.range(at: 2)).trimmingCharacters(in: .whitespacesAndNewlines)
var args: [String: Any] = [:]
if !argsStr.isEmpty {
if let (_, parsed) = parsePythonCall("\(name)(\(argsStr))", toolDefs: toolDefs) as (String, [String: Any])? {
args = parsed
}
}
let argsJSON = (try? JSONSerialization.data(withJSONObject: args))
.flatMap { String(data: $0, encoding: .utf8) } ?? "{}"
let callId = String(format: "call_%d_%08d", i, abs(name.hashValue) % 100_000_000)
toolCalls.append(ParsedToolCall(id: callId, name: name, arguments: argsJSON))
}
let cleanText = regex.stringByReplacingMatches(
in: text, range: NSRange(location: 0, length: nsText.length),
withTemplate: ""
).trimmingCharacters(in: .whitespacesAndNewlines)
return (cleanText, toolCalls)
}
} }