fix: app-close-crash

This commit is contained in:
2026-03-20 12:20:26 +01:00
parent 1f12fac5e2
commit e59be9df1a
4 changed files with 167 additions and 11 deletions

View File

@@ -4,11 +4,26 @@ import MLX
@MainActor @MainActor
final class AppDelegate: NSObject, NSApplicationDelegate { final class AppDelegate: NSObject, NSApplicationDelegate {
var chatViewModel: ChatViewModel? var chatViewModel: ChatViewModel?
private var terminationTask: Task<Void, Never>?
func application(_ application: NSApplication, open urls: [URL]) { func application(_ application: NSApplication, open urls: [URL]) {
ChatDocumentController.shared.enqueueOpenRequests(urls) ChatDocumentController.shared.enqueueOpenRequests(urls)
} }
func applicationShouldTerminate(_ sender: NSApplication) -> NSApplication.TerminateReply {
if terminationTask != nil {
return .terminateLater
}
terminationTask = Task { @MainActor [weak self] in
await self?.chatViewModel?.prepareForTermination()
sender.reply(toApplicationShouldTerminate: true)
self?.terminationTask = nil
}
return .terminateLater
}
func applicationWillTerminate(_ notification: Notification) { func applicationWillTerminate(_ notification: Notification) {
chatViewModel?.autosaveToSandbox() chatViewModel?.autosaveToSandbox()
} }

View File

@@ -15,6 +15,11 @@ final class APIServer {
let matchedTokenCount: Int let matchedTokenCount: Int
} }
private struct ActiveRequest {
let connection: NWConnection
let cancellation: CancellationToken
}
nonisolated(unsafe) static var debugLookupEventHandler: (@Sendable (DebugLookupEvent) -> Void)? nonisolated(unsafe) static var debugLookupEventHandler: (@Sendable (DebugLookupEvent) -> Void)?
var isRunning = false var isRunning = false
@@ -24,11 +29,14 @@ final class APIServer {
private var listener: NWListener? private var listener: NWListener?
private var modelManager: ModelManager? private var modelManager: ModelManager?
private var activeRequests: [String: ActiveRequest] = [:]
private var isShuttingDown = false
func start(modelManager: ModelManager, port: Int = 1234) { func start(modelManager: ModelManager, port: Int = 1234) {
guard !isRunning else { return } guard !isRunning else { return }
self.modelManager = modelManager self.modelManager = modelManager
self.port = port self.port = port
self.isShuttingDown = false
do { do {
let params = NWParameters.tcp let params = NWParameters.tcp
@@ -70,11 +78,46 @@ final class APIServer {
} }
func stop() { func stop() {
beginShutdown()
TokenPrefixCache.shared.invalidateAll()
inferenceStats.stopSampling()
}
func shutdown(timeoutSeconds: TimeInterval = 2.0) async {
beginShutdown()
let deadline = Date().addingTimeInterval(timeoutSeconds)
while !activeRequests.isEmpty && Date() < deadline {
try? await Task.sleep(nanoseconds: 10_000_000)
}
TokenPrefixCache.shared.invalidateAll()
inferenceStats.stopSampling()
}
private func beginShutdown() {
guard !isShuttingDown else { return }
isShuttingDown = true
listener?.cancel() listener?.cancel()
listener = nil listener = nil
isRunning = false isRunning = false
TokenPrefixCache.shared.invalidateAll()
inferenceStats.stopSampling() for activeRequest in activeRequests.values {
activeRequest.cancellation.cancel()
activeRequest.connection.cancel()
}
}
private func registerActiveRequest(
requestId: String,
connection: NWConnection,
cancellation: CancellationToken
) {
activeRequests[requestId] = ActiveRequest(connection: connection, cancellation: cancellation)
}
private func unregisterActiveRequest(requestId: String) {
activeRequests.removeValue(forKey: requestId)
} }
// MARK: - Connection handling // MARK: - Connection handling
@@ -171,6 +214,11 @@ final class APIServer {
// MARK: - POST /v1/chat/completions // MARK: - POST /v1/chat/completions
private func handleChatCompletions(connection: NWConnection, body: Data?) async { private func handleChatCompletions(connection: NWConnection, body: Data?) async {
guard !isShuttingDown else {
sendResponse(connection: connection, status: 503, body: #"{"error":"Server is shutting down"}"#)
return
}
guard let body, let request = try? JSONDecoder().decode(APIChatCompletionRequest.self, from: body) else { guard let body, let request = try? JSONDecoder().decode(APIChatCompletionRequest.self, from: body) else {
sendResponse(connection: connection, status: 400, body: #"{"error":"Invalid request body"}"#) sendResponse(connection: connection, status: 400, body: #"{"error":"Invalid request body"}"#)
return return
@@ -312,6 +360,11 @@ final class APIServer {
LiveCounters.shared.requestPhaseChanged(requestId: requestId, phase: .prefilling) LiveCounters.shared.requestPhaseChanged(requestId: requestId, phase: .prefilling)
let cancellation = CancellationToken() let cancellation = CancellationToken()
registerActiveRequest(requestId: requestId, connection: connection, cancellation: cancellation)
defer {
unregisterActiveRequest(requestId: requestId)
}
let streamHandle: InferenceEngine.StreamHandle let streamHandle: InferenceEngine.StreamHandle
do { do {
streamHandle = try await engine.stream( streamHandle = try await engine.stream(
@@ -353,6 +406,7 @@ final class APIServer {
} }
if let cacheKey, if let cacheKey,
!isShuttingDown,
result.succeeded || result.cancelled { result.succeeded || result.cancelled {
Self.storePromptCache( Self.storePromptCache(
streamHandle.workingCache, streamHandle.workingCache,

View File

@@ -181,15 +181,18 @@ final class ChatViewModel {
} }
func stop() { func stop() {
generationTask?.cancel() _ = cancelActiveGeneration()
generationTask = nil
isGenerating = false
if let last = conversation.messages.indices.last,
conversation.messages[last].isStreaming {
conversation.finalizeMessage(at: last)
markDirtyIfNeeded()
} }
func prepareForTermination() async {
autosaveToSandbox()
let activeGeneration = cancelActiveGeneration()
await apiServer.shutdown()
await activeGeneration?.value
resetSession()
modelManager.unloadModel()
} }
func attachImage(_ image: NSImage) { func attachImage(_ image: NSImage) {
@@ -564,4 +567,20 @@ final class ChatViewModel {
func stopAPIServer() { func stopAPIServer() {
apiServer.stop() apiServer.stop()
} }
@discardableResult
private func cancelActiveGeneration() -> Task<Void, Never>? {
let activeGeneration = generationTask
activeGeneration?.cancel()
generationTask = nil
isGenerating = false
if let last = conversation.messages.indices.last,
conversation.messages[last].isStreaming {
conversation.finalizeMessage(at: last)
markDirtyIfNeeded()
}
return activeGeneration
}
} }

View File

@@ -998,6 +998,74 @@ final class APIServerRewriteTests: XCTestCase {
XCTAssertLessThan(elapsed, 0.2) XCTAssertLessThan(elapsed, 0.2)
} }
func testServerShutdownCancelsInFlightStreamAndDrainsActiveRequests() async throws {
let harness = try await makeHarness()
let request = APIChatCompletionRequest(
model: "gemma",
messages: [
APIChatMessage(role: "user", content: .text("Count from one to fifty with commas, using many tokens."), name: nil, tool_calls: nil, tool_call_id: nil)
],
temperature: 0,
top_p: 1,
max_tokens: 128,
stream: true,
stop: nil,
tools: nil,
tool_choice: nil,
frequency_penalty: nil,
presence_penalty: nil,
n: nil
)
let url = URL(string: "http://127.0.0.1:\(harness.port)/v1/chat/completions")!
var urlRequest = URLRequest(url: url)
urlRequest.httpMethod = "POST"
urlRequest.setValue("application/json", forHTTPHeaderField: "Content-Type")
urlRequest.httpBody = try JSONEncoder().encode(request)
let observer = StreamCancellationObserver()
let session = URLSession(configuration: .ephemeral)
let task = Task {
let (bytes, response) = try await session.bytes(for: urlRequest)
let httpResponse = try XCTUnwrap(response as? HTTPURLResponse)
XCTAssertEqual(httpResponse.statusCode, 200)
for try await line in bytes.lines {
guard line.hasPrefix("data: ") else { continue }
let payload = String(line.dropFirst(6))
if payload == "[DONE]" {
break
}
guard let data = payload.data(using: .utf8) else { continue }
let chunk = try JSONDecoder().decode(APIChatCompletionChunk.self, from: data)
if let deltaContent = chunk.choices.first?.delta.content, !deltaContent.isEmpty {
await observer.markFirstContentSeen()
try await Task.sleep(nanoseconds: 30_000_000_000)
}
}
}
try await waitUntil(timeoutSeconds: 10) {
await observer.hasSeenFirstContent
}
await harness.server.shutdown(timeoutSeconds: 2.0)
let liveSnapshot = LiveCounters.shared.snapshot()
XCTAssertEqual(liveSnapshot.activeRequests, 0)
let isRunning = await MainActor.run { harness.server.isRunning }
XCTAssertFalse(isRunning)
session.invalidateAndCancel()
task.cancel()
_ = try? await task.value
await MainActor.run {
harness.modelManager.unloadModel()
}
TokenPrefixCache.shared.reset()
}
func testRepeatedStreamingDisconnectsDoNotBreakSubsequentGeneration() async throws { func testRepeatedStreamingDisconnectsDoNotBreakSubsequentGeneration() async throws {
let harness = try await makeHarness() let harness = try await makeHarness()
defer { harness.stop() } defer { harness.stop() }