fix: better handling of API stuff, still not where internal chat is
This commit is contained in:
@@ -36,6 +36,15 @@ open "build/Debug/MLX Server.app"
|
||||
| `gemma` | `mlx-community/gemma-3-4b-it-4bit` | Vision + tool use via `tool_code` blocks (128k context) |
|
||||
| `qwen` | `mlx-community/Qwen3-VL-4B-Instruct-4bit` | Vision + tool use via `<tool_call>` tags (256k context) |
|
||||
|
||||
## Critical Performance Rule
|
||||
|
||||
**Inference speed is the #1 priority.** The token generation loop must never be blocked or slowed by anything else — no MainActor hops, no SwiftUI observation, no synchronous I/O. Everything that isn't inference (stats collection, UI updates, logging) must run on separate threads via loose coupling:
|
||||
|
||||
- **`LiveCounters`** (thread-safe singleton with `OSAllocatedUnfairLock`) is the bridge: generation code writes to it directly from any thread with zero actor overhead.
|
||||
- **`InferenceStats`** (UI-side, `@Observable @MainActor`) polls `LiveCounters` at 1Hz via a timer — never the other way around.
|
||||
- SSE streaming (`sendSSEEvent`/`sendData`) runs nonisolated off MainActor so token sends don't compete with SwiftUI rendering.
|
||||
- Never gate token output on UI state, analytics, or any `@MainActor`-isolated code.
|
||||
|
||||
## Key Design Decisions
|
||||
|
||||
- Uses `mlx-swift-lm` (`MLXVLM` / `VLMModelFactory`) as the inference backend — supports both text and vision in a single model load
|
||||
|
||||
@@ -1,11 +1,124 @@
|
||||
import Foundation
|
||||
import os
|
||||
|
||||
// MARK: - Thread-safe live counters (written from any thread, no actor isolation)
|
||||
|
||||
/// Lock-protected counters that the generation loop writes to directly.
|
||||
/// No MainActor requirement — the UI polls these via the 1Hz timer.
|
||||
final class LiveCounters: @unchecked Sendable {
|
||||
static let shared = LiveCounters()
|
||||
|
||||
private let lock = OSAllocatedUnfairLock()
|
||||
|
||||
// Current request
|
||||
private var _activeRequests: Int = 0
|
||||
private var _promptTokens: Int = 0
|
||||
private var _generationTokens: Int = 0
|
||||
private var _tokensPerSecond: Double = 0
|
||||
private var _isPrefilling: Bool = false
|
||||
private var _isGenerating: Bool = false
|
||||
private var _contextMax: Int = 0
|
||||
|
||||
// Cumulative
|
||||
private var _totalRequests: Int = 0
|
||||
private var _totalPromptTokens: Int = 0
|
||||
private var _totalGenerationTokens: Int = 0
|
||||
|
||||
func requestStarted(contextLength: Int) {
|
||||
lock.lock()
|
||||
_activeRequests += 1
|
||||
_totalRequests += 1
|
||||
_isPrefilling = true
|
||||
_isGenerating = false
|
||||
_promptTokens = 0
|
||||
_generationTokens = 0
|
||||
_tokensPerSecond = 0
|
||||
_contextMax = contextLength
|
||||
lock.unlock()
|
||||
}
|
||||
|
||||
func prefillCompleted(promptTokens: Int) {
|
||||
lock.lock()
|
||||
_isPrefilling = false
|
||||
_isGenerating = true
|
||||
_promptTokens = promptTokens
|
||||
_totalPromptTokens += promptTokens
|
||||
lock.unlock()
|
||||
}
|
||||
|
||||
func tokenGenerated(tokensPerSecond: Double, totalGenerated: Int) {
|
||||
lock.lock()
|
||||
_generationTokens = totalGenerated
|
||||
_tokensPerSecond = tokensPerSecond
|
||||
lock.unlock()
|
||||
}
|
||||
|
||||
func requestCompleted(generationTokens: Int) {
|
||||
lock.lock()
|
||||
_activeRequests = max(0, _activeRequests - 1)
|
||||
_totalGenerationTokens += generationTokens
|
||||
if _activeRequests == 0 {
|
||||
_isGenerating = false
|
||||
_isPrefilling = false
|
||||
_tokensPerSecond = 0
|
||||
}
|
||||
lock.unlock()
|
||||
}
|
||||
|
||||
func reset() {
|
||||
lock.lock()
|
||||
_activeRequests = 0
|
||||
_promptTokens = 0
|
||||
_generationTokens = 0
|
||||
_tokensPerSecond = 0
|
||||
_isPrefilling = false
|
||||
_isGenerating = false
|
||||
_contextMax = 0
|
||||
_totalRequests = 0
|
||||
_totalPromptTokens = 0
|
||||
_totalGenerationTokens = 0
|
||||
lock.unlock()
|
||||
}
|
||||
|
||||
/// Atomic snapshot for the UI timer.
|
||||
func snapshot() -> Snapshot {
|
||||
lock.lock()
|
||||
let s = Snapshot(
|
||||
activeRequests: _activeRequests,
|
||||
promptTokens: _promptTokens,
|
||||
generationTokens: _generationTokens,
|
||||
tokensPerSecond: _tokensPerSecond,
|
||||
isPrefilling: _isPrefilling,
|
||||
isGenerating: _isGenerating,
|
||||
contextMax: _contextMax,
|
||||
totalRequests: _totalRequests,
|
||||
totalPromptTokens: _totalPromptTokens,
|
||||
totalGenerationTokens: _totalGenerationTokens
|
||||
)
|
||||
lock.unlock()
|
||||
return s
|
||||
}
|
||||
|
||||
struct Snapshot {
|
||||
let activeRequests: Int
|
||||
let promptTokens: Int
|
||||
let generationTokens: Int
|
||||
let tokensPerSecond: Double
|
||||
let isPrefilling: Bool
|
||||
let isGenerating: Bool
|
||||
let contextMax: Int
|
||||
let totalRequests: Int
|
||||
let totalPromptTokens: Int
|
||||
let totalGenerationTokens: Int
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Observable stats for the UI (polls LiveCounters at 1Hz)
|
||||
|
||||
/// Lightweight stats collector for inference activity visualization.
|
||||
/// All mutations happen on @MainActor to avoid locks.
|
||||
@Observable
|
||||
@MainActor
|
||||
final class InferenceStats {
|
||||
// MARK: - Current request state
|
||||
// MARK: - Current request state (refreshed from LiveCounters)
|
||||
|
||||
var activeRequests: Int = 0
|
||||
var currentPromptTokens: Int = 0
|
||||
@@ -40,11 +153,9 @@ final class InferenceStats {
|
||||
private var sampleTimer: Timer?
|
||||
private var lastGenerationTokenCount: Int = 0
|
||||
private var lastPromptTokenCount: Int = 0
|
||||
private var lastSampleTime: Date = .now
|
||||
|
||||
func startSampling() {
|
||||
guard sampleTimer == nil else { return }
|
||||
lastSampleTime = .now
|
||||
sampleTimer = Timer.scheduledTimer(withTimeInterval: 1.0, repeats: true) { [weak self] _ in
|
||||
Task { @MainActor in
|
||||
self?.recordSample()
|
||||
@@ -58,19 +169,31 @@ final class InferenceStats {
|
||||
}
|
||||
|
||||
private func recordSample() {
|
||||
// Pull live values from the thread-safe counters
|
||||
let snap = LiveCounters.shared.snapshot()
|
||||
|
||||
activeRequests = snap.activeRequests
|
||||
currentPromptTokens = snap.promptTokens
|
||||
currentGenerationTokens = snap.generationTokens
|
||||
currentTokensPerSecond = snap.tokensPerSecond
|
||||
isPrefilling = snap.isPrefilling
|
||||
isGenerating = snap.isGenerating
|
||||
contextMax = snap.contextMax
|
||||
contextUsed = snap.promptTokens + snap.generationTokens
|
||||
totalRequests = snap.totalRequests
|
||||
totalPromptTokens = snap.totalPromptTokens
|
||||
totalGenerationTokens = snap.totalGenerationTokens
|
||||
|
||||
let now = Date.now
|
||||
let genDelta = snap.totalGenerationTokens - lastGenerationTokenCount
|
||||
let promptDelta = snap.totalPromptTokens - lastPromptTokenCount
|
||||
lastGenerationTokenCount = snap.totalGenerationTokens
|
||||
lastPromptTokenCount = snap.totalPromptTokens
|
||||
|
||||
// Token rate: tokens generated since last sample
|
||||
let genDelta = totalGenerationTokens - lastGenerationTokenCount
|
||||
let promptDelta = totalPromptTokens - lastPromptTokenCount
|
||||
lastGenerationTokenCount = totalGenerationTokens
|
||||
lastPromptTokenCount = totalPromptTokens
|
||||
|
||||
tokenRateHistory.append(DataPoint(timestamp: now, value: currentTokensPerSecond))
|
||||
tokenRateHistory.append(DataPoint(timestamp: now, value: snap.tokensPerSecond))
|
||||
generationTokenHistory.append(DataPoint(timestamp: now, value: Double(genDelta)))
|
||||
promptTokenHistory.append(DataPoint(timestamp: now, value: Double(promptDelta)))
|
||||
|
||||
// Trim to ring buffer size
|
||||
if tokenRateHistory.count > Self.maxHistoryPoints {
|
||||
tokenRateHistory.removeFirst(tokenRateHistory.count - Self.maxHistoryPoints)
|
||||
}
|
||||
@@ -82,45 +205,8 @@ final class InferenceStats {
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Event recording (called from APIServer)
|
||||
|
||||
func requestStarted(contextLength: Int) {
|
||||
activeRequests += 1
|
||||
totalRequests += 1
|
||||
isPrefilling = true
|
||||
isGenerating = false
|
||||
currentPromptTokens = 0
|
||||
currentGenerationTokens = 0
|
||||
currentTokensPerSecond = 0
|
||||
contextMax = contextLength
|
||||
contextUsed = 0
|
||||
}
|
||||
|
||||
func prefillCompleted(promptTokens: Int) {
|
||||
isPrefilling = false
|
||||
isGenerating = true
|
||||
currentPromptTokens = promptTokens
|
||||
totalPromptTokens += promptTokens
|
||||
contextUsed = promptTokens
|
||||
}
|
||||
|
||||
func tokenGenerated(tokensPerSecond: Double, totalGenerated: Int) {
|
||||
currentGenerationTokens = totalGenerated
|
||||
currentTokensPerSecond = tokensPerSecond
|
||||
contextUsed = currentPromptTokens + totalGenerated
|
||||
}
|
||||
|
||||
func requestCompleted(promptTokens: Int, generationTokens: Int) {
|
||||
activeRequests = max(0, activeRequests - 1)
|
||||
totalGenerationTokens += generationTokens
|
||||
if activeRequests == 0 {
|
||||
isGenerating = false
|
||||
isPrefilling = false
|
||||
currentTokensPerSecond = 0
|
||||
}
|
||||
}
|
||||
|
||||
func reset() {
|
||||
LiveCounters.shared.reset()
|
||||
activeRequests = 0
|
||||
currentPromptTokens = 0
|
||||
currentGenerationTokens = 0
|
||||
|
||||
@@ -20,6 +20,7 @@ final class APIServer {
|
||||
private var cachedSession: ChatSession?
|
||||
private var cachedMessages: [Chat.Message]?
|
||||
private var cachedModelId: String?
|
||||
private var cachedInstructions: String = ""
|
||||
|
||||
func start(modelManager: ModelManager, port: Int = 1234) {
|
||||
guard !isRunning else { return }
|
||||
@@ -29,6 +30,10 @@ final class APIServer {
|
||||
do {
|
||||
let params = NWParameters.tcp
|
||||
params.allowLocalEndpointReuse = true
|
||||
// Disable Nagle's algorithm so small SSE events go out immediately
|
||||
if let tcpOptions = params.defaultProtocolStack.transportProtocol as? NWProtocolTCP.Options {
|
||||
tcpOptions.noDelay = true
|
||||
}
|
||||
listener = try NWListener(using: params, on: NWEndpoint.Port(integerLiteral: UInt16(port)))
|
||||
|
||||
listener?.stateUpdateHandler = { [weak self] state in
|
||||
@@ -68,6 +73,7 @@ final class APIServer {
|
||||
cachedSession = nil
|
||||
cachedMessages = nil
|
||||
cachedModelId = nil
|
||||
cachedInstructions = ""
|
||||
inferenceStats.stopSampling()
|
||||
}
|
||||
|
||||
@@ -183,6 +189,7 @@ final class APIServer {
|
||||
cachedSession = nil
|
||||
cachedMessages = nil
|
||||
cachedModelId = nil
|
||||
cachedInstructions = ""
|
||||
await modelManager.loadModel(targetConfig)
|
||||
}
|
||||
}
|
||||
@@ -196,6 +203,7 @@ final class APIServer {
|
||||
cachedSession = nil
|
||||
cachedMessages = nil
|
||||
cachedModelId = nil
|
||||
cachedInstructions = ""
|
||||
await modelManager.loadModel(config)
|
||||
}
|
||||
|
||||
@@ -220,32 +228,33 @@ final class APIServer {
|
||||
var images: [UserInput.Image] = []
|
||||
let currentModelRepoId = modelManager.currentModel?.repoId ?? modelName
|
||||
|
||||
// Inject tool definitions into the system prompt if tools are provided
|
||||
// Build the instructions string (system prompt + tool definitions).
|
||||
// This is passed to ChatSession via `instructions:` rather than injected
|
||||
// as history messages, so it avoids an expensive history-replay prefill.
|
||||
var instructions: String = ""
|
||||
|
||||
// Collect system message text from the request
|
||||
for msg in request.messages where msg.role == "system" {
|
||||
let text = msg.content?.textContent ?? ""
|
||||
if !text.isEmpty {
|
||||
if !instructions.isEmpty { instructions += "\n\n" }
|
||||
instructions += text
|
||||
}
|
||||
}
|
||||
|
||||
// Append tool definitions to instructions
|
||||
if let tools = request.tools, !tools.isEmpty {
|
||||
let toolSystemPrompt = ToolPromptBuilder.buildSystemPrompt(tools: tools, modelId: currentModelRepoId)
|
||||
|
||||
// Check if there's already a system message
|
||||
let hasSystem = request.messages.contains { $0.role == "system" }
|
||||
if hasSystem {
|
||||
// Append tool prompt to existing system message (handled below during conversion)
|
||||
} else {
|
||||
// For Gemma: inject as user message (Gemma doesn't support system role natively)
|
||||
// For Qwen: inject as system message
|
||||
if currentModelRepoId.lowercased().contains("qwen") {
|
||||
chatMessages.append(Chat.Message(role: .system, content: toolSystemPrompt))
|
||||
} else {
|
||||
chatMessages.append(Chat.Message(role: .user, content: toolSystemPrompt))
|
||||
chatMessages.append(Chat.Message(role: .assistant, content: "Understood. I will use the provided tools when appropriate."))
|
||||
}
|
||||
}
|
||||
if !instructions.isEmpty { instructions += "\n\n" }
|
||||
instructions += toolSystemPrompt
|
||||
}
|
||||
|
||||
let toolsForInjection = request.tools
|
||||
let isQwen = currentModelRepoId.lowercased().contains("qwen")
|
||||
|
||||
for msg in request.messages {
|
||||
// Convert non-system messages to Chat.Message
|
||||
for msg in request.messages where msg.role != "system" {
|
||||
let role: Chat.Message.Role = switch msg.role {
|
||||
case "system": .system
|
||||
case "assistant": .assistant
|
||||
case "tool": .user
|
||||
default: .user
|
||||
@@ -253,12 +262,6 @@ final class APIServer {
|
||||
|
||||
var text = msg.content?.textContent ?? ""
|
||||
|
||||
// If this is a system message and tools are provided, append tool definitions
|
||||
if msg.role == "system", let tools = toolsForInjection, !tools.isEmpty {
|
||||
let toolSystemPrompt = ToolPromptBuilder.buildSystemPrompt(tools: tools, modelId: currentModelRepoId)
|
||||
text = text + "\n\n" + toolSystemPrompt
|
||||
}
|
||||
|
||||
// Format tool_call_id responses as tool_output for the model
|
||||
if msg.role == "tool" {
|
||||
if isQwen {
|
||||
@@ -328,6 +331,7 @@ final class APIServer {
|
||||
let canReuse = cachedSession != nil
|
||||
&& cachedModelId == currentModelId
|
||||
&& cachedMessages != nil
|
||||
&& cachedInstructions == instructions
|
||||
&& messagesMatch(cachedMessages!, allButLast)
|
||||
|
||||
let session: ChatSession
|
||||
@@ -339,15 +343,21 @@ final class APIServer {
|
||||
if cachedSession != nil {
|
||||
print("[APIServer] History diverged, creating fresh session")
|
||||
}
|
||||
// Use `instructions:` for system/tool prompt (matches internal chat pattern).
|
||||
// Only conversation turns go in `history:` — this avoids replaying the
|
||||
// large tool prompt as history on every new session.
|
||||
let instr = instructions.isEmpty ? nil : instructions
|
||||
if !allButLast.isEmpty {
|
||||
session = ChatSession(
|
||||
container,
|
||||
instructions: instr,
|
||||
history: allButLast,
|
||||
generateParameters: generateParams
|
||||
)
|
||||
} else {
|
||||
session = ChatSession(
|
||||
container,
|
||||
instructions: instr,
|
||||
generateParameters: generateParams
|
||||
)
|
||||
}
|
||||
@@ -356,7 +366,7 @@ final class APIServer {
|
||||
// Extract images from the last message only (ChatSession.streamDetails takes images separately)
|
||||
let lastImages = lastMessage.images
|
||||
|
||||
inferenceStats.requestStarted(contextLength: contextLength)
|
||||
LiveCounters.shared.requestStarted(contextLength: contextLength)
|
||||
|
||||
if isStream {
|
||||
await handleStreamingResponse(
|
||||
@@ -387,6 +397,7 @@ final class APIServer {
|
||||
cachedSession = session
|
||||
cachedMessages = chatMessages // full messages including the one just sent
|
||||
cachedModelId = currentModelId
|
||||
cachedInstructions = instructions
|
||||
}
|
||||
|
||||
/// Decode a base64 data URI (data:image/png;base64,...) into a UserInput.Image.
|
||||
@@ -439,20 +450,20 @@ final class APIServer {
|
||||
case .chunk(let text):
|
||||
fullText += text
|
||||
completionTokens += 1
|
||||
inferenceStats.tokenGenerated(tokensPerSecond: 0, totalGenerated: completionTokens)
|
||||
LiveCounters.shared.tokenGenerated(tokensPerSecond: 0, totalGenerated: completionTokens)
|
||||
case .info(let info):
|
||||
promptTokens = info.promptTokenCount
|
||||
completionTokens = info.generationTokenCount
|
||||
inferenceStats.prefillCompleted(promptTokens: promptTokens)
|
||||
LiveCounters.shared.prefillCompleted(promptTokens: promptTokens)
|
||||
if info.tokensPerSecond > 0 {
|
||||
inferenceStats.tokenGenerated(tokensPerSecond: info.tokensPerSecond, totalGenerated: completionTokens)
|
||||
LiveCounters.shared.tokenGenerated(tokensPerSecond: info.tokensPerSecond, totalGenerated: completionTokens)
|
||||
}
|
||||
case .toolCall(let call):
|
||||
frameworkToolCalls.append(call)
|
||||
}
|
||||
}
|
||||
|
||||
inferenceStats.requestCompleted(promptTokens: promptTokens, generationTokens: completionTokens)
|
||||
LiveCounters.shared.requestCompleted(generationTokens: completionTokens)
|
||||
|
||||
// Parse tool calls: first check framework-detected ones, then our own text parser
|
||||
var finishReason = "stop"
|
||||
@@ -524,7 +535,7 @@ final class APIServer {
|
||||
sendResponse(connection: connection, status: 200, body: String(data: json, encoding: .utf8) ?? "{}")
|
||||
}
|
||||
} catch {
|
||||
inferenceStats.requestCompleted(promptTokens: 0, generationTokens: 0)
|
||||
LiveCounters.shared.requestCompleted(generationTokens: 0)
|
||||
sendResponse(connection: connection, status: 500, body: #"{"error":"\#(error.localizedDescription)"}"#)
|
||||
}
|
||||
}
|
||||
@@ -552,15 +563,10 @@ final class APIServer {
|
||||
"",
|
||||
].joined(separator: "\r\n")
|
||||
|
||||
let headerSent = await withCheckedContinuation { continuation in
|
||||
connection.send(content: header.data(using: .utf8), completion: .contentProcessed({ _ in
|
||||
continuation.resume(returning: true)
|
||||
}))
|
||||
}
|
||||
guard headerSent else { return }
|
||||
await Self.sendData(connection: connection, data: header.data(using: .utf8)!)
|
||||
|
||||
// Send initial role chunk
|
||||
sendSSEEvent(connection: connection, chunk: APIChatCompletionChunk(
|
||||
await Self.sendSSEEvent(connection: connection, chunk: APIChatCompletionChunk(
|
||||
id: requestId,
|
||||
object: "chat.completion.chunk",
|
||||
created: created,
|
||||
@@ -569,78 +575,39 @@ final class APIServer {
|
||||
usage: nil
|
||||
))
|
||||
|
||||
// When tools are available, buffer full response to parse tool calls
|
||||
// (otherwise raw tool-call markup leaks into streamed text)
|
||||
let bufferForTools = tools != nil && !(tools?.isEmpty ?? true)
|
||||
let hasTools = tools != nil && !(tools?.isEmpty ?? true)
|
||||
|
||||
var promptTokens = 0
|
||||
var completionTokens = 0
|
||||
var fullText = ""
|
||||
var frameworkToolCalls: [MLXLMCommon.ToolCall] = []
|
||||
|
||||
do {
|
||||
// Run the generation loop OFF MainActor.
|
||||
// ChatSession and NWConnection don't need MainActor.
|
||||
// Running on MainActor caused every token to compete with SwiftUI
|
||||
// rendering, creating back-pressure that coalesced all output.
|
||||
let stream = session.streamDetails(
|
||||
to: prompt,
|
||||
images: images,
|
||||
videos: []
|
||||
)
|
||||
|
||||
for try await generation in stream {
|
||||
switch generation {
|
||||
case .chunk(let text):
|
||||
completionTokens += 1
|
||||
fullText += text
|
||||
inferenceStats.tokenGenerated(tokensPerSecond: 0, totalGenerated: completionTokens)
|
||||
|
||||
if !bufferForTools {
|
||||
sendSSEEvent(connection: connection, chunk: APIChatCompletionChunk(
|
||||
id: requestId,
|
||||
object: "chat.completion.chunk",
|
||||
// Transfer non-Sendable values to the nonisolated loop.
|
||||
// Safe because we don't touch session/images again until after the loop.
|
||||
let result = await {
|
||||
nonisolated(unsafe) let stream = stream
|
||||
return await Self.runStreamingLoop(
|
||||
connection: connection,
|
||||
stream: stream,
|
||||
requestId: requestId,
|
||||
created: created,
|
||||
model: modelName,
|
||||
choices: [APIStreamChoice(index: 0, delta: APIDeltaMessage(role: nil, content: text, tool_calls: nil), finish_reason: nil)],
|
||||
usage: nil
|
||||
))
|
||||
}
|
||||
modelName: modelName
|
||||
)
|
||||
}()
|
||||
|
||||
case .info(let info):
|
||||
promptTokens = info.promptTokenCount
|
||||
completionTokens = info.generationTokenCount
|
||||
inferenceStats.prefillCompleted(promptTokens: promptTokens)
|
||||
if info.tokensPerSecond > 0 {
|
||||
inferenceStats.tokenGenerated(tokensPerSecond: info.tokensPerSecond, totalGenerated: completionTokens)
|
||||
}
|
||||
let (promptTokens, completionTokens, fullText, frameworkToolCalls) = result
|
||||
|
||||
case .toolCall(let call):
|
||||
frameworkToolCalls.append(call)
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
inferenceStats.requestCompleted(promptTokens: promptTokens, generationTokens: completionTokens)
|
||||
let errorEvent = "data: {\"error\":\"\(error.localizedDescription)\"}\n\n"
|
||||
connection.send(content: errorEvent.data(using: .utf8), completion: .contentProcessed({ _ in }))
|
||||
}
|
||||
// Stats were already updated by LiveCounters inside the loop
|
||||
|
||||
// Post-generation: handle tool calls (framework-detected or text-parsed)
|
||||
var finishReason = "stop"
|
||||
|
||||
if !frameworkToolCalls.isEmpty {
|
||||
// Framework natively detected tool calls (e.g. Qwen)
|
||||
finishReason = "tool_calls"
|
||||
|
||||
// Emit any buffered text content
|
||||
if !fullText.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty {
|
||||
sendSSEEvent(connection: connection, chunk: APIChatCompletionChunk(
|
||||
id: requestId,
|
||||
object: "chat.completion.chunk",
|
||||
created: created,
|
||||
model: modelName,
|
||||
choices: [APIStreamChoice(index: 0, delta: APIDeltaMessage(role: nil, content: fullText, tool_calls: nil), finish_reason: nil)],
|
||||
usage: nil
|
||||
))
|
||||
}
|
||||
|
||||
// Emit tool call chunks
|
||||
for (i, tc) in frameworkToolCalls.enumerated() {
|
||||
let argsDict = tc.function.arguments.mapValues { $0.anyValue }
|
||||
let argsJSON: String
|
||||
@@ -657,7 +624,7 @@ final class APIServer {
|
||||
type: "function",
|
||||
function: APIFunctionCall(name: tc.function.name, arguments: argsJSON)
|
||||
)
|
||||
sendSSEEvent(connection: connection, chunk: APIChatCompletionChunk(
|
||||
await Self.sendSSEEvent(connection: connection, chunk: APIChatCompletionChunk(
|
||||
id: requestId,
|
||||
object: "chat.completion.chunk",
|
||||
created: created,
|
||||
@@ -666,27 +633,11 @@ final class APIServer {
|
||||
usage: nil
|
||||
))
|
||||
}
|
||||
} else if bufferForTools {
|
||||
// Text-parsed tool calls (e.g. Gemma tool_code blocks)
|
||||
let (cleanText, parsed) = ToolCallParser.parse(text: fullText, tools: tools)
|
||||
} else if hasTools {
|
||||
let (_, parsed) = ToolCallParser.parse(text: fullText, tools: tools)
|
||||
if !parsed.isEmpty {
|
||||
finishReason = "tool_calls"
|
||||
fullText = cleanText
|
||||
}
|
||||
|
||||
// Emit buffered content (cleaned of tool-call markup)
|
||||
if !fullText.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty {
|
||||
sendSSEEvent(connection: connection, chunk: APIChatCompletionChunk(
|
||||
id: requestId,
|
||||
object: "chat.completion.chunk",
|
||||
created: created,
|
||||
model: modelName,
|
||||
choices: [APIStreamChoice(index: 0, delta: APIDeltaMessage(role: nil, content: fullText, tool_calls: nil), finish_reason: nil)],
|
||||
usage: nil
|
||||
))
|
||||
}
|
||||
|
||||
// Emit tool call chunks
|
||||
for (i, tc) in parsed.enumerated() {
|
||||
let apiToolCall = APIToolCall(
|
||||
index: i,
|
||||
@@ -694,7 +645,7 @@ final class APIServer {
|
||||
type: "function",
|
||||
function: APIFunctionCall(name: tc.name, arguments: tc.arguments)
|
||||
)
|
||||
sendSSEEvent(connection: connection, chunk: APIChatCompletionChunk(
|
||||
await Self.sendSSEEvent(connection: connection, chunk: APIChatCompletionChunk(
|
||||
id: requestId,
|
||||
object: "chat.completion.chunk",
|
||||
created: created,
|
||||
@@ -706,7 +657,7 @@ final class APIServer {
|
||||
}
|
||||
|
||||
// Final chunk with finish_reason and usage
|
||||
sendSSEEvent(connection: connection, chunk: APIChatCompletionChunk(
|
||||
await Self.sendSSEEvent(connection: connection, chunk: APIChatCompletionChunk(
|
||||
id: requestId,
|
||||
object: "chat.completion.chunk",
|
||||
created: created,
|
||||
@@ -719,20 +670,83 @@ final class APIServer {
|
||||
)
|
||||
))
|
||||
|
||||
inferenceStats.requestCompleted(promptTokens: promptTokens, generationTokens: completionTokens)
|
||||
LiveCounters.shared.requestCompleted(generationTokens: completionTokens)
|
||||
|
||||
// Send [DONE] and close
|
||||
let done = "data: [DONE]\n\n"
|
||||
connection.send(content: done.data(using: .utf8), completion: .contentProcessed({ _ in
|
||||
await Self.sendData(connection: connection, data: "data: [DONE]\n\n".data(using: .utf8)!)
|
||||
connection.cancel()
|
||||
}))
|
||||
}
|
||||
|
||||
private func sendSSEEvent(connection: NWConnection, chunk: APIChatCompletionChunk) {
|
||||
/// Run the token generation + SSE send loop entirely off MainActor.
|
||||
/// This is critical: if the loop runs on MainActor, every token requires
|
||||
/// multiple actor hops competing with SwiftUI, causing all output to batch.
|
||||
nonisolated private static func runStreamingLoop(
|
||||
connection: NWConnection,
|
||||
stream: AsyncThrowingStream<Generation, any Error>,
|
||||
requestId: String,
|
||||
created: Int,
|
||||
modelName: String
|
||||
) async -> (Int, Int, String, [MLXLMCommon.ToolCall]) {
|
||||
var promptTokens = 0
|
||||
var completionTokens = 0
|
||||
var fullText = ""
|
||||
var frameworkToolCalls: [MLXLMCommon.ToolCall] = []
|
||||
|
||||
do {
|
||||
for try await generation in stream {
|
||||
switch generation {
|
||||
case .chunk(let text):
|
||||
completionTokens += 1
|
||||
fullText += text
|
||||
|
||||
// Update live counters directly — no MainActor hop needed
|
||||
LiveCounters.shared.tokenGenerated(tokensPerSecond: 0, totalGenerated: completionTokens)
|
||||
|
||||
// Send directly — no MainActor hop.
|
||||
await sendSSEEvent(connection: connection, chunk: APIChatCompletionChunk(
|
||||
id: requestId,
|
||||
object: "chat.completion.chunk",
|
||||
created: created,
|
||||
model: modelName,
|
||||
choices: [APIStreamChoice(index: 0, delta: APIDeltaMessage(role: nil, content: text, tool_calls: nil), finish_reason: nil)],
|
||||
usage: nil
|
||||
))
|
||||
|
||||
case .info(let info):
|
||||
promptTokens = info.promptTokenCount
|
||||
completionTokens = info.generationTokenCount
|
||||
LiveCounters.shared.prefillCompleted(promptTokens: promptTokens)
|
||||
if info.tokensPerSecond > 0 {
|
||||
LiveCounters.shared.tokenGenerated(tokensPerSecond: info.tokensPerSecond, totalGenerated: completionTokens)
|
||||
}
|
||||
|
||||
case .toolCall(let call):
|
||||
frameworkToolCalls.append(call)
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
let errorEvent = "data: {\"error\":\"\(error.localizedDescription)\"}\n\n"
|
||||
await sendData(connection: connection, data: errorEvent.data(using: .utf8)!)
|
||||
}
|
||||
|
||||
return (promptTokens, completionTokens, fullText, frameworkToolCalls)
|
||||
}
|
||||
|
||||
/// Send an SSE event and wait for the protocol stack to process it.
|
||||
nonisolated private static func sendSSEEvent(connection: NWConnection, chunk: APIChatCompletionChunk) async {
|
||||
guard let json = try? JSONEncoder().encode(chunk),
|
||||
let jsonString = String(data: json, encoding: .utf8) else { return }
|
||||
let event = "data: \(jsonString)\n\n"
|
||||
connection.send(content: event.data(using: .utf8), completion: .contentProcessed({ _ in }))
|
||||
await sendData(connection: connection, data: event.data(using: .utf8)!)
|
||||
}
|
||||
|
||||
/// Send raw data on the connection and wait for the protocol stack to process it.
|
||||
nonisolated private static func sendData(connection: NWConnection, data: Data) async {
|
||||
await withCheckedContinuation { (continuation: CheckedContinuation<Void, Never>) in
|
||||
connection.send(content: data, completion: .contentProcessed({ _ in
|
||||
continuation.resume()
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - HTTP helpers
|
||||
@@ -778,19 +792,19 @@ final class APIServer {
|
||||
]
|
||||
}
|
||||
|
||||
/// Check if cached messages are a prefix of new messages (for KV cache reuse).
|
||||
/// The cached messages include the full history from the previous request.
|
||||
/// For cache reuse, all but the last message of the new request must match
|
||||
/// all but the last message of the cached messages (the cached last was the
|
||||
/// previous user prompt, which is now part of the history).
|
||||
/// Check if the cached session can be reused for the new history.
|
||||
///
|
||||
/// After a request the session's KV cache contains:
|
||||
/// cachedMessages (history + user prompt) + the generated assistant response.
|
||||
/// On the next request the client sends back the full conversation, so
|
||||
/// `newHistory` (allButLast) is typically `cachedMessages` + 1 assistant reply.
|
||||
/// We allow reuse when `cached` is a prefix of `newHistory` and there is at most
|
||||
/// one extra message (the assistant response the session already generated).
|
||||
/// More than one extra message (e.g. injected tool results) means the session
|
||||
/// hasn't processed them, so we must create a fresh session.
|
||||
private func messagesMatch(_ cached: [Chat.Message], _ newHistory: [Chat.Message]) -> Bool {
|
||||
// The cached messages are the full chatMessages from the previous request.
|
||||
// For the cache to be reusable, the new history (allButLast) must match
|
||||
// exactly what the session has already processed.
|
||||
// After a request, the session has seen: cachedMessages' history + prompt + response.
|
||||
// So on the next request, if newHistory == cachedMessages, the session already
|
||||
// contains all of those turns and we can just send the new last message.
|
||||
guard cached.count == newHistory.count else { return false }
|
||||
guard cached.count <= newHistory.count,
|
||||
newHistory.count <= cached.count + 1 else { return false }
|
||||
for (a, b) in zip(cached, newHistory) {
|
||||
if a.role != b.role || a.content != b.content { return false }
|
||||
}
|
||||
|
||||
@@ -24,6 +24,12 @@ enum ToolCallParser {
|
||||
return (gemmaClean, gemmaCalls)
|
||||
}
|
||||
|
||||
// Try bare function calls matching known tool names: tool_name(args...)
|
||||
let (bareClean, bareCalls) = parseBareToolCalls(text: text, tools: tools)
|
||||
if !bareCalls.isEmpty {
|
||||
return (bareClean, bareCalls)
|
||||
}
|
||||
|
||||
return (text, [])
|
||||
}
|
||||
|
||||
@@ -187,4 +193,57 @@ enum ToolCallParser {
|
||||
|
||||
return (cleanText, toolCalls)
|
||||
}
|
||||
|
||||
// MARK: - Bare function calls: tool_name(args...)
|
||||
|
||||
/// Parse bare function calls that match known tool names.
|
||||
/// Handles models that output tool calls without fences or XML tags.
|
||||
private static func parseBareToolCalls(text: String, tools: [APIToolDefinition]?) -> (String, [ParsedToolCall]) {
|
||||
guard let tools, !tools.isEmpty else { return (text, []) }
|
||||
|
||||
let toolNames = tools.map { $0.function.name }
|
||||
guard !toolNames.isEmpty else { return (text, []) }
|
||||
|
||||
// Build regex: (tool_name1|tool_name2)\s*\(.*\)
|
||||
let escapedNames = toolNames.map { NSRegularExpression.escapedPattern(for: $0) }
|
||||
let pattern = "(" + escapedNames.joined(separator: "|") + #")\s*\((.*)\)"#
|
||||
guard let regex = try? NSRegularExpression(pattern: pattern, options: .dotMatchesLineSeparators) else {
|
||||
return (text, [])
|
||||
}
|
||||
|
||||
let nsText = text as NSString
|
||||
let matches = regex.matches(in: text, range: NSRange(location: 0, length: nsText.length))
|
||||
guard !matches.isEmpty else { return (text, []) }
|
||||
|
||||
var toolDefs: [String: [String]] = [:]
|
||||
for tool in tools {
|
||||
let paramNames = tool.function.parameters?["properties"]?.value as? [String: Any]
|
||||
toolDefs[tool.function.name] = paramNames.map { Array($0.keys).sorted() } ?? []
|
||||
}
|
||||
|
||||
var toolCalls: [ParsedToolCall] = []
|
||||
for (i, match) in matches.enumerated() {
|
||||
let name = nsText.substring(with: match.range(at: 1))
|
||||
let argsStr = nsText.substring(with: match.range(at: 2)).trimmingCharacters(in: .whitespacesAndNewlines)
|
||||
|
||||
var args: [String: Any] = [:]
|
||||
if !argsStr.isEmpty {
|
||||
if let (_, parsed) = parsePythonCall("\(name)(\(argsStr))", toolDefs: toolDefs) as (String, [String: Any])? {
|
||||
args = parsed
|
||||
}
|
||||
}
|
||||
|
||||
let argsJSON = (try? JSONSerialization.data(withJSONObject: args))
|
||||
.flatMap { String(data: $0, encoding: .utf8) } ?? "{}"
|
||||
let callId = String(format: "call_%d_%08d", i, abs(name.hashValue) % 100_000_000)
|
||||
toolCalls.append(ParsedToolCall(id: callId, name: name, arguments: argsJSON))
|
||||
}
|
||||
|
||||
let cleanText = regex.stringByReplacingMatches(
|
||||
in: text, range: NSRange(location: 0, length: nsText.length),
|
||||
withTemplate: ""
|
||||
).trimmingCharacters(in: .whitespacesAndNewlines)
|
||||
|
||||
return (cleanText, toolCalls)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user