From e59be9df1a94ff1411fdc070494cbf62e536b074 Mon Sep 17 00:00:00 2001 From: Chili Palmer Date: Fri, 20 Mar 2026 12:20:26 +0100 Subject: [PATCH] fix: app-close-crash --- MLXServer/MLXServerApp.swift | 15 ++++ MLXServer/Server/APIServer.swift | 60 +++++++++++++++- MLXServer/ViewModels/ChatViewModel.swift | 35 +++++++--- .../Server/APIServerRewriteTests.swift | 68 +++++++++++++++++++ 4 files changed, 167 insertions(+), 11 deletions(-) diff --git a/MLXServer/MLXServerApp.swift b/MLXServer/MLXServerApp.swift index 54b351f..cb37ff4 100644 --- a/MLXServer/MLXServerApp.swift +++ b/MLXServer/MLXServerApp.swift @@ -4,11 +4,26 @@ import MLX @MainActor final class AppDelegate: NSObject, NSApplicationDelegate { var chatViewModel: ChatViewModel? + private var terminationTask: Task? func application(_ application: NSApplication, open urls: [URL]) { ChatDocumentController.shared.enqueueOpenRequests(urls) } + func applicationShouldTerminate(_ sender: NSApplication) -> NSApplication.TerminateReply { + if terminationTask != nil { + return .terminateLater + } + + terminationTask = Task { @MainActor [weak self] in + await self?.chatViewModel?.prepareForTermination() + sender.reply(toApplicationShouldTerminate: true) + self?.terminationTask = nil + } + + return .terminateLater + } + func applicationWillTerminate(_ notification: Notification) { chatViewModel?.autosaveToSandbox() } diff --git a/MLXServer/Server/APIServer.swift b/MLXServer/Server/APIServer.swift index c0d299b..7123c92 100644 --- a/MLXServer/Server/APIServer.swift +++ b/MLXServer/Server/APIServer.swift @@ -15,6 +15,11 @@ final class APIServer { let matchedTokenCount: Int } + private struct ActiveRequest { + let connection: NWConnection + let cancellation: CancellationToken + } + nonisolated(unsafe) static var debugLookupEventHandler: (@Sendable (DebugLookupEvent) -> Void)? var isRunning = false @@ -24,11 +29,14 @@ final class APIServer { private var listener: NWListener? private var modelManager: ModelManager? + private var activeRequests: [String: ActiveRequest] = [:] + private var isShuttingDown = false func start(modelManager: ModelManager, port: Int = 1234) { guard !isRunning else { return } self.modelManager = modelManager self.port = port + self.isShuttingDown = false do { let params = NWParameters.tcp @@ -70,11 +78,46 @@ final class APIServer { } func stop() { + beginShutdown() + TokenPrefixCache.shared.invalidateAll() + inferenceStats.stopSampling() + } + + func shutdown(timeoutSeconds: TimeInterval = 2.0) async { + beginShutdown() + + let deadline = Date().addingTimeInterval(timeoutSeconds) + while !activeRequests.isEmpty && Date() < deadline { + try? await Task.sleep(nanoseconds: 10_000_000) + } + + TokenPrefixCache.shared.invalidateAll() + inferenceStats.stopSampling() + } + + private func beginShutdown() { + guard !isShuttingDown else { return } + isShuttingDown = true listener?.cancel() listener = nil isRunning = false - TokenPrefixCache.shared.invalidateAll() - inferenceStats.stopSampling() + + for activeRequest in activeRequests.values { + activeRequest.cancellation.cancel() + activeRequest.connection.cancel() + } + } + + private func registerActiveRequest( + requestId: String, + connection: NWConnection, + cancellation: CancellationToken + ) { + activeRequests[requestId] = ActiveRequest(connection: connection, cancellation: cancellation) + } + + private func unregisterActiveRequest(requestId: String) { + activeRequests.removeValue(forKey: requestId) } // MARK: - Connection handling @@ -171,6 +214,11 @@ final class APIServer { // MARK: - POST /v1/chat/completions private func handleChatCompletions(connection: NWConnection, body: Data?) async { + guard !isShuttingDown else { + sendResponse(connection: connection, status: 503, body: #"{"error":"Server is shutting down"}"#) + return + } + guard let body, let request = try? JSONDecoder().decode(APIChatCompletionRequest.self, from: body) else { sendResponse(connection: connection, status: 400, body: #"{"error":"Invalid request body"}"#) return @@ -312,6 +360,11 @@ final class APIServer { LiveCounters.shared.requestPhaseChanged(requestId: requestId, phase: .prefilling) let cancellation = CancellationToken() + registerActiveRequest(requestId: requestId, connection: connection, cancellation: cancellation) + defer { + unregisterActiveRequest(requestId: requestId) + } + let streamHandle: InferenceEngine.StreamHandle do { streamHandle = try await engine.stream( @@ -352,7 +405,8 @@ final class APIServer { ) } - if let cacheKey, + if let cacheKey, + !isShuttingDown, result.succeeded || result.cancelled { Self.storePromptCache( streamHandle.workingCache, diff --git a/MLXServer/ViewModels/ChatViewModel.swift b/MLXServer/ViewModels/ChatViewModel.swift index 41fe770..bfeefed 100644 --- a/MLXServer/ViewModels/ChatViewModel.swift +++ b/MLXServer/ViewModels/ChatViewModel.swift @@ -181,15 +181,18 @@ final class ChatViewModel { } func stop() { - generationTask?.cancel() - generationTask = nil - isGenerating = false + _ = cancelActiveGeneration() + } - if let last = conversation.messages.indices.last, - conversation.messages[last].isStreaming { - conversation.finalizeMessage(at: last) - markDirtyIfNeeded() - } + func prepareForTermination() async { + autosaveToSandbox() + + let activeGeneration = cancelActiveGeneration() + await apiServer.shutdown() + await activeGeneration?.value + + resetSession() + modelManager.unloadModel() } func attachImage(_ image: NSImage) { @@ -564,4 +567,20 @@ final class ChatViewModel { func stopAPIServer() { apiServer.stop() } + + @discardableResult + private func cancelActiveGeneration() -> Task? { + let activeGeneration = generationTask + activeGeneration?.cancel() + generationTask = nil + isGenerating = false + + if let last = conversation.messages.indices.last, + conversation.messages[last].isStreaming { + conversation.finalizeMessage(at: last) + markDirtyIfNeeded() + } + + return activeGeneration + } } diff --git a/MLXServerTests/Server/APIServerRewriteTests.swift b/MLXServerTests/Server/APIServerRewriteTests.swift index 91a49e0..822e996 100644 --- a/MLXServerTests/Server/APIServerRewriteTests.swift +++ b/MLXServerTests/Server/APIServerRewriteTests.swift @@ -998,6 +998,74 @@ final class APIServerRewriteTests: XCTestCase { XCTAssertLessThan(elapsed, 0.2) } + func testServerShutdownCancelsInFlightStreamAndDrainsActiveRequests() async throws { + let harness = try await makeHarness() + + let request = APIChatCompletionRequest( + model: "gemma", + messages: [ + APIChatMessage(role: "user", content: .text("Count from one to fifty with commas, using many tokens."), name: nil, tool_calls: nil, tool_call_id: nil) + ], + temperature: 0, + top_p: 1, + max_tokens: 128, + stream: true, + stop: nil, + tools: nil, + tool_choice: nil, + frequency_penalty: nil, + presence_penalty: nil, + n: nil + ) + + let url = URL(string: "http://127.0.0.1:\(harness.port)/v1/chat/completions")! + var urlRequest = URLRequest(url: url) + urlRequest.httpMethod = "POST" + urlRequest.setValue("application/json", forHTTPHeaderField: "Content-Type") + urlRequest.httpBody = try JSONEncoder().encode(request) + + let observer = StreamCancellationObserver() + let session = URLSession(configuration: .ephemeral) + let task = Task { + let (bytes, response) = try await session.bytes(for: urlRequest) + let httpResponse = try XCTUnwrap(response as? HTTPURLResponse) + XCTAssertEqual(httpResponse.statusCode, 200) + + for try await line in bytes.lines { + guard line.hasPrefix("data: ") else { continue } + let payload = String(line.dropFirst(6)) + if payload == "[DONE]" { + break + } + guard let data = payload.data(using: .utf8) else { continue } + let chunk = try JSONDecoder().decode(APIChatCompletionChunk.self, from: data) + if let deltaContent = chunk.choices.first?.delta.content, !deltaContent.isEmpty { + await observer.markFirstContentSeen() + try await Task.sleep(nanoseconds: 30_000_000_000) + } + } + } + + try await waitUntil(timeoutSeconds: 10) { + await observer.hasSeenFirstContent + } + + await harness.server.shutdown(timeoutSeconds: 2.0) + + let liveSnapshot = LiveCounters.shared.snapshot() + XCTAssertEqual(liveSnapshot.activeRequests, 0) + let isRunning = await MainActor.run { harness.server.isRunning } + XCTAssertFalse(isRunning) + + session.invalidateAndCancel() + task.cancel() + _ = try? await task.value + await MainActor.run { + harness.modelManager.unloadModel() + } + TokenPrefixCache.shared.reset() + } + func testRepeatedStreamingDisconnectsDoNotBreakSubsequentGeneration() async throws { let harness = try await makeHarness() defer { harness.stop() }