fix: app-close-crash
This commit is contained in:
@@ -4,11 +4,26 @@ import MLX
|
||||
@MainActor
|
||||
final class AppDelegate: NSObject, NSApplicationDelegate {
|
||||
var chatViewModel: ChatViewModel?
|
||||
private var terminationTask: Task<Void, Never>?
|
||||
|
||||
func application(_ application: NSApplication, open urls: [URL]) {
|
||||
ChatDocumentController.shared.enqueueOpenRequests(urls)
|
||||
}
|
||||
|
||||
func applicationShouldTerminate(_ sender: NSApplication) -> NSApplication.TerminateReply {
|
||||
if terminationTask != nil {
|
||||
return .terminateLater
|
||||
}
|
||||
|
||||
terminationTask = Task { @MainActor [weak self] in
|
||||
await self?.chatViewModel?.prepareForTermination()
|
||||
sender.reply(toApplicationShouldTerminate: true)
|
||||
self?.terminationTask = nil
|
||||
}
|
||||
|
||||
return .terminateLater
|
||||
}
|
||||
|
||||
func applicationWillTerminate(_ notification: Notification) {
|
||||
chatViewModel?.autosaveToSandbox()
|
||||
}
|
||||
|
||||
@@ -15,6 +15,11 @@ final class APIServer {
|
||||
let matchedTokenCount: Int
|
||||
}
|
||||
|
||||
private struct ActiveRequest {
|
||||
let connection: NWConnection
|
||||
let cancellation: CancellationToken
|
||||
}
|
||||
|
||||
nonisolated(unsafe) static var debugLookupEventHandler: (@Sendable (DebugLookupEvent) -> Void)?
|
||||
|
||||
var isRunning = false
|
||||
@@ -24,11 +29,14 @@ final class APIServer {
|
||||
|
||||
private var listener: NWListener?
|
||||
private var modelManager: ModelManager?
|
||||
private var activeRequests: [String: ActiveRequest] = [:]
|
||||
private var isShuttingDown = false
|
||||
|
||||
func start(modelManager: ModelManager, port: Int = 1234) {
|
||||
guard !isRunning else { return }
|
||||
self.modelManager = modelManager
|
||||
self.port = port
|
||||
self.isShuttingDown = false
|
||||
|
||||
do {
|
||||
let params = NWParameters.tcp
|
||||
@@ -70,11 +78,46 @@ final class APIServer {
|
||||
}
|
||||
|
||||
func stop() {
|
||||
beginShutdown()
|
||||
TokenPrefixCache.shared.invalidateAll()
|
||||
inferenceStats.stopSampling()
|
||||
}
|
||||
|
||||
func shutdown(timeoutSeconds: TimeInterval = 2.0) async {
|
||||
beginShutdown()
|
||||
|
||||
let deadline = Date().addingTimeInterval(timeoutSeconds)
|
||||
while !activeRequests.isEmpty && Date() < deadline {
|
||||
try? await Task.sleep(nanoseconds: 10_000_000)
|
||||
}
|
||||
|
||||
TokenPrefixCache.shared.invalidateAll()
|
||||
inferenceStats.stopSampling()
|
||||
}
|
||||
|
||||
private func beginShutdown() {
|
||||
guard !isShuttingDown else { return }
|
||||
isShuttingDown = true
|
||||
listener?.cancel()
|
||||
listener = nil
|
||||
isRunning = false
|
||||
TokenPrefixCache.shared.invalidateAll()
|
||||
inferenceStats.stopSampling()
|
||||
|
||||
for activeRequest in activeRequests.values {
|
||||
activeRequest.cancellation.cancel()
|
||||
activeRequest.connection.cancel()
|
||||
}
|
||||
}
|
||||
|
||||
private func registerActiveRequest(
|
||||
requestId: String,
|
||||
connection: NWConnection,
|
||||
cancellation: CancellationToken
|
||||
) {
|
||||
activeRequests[requestId] = ActiveRequest(connection: connection, cancellation: cancellation)
|
||||
}
|
||||
|
||||
private func unregisterActiveRequest(requestId: String) {
|
||||
activeRequests.removeValue(forKey: requestId)
|
||||
}
|
||||
|
||||
// MARK: - Connection handling
|
||||
@@ -171,6 +214,11 @@ final class APIServer {
|
||||
// MARK: - POST /v1/chat/completions
|
||||
|
||||
private func handleChatCompletions(connection: NWConnection, body: Data?) async {
|
||||
guard !isShuttingDown else {
|
||||
sendResponse(connection: connection, status: 503, body: #"{"error":"Server is shutting down"}"#)
|
||||
return
|
||||
}
|
||||
|
||||
guard let body, let request = try? JSONDecoder().decode(APIChatCompletionRequest.self, from: body) else {
|
||||
sendResponse(connection: connection, status: 400, body: #"{"error":"Invalid request body"}"#)
|
||||
return
|
||||
@@ -312,6 +360,11 @@ final class APIServer {
|
||||
LiveCounters.shared.requestPhaseChanged(requestId: requestId, phase: .prefilling)
|
||||
|
||||
let cancellation = CancellationToken()
|
||||
registerActiveRequest(requestId: requestId, connection: connection, cancellation: cancellation)
|
||||
defer {
|
||||
unregisterActiveRequest(requestId: requestId)
|
||||
}
|
||||
|
||||
let streamHandle: InferenceEngine.StreamHandle
|
||||
do {
|
||||
streamHandle = try await engine.stream(
|
||||
@@ -353,6 +406,7 @@ final class APIServer {
|
||||
}
|
||||
|
||||
if let cacheKey,
|
||||
!isShuttingDown,
|
||||
result.succeeded || result.cancelled {
|
||||
Self.storePromptCache(
|
||||
streamHandle.workingCache,
|
||||
|
||||
@@ -181,15 +181,18 @@ final class ChatViewModel {
|
||||
}
|
||||
|
||||
func stop() {
|
||||
generationTask?.cancel()
|
||||
generationTask = nil
|
||||
isGenerating = false
|
||||
|
||||
if let last = conversation.messages.indices.last,
|
||||
conversation.messages[last].isStreaming {
|
||||
conversation.finalizeMessage(at: last)
|
||||
markDirtyIfNeeded()
|
||||
_ = cancelActiveGeneration()
|
||||
}
|
||||
|
||||
func prepareForTermination() async {
|
||||
autosaveToSandbox()
|
||||
|
||||
let activeGeneration = cancelActiveGeneration()
|
||||
await apiServer.shutdown()
|
||||
await activeGeneration?.value
|
||||
|
||||
resetSession()
|
||||
modelManager.unloadModel()
|
||||
}
|
||||
|
||||
func attachImage(_ image: NSImage) {
|
||||
@@ -564,4 +567,20 @@ final class ChatViewModel {
|
||||
func stopAPIServer() {
|
||||
apiServer.stop()
|
||||
}
|
||||
|
||||
@discardableResult
|
||||
private func cancelActiveGeneration() -> Task<Void, Never>? {
|
||||
let activeGeneration = generationTask
|
||||
activeGeneration?.cancel()
|
||||
generationTask = nil
|
||||
isGenerating = false
|
||||
|
||||
if let last = conversation.messages.indices.last,
|
||||
conversation.messages[last].isStreaming {
|
||||
conversation.finalizeMessage(at: last)
|
||||
markDirtyIfNeeded()
|
||||
}
|
||||
|
||||
return activeGeneration
|
||||
}
|
||||
}
|
||||
|
||||
@@ -998,6 +998,74 @@ final class APIServerRewriteTests: XCTestCase {
|
||||
XCTAssertLessThan(elapsed, 0.2)
|
||||
}
|
||||
|
||||
func testServerShutdownCancelsInFlightStreamAndDrainsActiveRequests() async throws {
|
||||
let harness = try await makeHarness()
|
||||
|
||||
let request = APIChatCompletionRequest(
|
||||
model: "gemma",
|
||||
messages: [
|
||||
APIChatMessage(role: "user", content: .text("Count from one to fifty with commas, using many tokens."), name: nil, tool_calls: nil, tool_call_id: nil)
|
||||
],
|
||||
temperature: 0,
|
||||
top_p: 1,
|
||||
max_tokens: 128,
|
||||
stream: true,
|
||||
stop: nil,
|
||||
tools: nil,
|
||||
tool_choice: nil,
|
||||
frequency_penalty: nil,
|
||||
presence_penalty: nil,
|
||||
n: nil
|
||||
)
|
||||
|
||||
let url = URL(string: "http://127.0.0.1:\(harness.port)/v1/chat/completions")!
|
||||
var urlRequest = URLRequest(url: url)
|
||||
urlRequest.httpMethod = "POST"
|
||||
urlRequest.setValue("application/json", forHTTPHeaderField: "Content-Type")
|
||||
urlRequest.httpBody = try JSONEncoder().encode(request)
|
||||
|
||||
let observer = StreamCancellationObserver()
|
||||
let session = URLSession(configuration: .ephemeral)
|
||||
let task = Task {
|
||||
let (bytes, response) = try await session.bytes(for: urlRequest)
|
||||
let httpResponse = try XCTUnwrap(response as? HTTPURLResponse)
|
||||
XCTAssertEqual(httpResponse.statusCode, 200)
|
||||
|
||||
for try await line in bytes.lines {
|
||||
guard line.hasPrefix("data: ") else { continue }
|
||||
let payload = String(line.dropFirst(6))
|
||||
if payload == "[DONE]" {
|
||||
break
|
||||
}
|
||||
guard let data = payload.data(using: .utf8) else { continue }
|
||||
let chunk = try JSONDecoder().decode(APIChatCompletionChunk.self, from: data)
|
||||
if let deltaContent = chunk.choices.first?.delta.content, !deltaContent.isEmpty {
|
||||
await observer.markFirstContentSeen()
|
||||
try await Task.sleep(nanoseconds: 30_000_000_000)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
try await waitUntil(timeoutSeconds: 10) {
|
||||
await observer.hasSeenFirstContent
|
||||
}
|
||||
|
||||
await harness.server.shutdown(timeoutSeconds: 2.0)
|
||||
|
||||
let liveSnapshot = LiveCounters.shared.snapshot()
|
||||
XCTAssertEqual(liveSnapshot.activeRequests, 0)
|
||||
let isRunning = await MainActor.run { harness.server.isRunning }
|
||||
XCTAssertFalse(isRunning)
|
||||
|
||||
session.invalidateAndCancel()
|
||||
task.cancel()
|
||||
_ = try? await task.value
|
||||
await MainActor.run {
|
||||
harness.modelManager.unloadModel()
|
||||
}
|
||||
TokenPrefixCache.shared.reset()
|
||||
}
|
||||
|
||||
func testRepeatedStreamingDisconnectsDoNotBreakSubsequentGeneration() async throws {
|
||||
let harness = try await makeHarness()
|
||||
defer { harness.stop() }
|
||||
|
||||
Reference in New Issue
Block a user