fix: better handling of API stuff, still not where internal chat is
This commit is contained in:
@@ -36,6 +36,15 @@ open "build/Debug/MLX Server.app"
|
|||||||
| `gemma` | `mlx-community/gemma-3-4b-it-4bit` | Vision + tool use via `tool_code` blocks (128k context) |
|
| `gemma` | `mlx-community/gemma-3-4b-it-4bit` | Vision + tool use via `tool_code` blocks (128k context) |
|
||||||
| `qwen` | `mlx-community/Qwen3-VL-4B-Instruct-4bit` | Vision + tool use via `<tool_call>` tags (256k context) |
|
| `qwen` | `mlx-community/Qwen3-VL-4B-Instruct-4bit` | Vision + tool use via `<tool_call>` tags (256k context) |
|
||||||
|
|
||||||
|
## Critical Performance Rule
|
||||||
|
|
||||||
|
**Inference speed is the #1 priority.** The token generation loop must never be blocked or slowed by anything else — no MainActor hops, no SwiftUI observation, no synchronous I/O. Everything that isn't inference (stats collection, UI updates, logging) must run on separate threads via loose coupling:
|
||||||
|
|
||||||
|
- **`LiveCounters`** (thread-safe singleton with `OSAllocatedUnfairLock`) is the bridge: generation code writes to it directly from any thread with zero actor overhead.
|
||||||
|
- **`InferenceStats`** (UI-side, `@Observable @MainActor`) polls `LiveCounters` at 1Hz via a timer — never the other way around.
|
||||||
|
- SSE streaming (`sendSSEEvent`/`sendData`) runs nonisolated off MainActor so token sends don't compete with SwiftUI rendering.
|
||||||
|
- Never gate token output on UI state, analytics, or any `@MainActor`-isolated code.
|
||||||
|
|
||||||
## Key Design Decisions
|
## Key Design Decisions
|
||||||
|
|
||||||
- Uses `mlx-swift-lm` (`MLXVLM` / `VLMModelFactory`) as the inference backend — supports both text and vision in a single model load
|
- Uses `mlx-swift-lm` (`MLXVLM` / `VLMModelFactory`) as the inference backend — supports both text and vision in a single model load
|
||||||
|
|||||||
@@ -1,11 +1,124 @@
|
|||||||
import Foundation
|
import Foundation
|
||||||
|
import os
|
||||||
|
|
||||||
|
// MARK: - Thread-safe live counters (written from any thread, no actor isolation)
|
||||||
|
|
||||||
|
/// Lock-protected counters that the generation loop writes to directly.
|
||||||
|
/// No MainActor requirement — the UI polls these via the 1Hz timer.
|
||||||
|
final class LiveCounters: @unchecked Sendable {
|
||||||
|
static let shared = LiveCounters()
|
||||||
|
|
||||||
|
private let lock = OSAllocatedUnfairLock()
|
||||||
|
|
||||||
|
// Current request
|
||||||
|
private var _activeRequests: Int = 0
|
||||||
|
private var _promptTokens: Int = 0
|
||||||
|
private var _generationTokens: Int = 0
|
||||||
|
private var _tokensPerSecond: Double = 0
|
||||||
|
private var _isPrefilling: Bool = false
|
||||||
|
private var _isGenerating: Bool = false
|
||||||
|
private var _contextMax: Int = 0
|
||||||
|
|
||||||
|
// Cumulative
|
||||||
|
private var _totalRequests: Int = 0
|
||||||
|
private var _totalPromptTokens: Int = 0
|
||||||
|
private var _totalGenerationTokens: Int = 0
|
||||||
|
|
||||||
|
func requestStarted(contextLength: Int) {
|
||||||
|
lock.lock()
|
||||||
|
_activeRequests += 1
|
||||||
|
_totalRequests += 1
|
||||||
|
_isPrefilling = true
|
||||||
|
_isGenerating = false
|
||||||
|
_promptTokens = 0
|
||||||
|
_generationTokens = 0
|
||||||
|
_tokensPerSecond = 0
|
||||||
|
_contextMax = contextLength
|
||||||
|
lock.unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func prefillCompleted(promptTokens: Int) {
|
||||||
|
lock.lock()
|
||||||
|
_isPrefilling = false
|
||||||
|
_isGenerating = true
|
||||||
|
_promptTokens = promptTokens
|
||||||
|
_totalPromptTokens += promptTokens
|
||||||
|
lock.unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func tokenGenerated(tokensPerSecond: Double, totalGenerated: Int) {
|
||||||
|
lock.lock()
|
||||||
|
_generationTokens = totalGenerated
|
||||||
|
_tokensPerSecond = tokensPerSecond
|
||||||
|
lock.unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func requestCompleted(generationTokens: Int) {
|
||||||
|
lock.lock()
|
||||||
|
_activeRequests = max(0, _activeRequests - 1)
|
||||||
|
_totalGenerationTokens += generationTokens
|
||||||
|
if _activeRequests == 0 {
|
||||||
|
_isGenerating = false
|
||||||
|
_isPrefilling = false
|
||||||
|
_tokensPerSecond = 0
|
||||||
|
}
|
||||||
|
lock.unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func reset() {
|
||||||
|
lock.lock()
|
||||||
|
_activeRequests = 0
|
||||||
|
_promptTokens = 0
|
||||||
|
_generationTokens = 0
|
||||||
|
_tokensPerSecond = 0
|
||||||
|
_isPrefilling = false
|
||||||
|
_isGenerating = false
|
||||||
|
_contextMax = 0
|
||||||
|
_totalRequests = 0
|
||||||
|
_totalPromptTokens = 0
|
||||||
|
_totalGenerationTokens = 0
|
||||||
|
lock.unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Atomic snapshot for the UI timer.
|
||||||
|
func snapshot() -> Snapshot {
|
||||||
|
lock.lock()
|
||||||
|
let s = Snapshot(
|
||||||
|
activeRequests: _activeRequests,
|
||||||
|
promptTokens: _promptTokens,
|
||||||
|
generationTokens: _generationTokens,
|
||||||
|
tokensPerSecond: _tokensPerSecond,
|
||||||
|
isPrefilling: _isPrefilling,
|
||||||
|
isGenerating: _isGenerating,
|
||||||
|
contextMax: _contextMax,
|
||||||
|
totalRequests: _totalRequests,
|
||||||
|
totalPromptTokens: _totalPromptTokens,
|
||||||
|
totalGenerationTokens: _totalGenerationTokens
|
||||||
|
)
|
||||||
|
lock.unlock()
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
struct Snapshot {
|
||||||
|
let activeRequests: Int
|
||||||
|
let promptTokens: Int
|
||||||
|
let generationTokens: Int
|
||||||
|
let tokensPerSecond: Double
|
||||||
|
let isPrefilling: Bool
|
||||||
|
let isGenerating: Bool
|
||||||
|
let contextMax: Int
|
||||||
|
let totalRequests: Int
|
||||||
|
let totalPromptTokens: Int
|
||||||
|
let totalGenerationTokens: Int
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Observable stats for the UI (polls LiveCounters at 1Hz)
|
||||||
|
|
||||||
/// Lightweight stats collector for inference activity visualization.
|
|
||||||
/// All mutations happen on @MainActor to avoid locks.
|
|
||||||
@Observable
|
@Observable
|
||||||
@MainActor
|
@MainActor
|
||||||
final class InferenceStats {
|
final class InferenceStats {
|
||||||
// MARK: - Current request state
|
// MARK: - Current request state (refreshed from LiveCounters)
|
||||||
|
|
||||||
var activeRequests: Int = 0
|
var activeRequests: Int = 0
|
||||||
var currentPromptTokens: Int = 0
|
var currentPromptTokens: Int = 0
|
||||||
@@ -40,11 +153,9 @@ final class InferenceStats {
|
|||||||
private var sampleTimer: Timer?
|
private var sampleTimer: Timer?
|
||||||
private var lastGenerationTokenCount: Int = 0
|
private var lastGenerationTokenCount: Int = 0
|
||||||
private var lastPromptTokenCount: Int = 0
|
private var lastPromptTokenCount: Int = 0
|
||||||
private var lastSampleTime: Date = .now
|
|
||||||
|
|
||||||
func startSampling() {
|
func startSampling() {
|
||||||
guard sampleTimer == nil else { return }
|
guard sampleTimer == nil else { return }
|
||||||
lastSampleTime = .now
|
|
||||||
sampleTimer = Timer.scheduledTimer(withTimeInterval: 1.0, repeats: true) { [weak self] _ in
|
sampleTimer = Timer.scheduledTimer(withTimeInterval: 1.0, repeats: true) { [weak self] _ in
|
||||||
Task { @MainActor in
|
Task { @MainActor in
|
||||||
self?.recordSample()
|
self?.recordSample()
|
||||||
@@ -58,19 +169,31 @@ final class InferenceStats {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private func recordSample() {
|
private func recordSample() {
|
||||||
|
// Pull live values from the thread-safe counters
|
||||||
|
let snap = LiveCounters.shared.snapshot()
|
||||||
|
|
||||||
|
activeRequests = snap.activeRequests
|
||||||
|
currentPromptTokens = snap.promptTokens
|
||||||
|
currentGenerationTokens = snap.generationTokens
|
||||||
|
currentTokensPerSecond = snap.tokensPerSecond
|
||||||
|
isPrefilling = snap.isPrefilling
|
||||||
|
isGenerating = snap.isGenerating
|
||||||
|
contextMax = snap.contextMax
|
||||||
|
contextUsed = snap.promptTokens + snap.generationTokens
|
||||||
|
totalRequests = snap.totalRequests
|
||||||
|
totalPromptTokens = snap.totalPromptTokens
|
||||||
|
totalGenerationTokens = snap.totalGenerationTokens
|
||||||
|
|
||||||
let now = Date.now
|
let now = Date.now
|
||||||
|
let genDelta = snap.totalGenerationTokens - lastGenerationTokenCount
|
||||||
|
let promptDelta = snap.totalPromptTokens - lastPromptTokenCount
|
||||||
|
lastGenerationTokenCount = snap.totalGenerationTokens
|
||||||
|
lastPromptTokenCount = snap.totalPromptTokens
|
||||||
|
|
||||||
// Token rate: tokens generated since last sample
|
tokenRateHistory.append(DataPoint(timestamp: now, value: snap.tokensPerSecond))
|
||||||
let genDelta = totalGenerationTokens - lastGenerationTokenCount
|
|
||||||
let promptDelta = totalPromptTokens - lastPromptTokenCount
|
|
||||||
lastGenerationTokenCount = totalGenerationTokens
|
|
||||||
lastPromptTokenCount = totalPromptTokens
|
|
||||||
|
|
||||||
tokenRateHistory.append(DataPoint(timestamp: now, value: currentTokensPerSecond))
|
|
||||||
generationTokenHistory.append(DataPoint(timestamp: now, value: Double(genDelta)))
|
generationTokenHistory.append(DataPoint(timestamp: now, value: Double(genDelta)))
|
||||||
promptTokenHistory.append(DataPoint(timestamp: now, value: Double(promptDelta)))
|
promptTokenHistory.append(DataPoint(timestamp: now, value: Double(promptDelta)))
|
||||||
|
|
||||||
// Trim to ring buffer size
|
|
||||||
if tokenRateHistory.count > Self.maxHistoryPoints {
|
if tokenRateHistory.count > Self.maxHistoryPoints {
|
||||||
tokenRateHistory.removeFirst(tokenRateHistory.count - Self.maxHistoryPoints)
|
tokenRateHistory.removeFirst(tokenRateHistory.count - Self.maxHistoryPoints)
|
||||||
}
|
}
|
||||||
@@ -82,45 +205,8 @@ final class InferenceStats {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// MARK: - Event recording (called from APIServer)
|
|
||||||
|
|
||||||
func requestStarted(contextLength: Int) {
|
|
||||||
activeRequests += 1
|
|
||||||
totalRequests += 1
|
|
||||||
isPrefilling = true
|
|
||||||
isGenerating = false
|
|
||||||
currentPromptTokens = 0
|
|
||||||
currentGenerationTokens = 0
|
|
||||||
currentTokensPerSecond = 0
|
|
||||||
contextMax = contextLength
|
|
||||||
contextUsed = 0
|
|
||||||
}
|
|
||||||
|
|
||||||
func prefillCompleted(promptTokens: Int) {
|
|
||||||
isPrefilling = false
|
|
||||||
isGenerating = true
|
|
||||||
currentPromptTokens = promptTokens
|
|
||||||
totalPromptTokens += promptTokens
|
|
||||||
contextUsed = promptTokens
|
|
||||||
}
|
|
||||||
|
|
||||||
func tokenGenerated(tokensPerSecond: Double, totalGenerated: Int) {
|
|
||||||
currentGenerationTokens = totalGenerated
|
|
||||||
currentTokensPerSecond = tokensPerSecond
|
|
||||||
contextUsed = currentPromptTokens + totalGenerated
|
|
||||||
}
|
|
||||||
|
|
||||||
func requestCompleted(promptTokens: Int, generationTokens: Int) {
|
|
||||||
activeRequests = max(0, activeRequests - 1)
|
|
||||||
totalGenerationTokens += generationTokens
|
|
||||||
if activeRequests == 0 {
|
|
||||||
isGenerating = false
|
|
||||||
isPrefilling = false
|
|
||||||
currentTokensPerSecond = 0
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func reset() {
|
func reset() {
|
||||||
|
LiveCounters.shared.reset()
|
||||||
activeRequests = 0
|
activeRequests = 0
|
||||||
currentPromptTokens = 0
|
currentPromptTokens = 0
|
||||||
currentGenerationTokens = 0
|
currentGenerationTokens = 0
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ final class APIServer {
|
|||||||
private var cachedSession: ChatSession?
|
private var cachedSession: ChatSession?
|
||||||
private var cachedMessages: [Chat.Message]?
|
private var cachedMessages: [Chat.Message]?
|
||||||
private var cachedModelId: String?
|
private var cachedModelId: String?
|
||||||
|
private var cachedInstructions: String = ""
|
||||||
|
|
||||||
func start(modelManager: ModelManager, port: Int = 1234) {
|
func start(modelManager: ModelManager, port: Int = 1234) {
|
||||||
guard !isRunning else { return }
|
guard !isRunning else { return }
|
||||||
@@ -29,6 +30,10 @@ final class APIServer {
|
|||||||
do {
|
do {
|
||||||
let params = NWParameters.tcp
|
let params = NWParameters.tcp
|
||||||
params.allowLocalEndpointReuse = true
|
params.allowLocalEndpointReuse = true
|
||||||
|
// Disable Nagle's algorithm so small SSE events go out immediately
|
||||||
|
if let tcpOptions = params.defaultProtocolStack.transportProtocol as? NWProtocolTCP.Options {
|
||||||
|
tcpOptions.noDelay = true
|
||||||
|
}
|
||||||
listener = try NWListener(using: params, on: NWEndpoint.Port(integerLiteral: UInt16(port)))
|
listener = try NWListener(using: params, on: NWEndpoint.Port(integerLiteral: UInt16(port)))
|
||||||
|
|
||||||
listener?.stateUpdateHandler = { [weak self] state in
|
listener?.stateUpdateHandler = { [weak self] state in
|
||||||
@@ -68,6 +73,7 @@ final class APIServer {
|
|||||||
cachedSession = nil
|
cachedSession = nil
|
||||||
cachedMessages = nil
|
cachedMessages = nil
|
||||||
cachedModelId = nil
|
cachedModelId = nil
|
||||||
|
cachedInstructions = ""
|
||||||
inferenceStats.stopSampling()
|
inferenceStats.stopSampling()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -183,6 +189,7 @@ final class APIServer {
|
|||||||
cachedSession = nil
|
cachedSession = nil
|
||||||
cachedMessages = nil
|
cachedMessages = nil
|
||||||
cachedModelId = nil
|
cachedModelId = nil
|
||||||
|
cachedInstructions = ""
|
||||||
await modelManager.loadModel(targetConfig)
|
await modelManager.loadModel(targetConfig)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -196,6 +203,7 @@ final class APIServer {
|
|||||||
cachedSession = nil
|
cachedSession = nil
|
||||||
cachedMessages = nil
|
cachedMessages = nil
|
||||||
cachedModelId = nil
|
cachedModelId = nil
|
||||||
|
cachedInstructions = ""
|
||||||
await modelManager.loadModel(config)
|
await modelManager.loadModel(config)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -220,32 +228,33 @@ final class APIServer {
|
|||||||
var images: [UserInput.Image] = []
|
var images: [UserInput.Image] = []
|
||||||
let currentModelRepoId = modelManager.currentModel?.repoId ?? modelName
|
let currentModelRepoId = modelManager.currentModel?.repoId ?? modelName
|
||||||
|
|
||||||
// Inject tool definitions into the system prompt if tools are provided
|
// Build the instructions string (system prompt + tool definitions).
|
||||||
|
// This is passed to ChatSession via `instructions:` rather than injected
|
||||||
|
// as history messages, so it avoids an expensive history-replay prefill.
|
||||||
|
var instructions: String = ""
|
||||||
|
|
||||||
|
// Collect system message text from the request
|
||||||
|
for msg in request.messages where msg.role == "system" {
|
||||||
|
let text = msg.content?.textContent ?? ""
|
||||||
|
if !text.isEmpty {
|
||||||
|
if !instructions.isEmpty { instructions += "\n\n" }
|
||||||
|
instructions += text
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Append tool definitions to instructions
|
||||||
if let tools = request.tools, !tools.isEmpty {
|
if let tools = request.tools, !tools.isEmpty {
|
||||||
let toolSystemPrompt = ToolPromptBuilder.buildSystemPrompt(tools: tools, modelId: currentModelRepoId)
|
let toolSystemPrompt = ToolPromptBuilder.buildSystemPrompt(tools: tools, modelId: currentModelRepoId)
|
||||||
|
if !instructions.isEmpty { instructions += "\n\n" }
|
||||||
// Check if there's already a system message
|
instructions += toolSystemPrompt
|
||||||
let hasSystem = request.messages.contains { $0.role == "system" }
|
|
||||||
if hasSystem {
|
|
||||||
// Append tool prompt to existing system message (handled below during conversion)
|
|
||||||
} else {
|
|
||||||
// For Gemma: inject as user message (Gemma doesn't support system role natively)
|
|
||||||
// For Qwen: inject as system message
|
|
||||||
if currentModelRepoId.lowercased().contains("qwen") {
|
|
||||||
chatMessages.append(Chat.Message(role: .system, content: toolSystemPrompt))
|
|
||||||
} else {
|
|
||||||
chatMessages.append(Chat.Message(role: .user, content: toolSystemPrompt))
|
|
||||||
chatMessages.append(Chat.Message(role: .assistant, content: "Understood. I will use the provided tools when appropriate."))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
let toolsForInjection = request.tools
|
let toolsForInjection = request.tools
|
||||||
let isQwen = currentModelRepoId.lowercased().contains("qwen")
|
let isQwen = currentModelRepoId.lowercased().contains("qwen")
|
||||||
|
|
||||||
for msg in request.messages {
|
// Convert non-system messages to Chat.Message
|
||||||
|
for msg in request.messages where msg.role != "system" {
|
||||||
let role: Chat.Message.Role = switch msg.role {
|
let role: Chat.Message.Role = switch msg.role {
|
||||||
case "system": .system
|
|
||||||
case "assistant": .assistant
|
case "assistant": .assistant
|
||||||
case "tool": .user
|
case "tool": .user
|
||||||
default: .user
|
default: .user
|
||||||
@@ -253,12 +262,6 @@ final class APIServer {
|
|||||||
|
|
||||||
var text = msg.content?.textContent ?? ""
|
var text = msg.content?.textContent ?? ""
|
||||||
|
|
||||||
// If this is a system message and tools are provided, append tool definitions
|
|
||||||
if msg.role == "system", let tools = toolsForInjection, !tools.isEmpty {
|
|
||||||
let toolSystemPrompt = ToolPromptBuilder.buildSystemPrompt(tools: tools, modelId: currentModelRepoId)
|
|
||||||
text = text + "\n\n" + toolSystemPrompt
|
|
||||||
}
|
|
||||||
|
|
||||||
// Format tool_call_id responses as tool_output for the model
|
// Format tool_call_id responses as tool_output for the model
|
||||||
if msg.role == "tool" {
|
if msg.role == "tool" {
|
||||||
if isQwen {
|
if isQwen {
|
||||||
@@ -328,6 +331,7 @@ final class APIServer {
|
|||||||
let canReuse = cachedSession != nil
|
let canReuse = cachedSession != nil
|
||||||
&& cachedModelId == currentModelId
|
&& cachedModelId == currentModelId
|
||||||
&& cachedMessages != nil
|
&& cachedMessages != nil
|
||||||
|
&& cachedInstructions == instructions
|
||||||
&& messagesMatch(cachedMessages!, allButLast)
|
&& messagesMatch(cachedMessages!, allButLast)
|
||||||
|
|
||||||
let session: ChatSession
|
let session: ChatSession
|
||||||
@@ -339,15 +343,21 @@ final class APIServer {
|
|||||||
if cachedSession != nil {
|
if cachedSession != nil {
|
||||||
print("[APIServer] History diverged, creating fresh session")
|
print("[APIServer] History diverged, creating fresh session")
|
||||||
}
|
}
|
||||||
|
// Use `instructions:` for system/tool prompt (matches internal chat pattern).
|
||||||
|
// Only conversation turns go in `history:` — this avoids replaying the
|
||||||
|
// large tool prompt as history on every new session.
|
||||||
|
let instr = instructions.isEmpty ? nil : instructions
|
||||||
if !allButLast.isEmpty {
|
if !allButLast.isEmpty {
|
||||||
session = ChatSession(
|
session = ChatSession(
|
||||||
container,
|
container,
|
||||||
|
instructions: instr,
|
||||||
history: allButLast,
|
history: allButLast,
|
||||||
generateParameters: generateParams
|
generateParameters: generateParams
|
||||||
)
|
)
|
||||||
} else {
|
} else {
|
||||||
session = ChatSession(
|
session = ChatSession(
|
||||||
container,
|
container,
|
||||||
|
instructions: instr,
|
||||||
generateParameters: generateParams
|
generateParameters: generateParams
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
@@ -356,7 +366,7 @@ final class APIServer {
|
|||||||
// Extract images from the last message only (ChatSession.streamDetails takes images separately)
|
// Extract images from the last message only (ChatSession.streamDetails takes images separately)
|
||||||
let lastImages = lastMessage.images
|
let lastImages = lastMessage.images
|
||||||
|
|
||||||
inferenceStats.requestStarted(contextLength: contextLength)
|
LiveCounters.shared.requestStarted(contextLength: contextLength)
|
||||||
|
|
||||||
if isStream {
|
if isStream {
|
||||||
await handleStreamingResponse(
|
await handleStreamingResponse(
|
||||||
@@ -387,6 +397,7 @@ final class APIServer {
|
|||||||
cachedSession = session
|
cachedSession = session
|
||||||
cachedMessages = chatMessages // full messages including the one just sent
|
cachedMessages = chatMessages // full messages including the one just sent
|
||||||
cachedModelId = currentModelId
|
cachedModelId = currentModelId
|
||||||
|
cachedInstructions = instructions
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Decode a base64 data URI (data:image/png;base64,...) into a UserInput.Image.
|
/// Decode a base64 data URI (data:image/png;base64,...) into a UserInput.Image.
|
||||||
@@ -439,20 +450,20 @@ final class APIServer {
|
|||||||
case .chunk(let text):
|
case .chunk(let text):
|
||||||
fullText += text
|
fullText += text
|
||||||
completionTokens += 1
|
completionTokens += 1
|
||||||
inferenceStats.tokenGenerated(tokensPerSecond: 0, totalGenerated: completionTokens)
|
LiveCounters.shared.tokenGenerated(tokensPerSecond: 0, totalGenerated: completionTokens)
|
||||||
case .info(let info):
|
case .info(let info):
|
||||||
promptTokens = info.promptTokenCount
|
promptTokens = info.promptTokenCount
|
||||||
completionTokens = info.generationTokenCount
|
completionTokens = info.generationTokenCount
|
||||||
inferenceStats.prefillCompleted(promptTokens: promptTokens)
|
LiveCounters.shared.prefillCompleted(promptTokens: promptTokens)
|
||||||
if info.tokensPerSecond > 0 {
|
if info.tokensPerSecond > 0 {
|
||||||
inferenceStats.tokenGenerated(tokensPerSecond: info.tokensPerSecond, totalGenerated: completionTokens)
|
LiveCounters.shared.tokenGenerated(tokensPerSecond: info.tokensPerSecond, totalGenerated: completionTokens)
|
||||||
}
|
}
|
||||||
case .toolCall(let call):
|
case .toolCall(let call):
|
||||||
frameworkToolCalls.append(call)
|
frameworkToolCalls.append(call)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
inferenceStats.requestCompleted(promptTokens: promptTokens, generationTokens: completionTokens)
|
LiveCounters.shared.requestCompleted(generationTokens: completionTokens)
|
||||||
|
|
||||||
// Parse tool calls: first check framework-detected ones, then our own text parser
|
// Parse tool calls: first check framework-detected ones, then our own text parser
|
||||||
var finishReason = "stop"
|
var finishReason = "stop"
|
||||||
@@ -524,7 +535,7 @@ final class APIServer {
|
|||||||
sendResponse(connection: connection, status: 200, body: String(data: json, encoding: .utf8) ?? "{}")
|
sendResponse(connection: connection, status: 200, body: String(data: json, encoding: .utf8) ?? "{}")
|
||||||
}
|
}
|
||||||
} catch {
|
} catch {
|
||||||
inferenceStats.requestCompleted(promptTokens: 0, generationTokens: 0)
|
LiveCounters.shared.requestCompleted(generationTokens: 0)
|
||||||
sendResponse(connection: connection, status: 500, body: #"{"error":"\#(error.localizedDescription)"}"#)
|
sendResponse(connection: connection, status: 500, body: #"{"error":"\#(error.localizedDescription)"}"#)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -552,15 +563,10 @@ final class APIServer {
|
|||||||
"",
|
"",
|
||||||
].joined(separator: "\r\n")
|
].joined(separator: "\r\n")
|
||||||
|
|
||||||
let headerSent = await withCheckedContinuation { continuation in
|
await Self.sendData(connection: connection, data: header.data(using: .utf8)!)
|
||||||
connection.send(content: header.data(using: .utf8), completion: .contentProcessed({ _ in
|
|
||||||
continuation.resume(returning: true)
|
|
||||||
}))
|
|
||||||
}
|
|
||||||
guard headerSent else { return }
|
|
||||||
|
|
||||||
// Send initial role chunk
|
// Send initial role chunk
|
||||||
sendSSEEvent(connection: connection, chunk: APIChatCompletionChunk(
|
await Self.sendSSEEvent(connection: connection, chunk: APIChatCompletionChunk(
|
||||||
id: requestId,
|
id: requestId,
|
||||||
object: "chat.completion.chunk",
|
object: "chat.completion.chunk",
|
||||||
created: created,
|
created: created,
|
||||||
@@ -569,78 +575,39 @@ final class APIServer {
|
|||||||
usage: nil
|
usage: nil
|
||||||
))
|
))
|
||||||
|
|
||||||
// When tools are available, buffer full response to parse tool calls
|
let hasTools = tools != nil && !(tools?.isEmpty ?? true)
|
||||||
// (otherwise raw tool-call markup leaks into streamed text)
|
|
||||||
let bufferForTools = tools != nil && !(tools?.isEmpty ?? true)
|
|
||||||
|
|
||||||
var promptTokens = 0
|
// Run the generation loop OFF MainActor.
|
||||||
var completionTokens = 0
|
// ChatSession and NWConnection don't need MainActor.
|
||||||
var fullText = ""
|
// Running on MainActor caused every token to compete with SwiftUI
|
||||||
var frameworkToolCalls: [MLXLMCommon.ToolCall] = []
|
// rendering, creating back-pressure that coalesced all output.
|
||||||
|
|
||||||
do {
|
|
||||||
let stream = session.streamDetails(
|
let stream = session.streamDetails(
|
||||||
to: prompt,
|
to: prompt,
|
||||||
images: images,
|
images: images,
|
||||||
videos: []
|
videos: []
|
||||||
)
|
)
|
||||||
|
// Transfer non-Sendable values to the nonisolated loop.
|
||||||
for try await generation in stream {
|
// Safe because we don't touch session/images again until after the loop.
|
||||||
switch generation {
|
let result = await {
|
||||||
case .chunk(let text):
|
nonisolated(unsafe) let stream = stream
|
||||||
completionTokens += 1
|
return await Self.runStreamingLoop(
|
||||||
fullText += text
|
connection: connection,
|
||||||
inferenceStats.tokenGenerated(tokensPerSecond: 0, totalGenerated: completionTokens)
|
stream: stream,
|
||||||
|
requestId: requestId,
|
||||||
if !bufferForTools {
|
|
||||||
sendSSEEvent(connection: connection, chunk: APIChatCompletionChunk(
|
|
||||||
id: requestId,
|
|
||||||
object: "chat.completion.chunk",
|
|
||||||
created: created,
|
created: created,
|
||||||
model: modelName,
|
modelName: modelName
|
||||||
choices: [APIStreamChoice(index: 0, delta: APIDeltaMessage(role: nil, content: text, tool_calls: nil), finish_reason: nil)],
|
)
|
||||||
usage: nil
|
}()
|
||||||
))
|
|
||||||
}
|
|
||||||
|
|
||||||
case .info(let info):
|
let (promptTokens, completionTokens, fullText, frameworkToolCalls) = result
|
||||||
promptTokens = info.promptTokenCount
|
|
||||||
completionTokens = info.generationTokenCount
|
|
||||||
inferenceStats.prefillCompleted(promptTokens: promptTokens)
|
|
||||||
if info.tokensPerSecond > 0 {
|
|
||||||
inferenceStats.tokenGenerated(tokensPerSecond: info.tokensPerSecond, totalGenerated: completionTokens)
|
|
||||||
}
|
|
||||||
|
|
||||||
case .toolCall(let call):
|
// Stats were already updated by LiveCounters inside the loop
|
||||||
frameworkToolCalls.append(call)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} catch {
|
|
||||||
inferenceStats.requestCompleted(promptTokens: promptTokens, generationTokens: completionTokens)
|
|
||||||
let errorEvent = "data: {\"error\":\"\(error.localizedDescription)\"}\n\n"
|
|
||||||
connection.send(content: errorEvent.data(using: .utf8), completion: .contentProcessed({ _ in }))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Post-generation: handle tool calls (framework-detected or text-parsed)
|
// Post-generation: handle tool calls (framework-detected or text-parsed)
|
||||||
var finishReason = "stop"
|
var finishReason = "stop"
|
||||||
|
|
||||||
if !frameworkToolCalls.isEmpty {
|
if !frameworkToolCalls.isEmpty {
|
||||||
// Framework natively detected tool calls (e.g. Qwen)
|
|
||||||
finishReason = "tool_calls"
|
finishReason = "tool_calls"
|
||||||
|
|
||||||
// Emit any buffered text content
|
|
||||||
if !fullText.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty {
|
|
||||||
sendSSEEvent(connection: connection, chunk: APIChatCompletionChunk(
|
|
||||||
id: requestId,
|
|
||||||
object: "chat.completion.chunk",
|
|
||||||
created: created,
|
|
||||||
model: modelName,
|
|
||||||
choices: [APIStreamChoice(index: 0, delta: APIDeltaMessage(role: nil, content: fullText, tool_calls: nil), finish_reason: nil)],
|
|
||||||
usage: nil
|
|
||||||
))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Emit tool call chunks
|
|
||||||
for (i, tc) in frameworkToolCalls.enumerated() {
|
for (i, tc) in frameworkToolCalls.enumerated() {
|
||||||
let argsDict = tc.function.arguments.mapValues { $0.anyValue }
|
let argsDict = tc.function.arguments.mapValues { $0.anyValue }
|
||||||
let argsJSON: String
|
let argsJSON: String
|
||||||
@@ -657,7 +624,7 @@ final class APIServer {
|
|||||||
type: "function",
|
type: "function",
|
||||||
function: APIFunctionCall(name: tc.function.name, arguments: argsJSON)
|
function: APIFunctionCall(name: tc.function.name, arguments: argsJSON)
|
||||||
)
|
)
|
||||||
sendSSEEvent(connection: connection, chunk: APIChatCompletionChunk(
|
await Self.sendSSEEvent(connection: connection, chunk: APIChatCompletionChunk(
|
||||||
id: requestId,
|
id: requestId,
|
||||||
object: "chat.completion.chunk",
|
object: "chat.completion.chunk",
|
||||||
created: created,
|
created: created,
|
||||||
@@ -666,27 +633,11 @@ final class APIServer {
|
|||||||
usage: nil
|
usage: nil
|
||||||
))
|
))
|
||||||
}
|
}
|
||||||
} else if bufferForTools {
|
} else if hasTools {
|
||||||
// Text-parsed tool calls (e.g. Gemma tool_code blocks)
|
let (_, parsed) = ToolCallParser.parse(text: fullText, tools: tools)
|
||||||
let (cleanText, parsed) = ToolCallParser.parse(text: fullText, tools: tools)
|
|
||||||
if !parsed.isEmpty {
|
if !parsed.isEmpty {
|
||||||
finishReason = "tool_calls"
|
finishReason = "tool_calls"
|
||||||
fullText = cleanText
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Emit buffered content (cleaned of tool-call markup)
|
|
||||||
if !fullText.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty {
|
|
||||||
sendSSEEvent(connection: connection, chunk: APIChatCompletionChunk(
|
|
||||||
id: requestId,
|
|
||||||
object: "chat.completion.chunk",
|
|
||||||
created: created,
|
|
||||||
model: modelName,
|
|
||||||
choices: [APIStreamChoice(index: 0, delta: APIDeltaMessage(role: nil, content: fullText, tool_calls: nil), finish_reason: nil)],
|
|
||||||
usage: nil
|
|
||||||
))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Emit tool call chunks
|
|
||||||
for (i, tc) in parsed.enumerated() {
|
for (i, tc) in parsed.enumerated() {
|
||||||
let apiToolCall = APIToolCall(
|
let apiToolCall = APIToolCall(
|
||||||
index: i,
|
index: i,
|
||||||
@@ -694,7 +645,7 @@ final class APIServer {
|
|||||||
type: "function",
|
type: "function",
|
||||||
function: APIFunctionCall(name: tc.name, arguments: tc.arguments)
|
function: APIFunctionCall(name: tc.name, arguments: tc.arguments)
|
||||||
)
|
)
|
||||||
sendSSEEvent(connection: connection, chunk: APIChatCompletionChunk(
|
await Self.sendSSEEvent(connection: connection, chunk: APIChatCompletionChunk(
|
||||||
id: requestId,
|
id: requestId,
|
||||||
object: "chat.completion.chunk",
|
object: "chat.completion.chunk",
|
||||||
created: created,
|
created: created,
|
||||||
@@ -706,7 +657,7 @@ final class APIServer {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Final chunk with finish_reason and usage
|
// Final chunk with finish_reason and usage
|
||||||
sendSSEEvent(connection: connection, chunk: APIChatCompletionChunk(
|
await Self.sendSSEEvent(connection: connection, chunk: APIChatCompletionChunk(
|
||||||
id: requestId,
|
id: requestId,
|
||||||
object: "chat.completion.chunk",
|
object: "chat.completion.chunk",
|
||||||
created: created,
|
created: created,
|
||||||
@@ -719,20 +670,83 @@ final class APIServer {
|
|||||||
)
|
)
|
||||||
))
|
))
|
||||||
|
|
||||||
inferenceStats.requestCompleted(promptTokens: promptTokens, generationTokens: completionTokens)
|
LiveCounters.shared.requestCompleted(generationTokens: completionTokens)
|
||||||
|
|
||||||
// Send [DONE] and close
|
// Send [DONE] and close
|
||||||
let done = "data: [DONE]\n\n"
|
await Self.sendData(connection: connection, data: "data: [DONE]\n\n".data(using: .utf8)!)
|
||||||
connection.send(content: done.data(using: .utf8), completion: .contentProcessed({ _ in
|
|
||||||
connection.cancel()
|
connection.cancel()
|
||||||
}))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private func sendSSEEvent(connection: NWConnection, chunk: APIChatCompletionChunk) {
|
/// Run the token generation + SSE send loop entirely off MainActor.
|
||||||
|
/// This is critical: if the loop runs on MainActor, every token requires
|
||||||
|
/// multiple actor hops competing with SwiftUI, causing all output to batch.
|
||||||
|
nonisolated private static func runStreamingLoop(
|
||||||
|
connection: NWConnection,
|
||||||
|
stream: AsyncThrowingStream<Generation, any Error>,
|
||||||
|
requestId: String,
|
||||||
|
created: Int,
|
||||||
|
modelName: String
|
||||||
|
) async -> (Int, Int, String, [MLXLMCommon.ToolCall]) {
|
||||||
|
var promptTokens = 0
|
||||||
|
var completionTokens = 0
|
||||||
|
var fullText = ""
|
||||||
|
var frameworkToolCalls: [MLXLMCommon.ToolCall] = []
|
||||||
|
|
||||||
|
do {
|
||||||
|
for try await generation in stream {
|
||||||
|
switch generation {
|
||||||
|
case .chunk(let text):
|
||||||
|
completionTokens += 1
|
||||||
|
fullText += text
|
||||||
|
|
||||||
|
// Update live counters directly — no MainActor hop needed
|
||||||
|
LiveCounters.shared.tokenGenerated(tokensPerSecond: 0, totalGenerated: completionTokens)
|
||||||
|
|
||||||
|
// Send directly — no MainActor hop.
|
||||||
|
await sendSSEEvent(connection: connection, chunk: APIChatCompletionChunk(
|
||||||
|
id: requestId,
|
||||||
|
object: "chat.completion.chunk",
|
||||||
|
created: created,
|
||||||
|
model: modelName,
|
||||||
|
choices: [APIStreamChoice(index: 0, delta: APIDeltaMessage(role: nil, content: text, tool_calls: nil), finish_reason: nil)],
|
||||||
|
usage: nil
|
||||||
|
))
|
||||||
|
|
||||||
|
case .info(let info):
|
||||||
|
promptTokens = info.promptTokenCount
|
||||||
|
completionTokens = info.generationTokenCount
|
||||||
|
LiveCounters.shared.prefillCompleted(promptTokens: promptTokens)
|
||||||
|
if info.tokensPerSecond > 0 {
|
||||||
|
LiveCounters.shared.tokenGenerated(tokensPerSecond: info.tokensPerSecond, totalGenerated: completionTokens)
|
||||||
|
}
|
||||||
|
|
||||||
|
case .toolCall(let call):
|
||||||
|
frameworkToolCalls.append(call)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} catch {
|
||||||
|
let errorEvent = "data: {\"error\":\"\(error.localizedDescription)\"}\n\n"
|
||||||
|
await sendData(connection: connection, data: errorEvent.data(using: .utf8)!)
|
||||||
|
}
|
||||||
|
|
||||||
|
return (promptTokens, completionTokens, fullText, frameworkToolCalls)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Send an SSE event and wait for the protocol stack to process it.
|
||||||
|
nonisolated private static func sendSSEEvent(connection: NWConnection, chunk: APIChatCompletionChunk) async {
|
||||||
guard let json = try? JSONEncoder().encode(chunk),
|
guard let json = try? JSONEncoder().encode(chunk),
|
||||||
let jsonString = String(data: json, encoding: .utf8) else { return }
|
let jsonString = String(data: json, encoding: .utf8) else { return }
|
||||||
let event = "data: \(jsonString)\n\n"
|
let event = "data: \(jsonString)\n\n"
|
||||||
connection.send(content: event.data(using: .utf8), completion: .contentProcessed({ _ in }))
|
await sendData(connection: connection, data: event.data(using: .utf8)!)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Send raw data on the connection and wait for the protocol stack to process it.
|
||||||
|
nonisolated private static func sendData(connection: NWConnection, data: Data) async {
|
||||||
|
await withCheckedContinuation { (continuation: CheckedContinuation<Void, Never>) in
|
||||||
|
connection.send(content: data, completion: .contentProcessed({ _ in
|
||||||
|
continuation.resume()
|
||||||
|
}))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// MARK: - HTTP helpers
|
// MARK: - HTTP helpers
|
||||||
@@ -778,19 +792,19 @@ final class APIServer {
|
|||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Check if cached messages are a prefix of new messages (for KV cache reuse).
|
/// Check if the cached session can be reused for the new history.
|
||||||
/// The cached messages include the full history from the previous request.
|
///
|
||||||
/// For cache reuse, all but the last message of the new request must match
|
/// After a request the session's KV cache contains:
|
||||||
/// all but the last message of the cached messages (the cached last was the
|
/// cachedMessages (history + user prompt) + the generated assistant response.
|
||||||
/// previous user prompt, which is now part of the history).
|
/// On the next request the client sends back the full conversation, so
|
||||||
|
/// `newHistory` (allButLast) is typically `cachedMessages` + 1 assistant reply.
|
||||||
|
/// We allow reuse when `cached` is a prefix of `newHistory` and there is at most
|
||||||
|
/// one extra message (the assistant response the session already generated).
|
||||||
|
/// More than one extra message (e.g. injected tool results) means the session
|
||||||
|
/// hasn't processed them, so we must create a fresh session.
|
||||||
private func messagesMatch(_ cached: [Chat.Message], _ newHistory: [Chat.Message]) -> Bool {
|
private func messagesMatch(_ cached: [Chat.Message], _ newHistory: [Chat.Message]) -> Bool {
|
||||||
// The cached messages are the full chatMessages from the previous request.
|
guard cached.count <= newHistory.count,
|
||||||
// For the cache to be reusable, the new history (allButLast) must match
|
newHistory.count <= cached.count + 1 else { return false }
|
||||||
// exactly what the session has already processed.
|
|
||||||
// After a request, the session has seen: cachedMessages' history + prompt + response.
|
|
||||||
// So on the next request, if newHistory == cachedMessages, the session already
|
|
||||||
// contains all of those turns and we can just send the new last message.
|
|
||||||
guard cached.count == newHistory.count else { return false }
|
|
||||||
for (a, b) in zip(cached, newHistory) {
|
for (a, b) in zip(cached, newHistory) {
|
||||||
if a.role != b.role || a.content != b.content { return false }
|
if a.role != b.role || a.content != b.content { return false }
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -24,6 +24,12 @@ enum ToolCallParser {
|
|||||||
return (gemmaClean, gemmaCalls)
|
return (gemmaClean, gemmaCalls)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Try bare function calls matching known tool names: tool_name(args...)
|
||||||
|
let (bareClean, bareCalls) = parseBareToolCalls(text: text, tools: tools)
|
||||||
|
if !bareCalls.isEmpty {
|
||||||
|
return (bareClean, bareCalls)
|
||||||
|
}
|
||||||
|
|
||||||
return (text, [])
|
return (text, [])
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -187,4 +193,57 @@ enum ToolCallParser {
|
|||||||
|
|
||||||
return (cleanText, toolCalls)
|
return (cleanText, toolCalls)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MARK: - Bare function calls: tool_name(args...)
|
||||||
|
|
||||||
|
/// Parse bare function calls that match known tool names.
|
||||||
|
/// Handles models that output tool calls without fences or XML tags.
|
||||||
|
private static func parseBareToolCalls(text: String, tools: [APIToolDefinition]?) -> (String, [ParsedToolCall]) {
|
||||||
|
guard let tools, !tools.isEmpty else { return (text, []) }
|
||||||
|
|
||||||
|
let toolNames = tools.map { $0.function.name }
|
||||||
|
guard !toolNames.isEmpty else { return (text, []) }
|
||||||
|
|
||||||
|
// Build regex: (tool_name1|tool_name2)\s*\(.*\)
|
||||||
|
let escapedNames = toolNames.map { NSRegularExpression.escapedPattern(for: $0) }
|
||||||
|
let pattern = "(" + escapedNames.joined(separator: "|") + #")\s*\((.*)\)"#
|
||||||
|
guard let regex = try? NSRegularExpression(pattern: pattern, options: .dotMatchesLineSeparators) else {
|
||||||
|
return (text, [])
|
||||||
|
}
|
||||||
|
|
||||||
|
let nsText = text as NSString
|
||||||
|
let matches = regex.matches(in: text, range: NSRange(location: 0, length: nsText.length))
|
||||||
|
guard !matches.isEmpty else { return (text, []) }
|
||||||
|
|
||||||
|
var toolDefs: [String: [String]] = [:]
|
||||||
|
for tool in tools {
|
||||||
|
let paramNames = tool.function.parameters?["properties"]?.value as? [String: Any]
|
||||||
|
toolDefs[tool.function.name] = paramNames.map { Array($0.keys).sorted() } ?? []
|
||||||
|
}
|
||||||
|
|
||||||
|
var toolCalls: [ParsedToolCall] = []
|
||||||
|
for (i, match) in matches.enumerated() {
|
||||||
|
let name = nsText.substring(with: match.range(at: 1))
|
||||||
|
let argsStr = nsText.substring(with: match.range(at: 2)).trimmingCharacters(in: .whitespacesAndNewlines)
|
||||||
|
|
||||||
|
var args: [String: Any] = [:]
|
||||||
|
if !argsStr.isEmpty {
|
||||||
|
if let (_, parsed) = parsePythonCall("\(name)(\(argsStr))", toolDefs: toolDefs) as (String, [String: Any])? {
|
||||||
|
args = parsed
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let argsJSON = (try? JSONSerialization.data(withJSONObject: args))
|
||||||
|
.flatMap { String(data: $0, encoding: .utf8) } ?? "{}"
|
||||||
|
let callId = String(format: "call_%d_%08d", i, abs(name.hashValue) % 100_000_000)
|
||||||
|
toolCalls.append(ParsedToolCall(id: callId, name: name, arguments: argsJSON))
|
||||||
|
}
|
||||||
|
|
||||||
|
let cleanText = regex.stringByReplacingMatches(
|
||||||
|
in: text, range: NSRange(location: 0, length: nsText.length),
|
||||||
|
withTemplate: ""
|
||||||
|
).trimmingCharacters(in: .whitespacesAndNewlines)
|
||||||
|
|
||||||
|
return (cleanText, toolCalls)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user