fix: app-close-crash
This commit is contained in:
@@ -4,11 +4,26 @@ import MLX
|
|||||||
@MainActor
|
@MainActor
|
||||||
final class AppDelegate: NSObject, NSApplicationDelegate {
|
final class AppDelegate: NSObject, NSApplicationDelegate {
|
||||||
var chatViewModel: ChatViewModel?
|
var chatViewModel: ChatViewModel?
|
||||||
|
private var terminationTask: Task<Void, Never>?
|
||||||
|
|
||||||
func application(_ application: NSApplication, open urls: [URL]) {
|
func application(_ application: NSApplication, open urls: [URL]) {
|
||||||
ChatDocumentController.shared.enqueueOpenRequests(urls)
|
ChatDocumentController.shared.enqueueOpenRequests(urls)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func applicationShouldTerminate(_ sender: NSApplication) -> NSApplication.TerminateReply {
|
||||||
|
if terminationTask != nil {
|
||||||
|
return .terminateLater
|
||||||
|
}
|
||||||
|
|
||||||
|
terminationTask = Task { @MainActor [weak self] in
|
||||||
|
await self?.chatViewModel?.prepareForTermination()
|
||||||
|
sender.reply(toApplicationShouldTerminate: true)
|
||||||
|
self?.terminationTask = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return .terminateLater
|
||||||
|
}
|
||||||
|
|
||||||
func applicationWillTerminate(_ notification: Notification) {
|
func applicationWillTerminate(_ notification: Notification) {
|
||||||
chatViewModel?.autosaveToSandbox()
|
chatViewModel?.autosaveToSandbox()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -15,6 +15,11 @@ final class APIServer {
|
|||||||
let matchedTokenCount: Int
|
let matchedTokenCount: Int
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private struct ActiveRequest {
|
||||||
|
let connection: NWConnection
|
||||||
|
let cancellation: CancellationToken
|
||||||
|
}
|
||||||
|
|
||||||
nonisolated(unsafe) static var debugLookupEventHandler: (@Sendable (DebugLookupEvent) -> Void)?
|
nonisolated(unsafe) static var debugLookupEventHandler: (@Sendable (DebugLookupEvent) -> Void)?
|
||||||
|
|
||||||
var isRunning = false
|
var isRunning = false
|
||||||
@@ -24,11 +29,14 @@ final class APIServer {
|
|||||||
|
|
||||||
private var listener: NWListener?
|
private var listener: NWListener?
|
||||||
private var modelManager: ModelManager?
|
private var modelManager: ModelManager?
|
||||||
|
private var activeRequests: [String: ActiveRequest] = [:]
|
||||||
|
private var isShuttingDown = false
|
||||||
|
|
||||||
func start(modelManager: ModelManager, port: Int = 1234) {
|
func start(modelManager: ModelManager, port: Int = 1234) {
|
||||||
guard !isRunning else { return }
|
guard !isRunning else { return }
|
||||||
self.modelManager = modelManager
|
self.modelManager = modelManager
|
||||||
self.port = port
|
self.port = port
|
||||||
|
self.isShuttingDown = false
|
||||||
|
|
||||||
do {
|
do {
|
||||||
let params = NWParameters.tcp
|
let params = NWParameters.tcp
|
||||||
@@ -70,11 +78,46 @@ final class APIServer {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func stop() {
|
func stop() {
|
||||||
|
beginShutdown()
|
||||||
|
TokenPrefixCache.shared.invalidateAll()
|
||||||
|
inferenceStats.stopSampling()
|
||||||
|
}
|
||||||
|
|
||||||
|
func shutdown(timeoutSeconds: TimeInterval = 2.0) async {
|
||||||
|
beginShutdown()
|
||||||
|
|
||||||
|
let deadline = Date().addingTimeInterval(timeoutSeconds)
|
||||||
|
while !activeRequests.isEmpty && Date() < deadline {
|
||||||
|
try? await Task.sleep(nanoseconds: 10_000_000)
|
||||||
|
}
|
||||||
|
|
||||||
|
TokenPrefixCache.shared.invalidateAll()
|
||||||
|
inferenceStats.stopSampling()
|
||||||
|
}
|
||||||
|
|
||||||
|
private func beginShutdown() {
|
||||||
|
guard !isShuttingDown else { return }
|
||||||
|
isShuttingDown = true
|
||||||
listener?.cancel()
|
listener?.cancel()
|
||||||
listener = nil
|
listener = nil
|
||||||
isRunning = false
|
isRunning = false
|
||||||
TokenPrefixCache.shared.invalidateAll()
|
|
||||||
inferenceStats.stopSampling()
|
for activeRequest in activeRequests.values {
|
||||||
|
activeRequest.cancellation.cancel()
|
||||||
|
activeRequest.connection.cancel()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private func registerActiveRequest(
|
||||||
|
requestId: String,
|
||||||
|
connection: NWConnection,
|
||||||
|
cancellation: CancellationToken
|
||||||
|
) {
|
||||||
|
activeRequests[requestId] = ActiveRequest(connection: connection, cancellation: cancellation)
|
||||||
|
}
|
||||||
|
|
||||||
|
private func unregisterActiveRequest(requestId: String) {
|
||||||
|
activeRequests.removeValue(forKey: requestId)
|
||||||
}
|
}
|
||||||
|
|
||||||
// MARK: - Connection handling
|
// MARK: - Connection handling
|
||||||
@@ -171,6 +214,11 @@ final class APIServer {
|
|||||||
// MARK: - POST /v1/chat/completions
|
// MARK: - POST /v1/chat/completions
|
||||||
|
|
||||||
private func handleChatCompletions(connection: NWConnection, body: Data?) async {
|
private func handleChatCompletions(connection: NWConnection, body: Data?) async {
|
||||||
|
guard !isShuttingDown else {
|
||||||
|
sendResponse(connection: connection, status: 503, body: #"{"error":"Server is shutting down"}"#)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
guard let body, let request = try? JSONDecoder().decode(APIChatCompletionRequest.self, from: body) else {
|
guard let body, let request = try? JSONDecoder().decode(APIChatCompletionRequest.self, from: body) else {
|
||||||
sendResponse(connection: connection, status: 400, body: #"{"error":"Invalid request body"}"#)
|
sendResponse(connection: connection, status: 400, body: #"{"error":"Invalid request body"}"#)
|
||||||
return
|
return
|
||||||
@@ -312,6 +360,11 @@ final class APIServer {
|
|||||||
LiveCounters.shared.requestPhaseChanged(requestId: requestId, phase: .prefilling)
|
LiveCounters.shared.requestPhaseChanged(requestId: requestId, phase: .prefilling)
|
||||||
|
|
||||||
let cancellation = CancellationToken()
|
let cancellation = CancellationToken()
|
||||||
|
registerActiveRequest(requestId: requestId, connection: connection, cancellation: cancellation)
|
||||||
|
defer {
|
||||||
|
unregisterActiveRequest(requestId: requestId)
|
||||||
|
}
|
||||||
|
|
||||||
let streamHandle: InferenceEngine.StreamHandle
|
let streamHandle: InferenceEngine.StreamHandle
|
||||||
do {
|
do {
|
||||||
streamHandle = try await engine.stream(
|
streamHandle = try await engine.stream(
|
||||||
@@ -353,6 +406,7 @@ final class APIServer {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if let cacheKey,
|
if let cacheKey,
|
||||||
|
!isShuttingDown,
|
||||||
result.succeeded || result.cancelled {
|
result.succeeded || result.cancelled {
|
||||||
Self.storePromptCache(
|
Self.storePromptCache(
|
||||||
streamHandle.workingCache,
|
streamHandle.workingCache,
|
||||||
|
|||||||
@@ -181,15 +181,18 @@ final class ChatViewModel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func stop() {
|
func stop() {
|
||||||
generationTask?.cancel()
|
_ = cancelActiveGeneration()
|
||||||
generationTask = nil
|
|
||||||
isGenerating = false
|
|
||||||
|
|
||||||
if let last = conversation.messages.indices.last,
|
|
||||||
conversation.messages[last].isStreaming {
|
|
||||||
conversation.finalizeMessage(at: last)
|
|
||||||
markDirtyIfNeeded()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func prepareForTermination() async {
|
||||||
|
autosaveToSandbox()
|
||||||
|
|
||||||
|
let activeGeneration = cancelActiveGeneration()
|
||||||
|
await apiServer.shutdown()
|
||||||
|
await activeGeneration?.value
|
||||||
|
|
||||||
|
resetSession()
|
||||||
|
modelManager.unloadModel()
|
||||||
}
|
}
|
||||||
|
|
||||||
func attachImage(_ image: NSImage) {
|
func attachImage(_ image: NSImage) {
|
||||||
@@ -564,4 +567,20 @@ final class ChatViewModel {
|
|||||||
func stopAPIServer() {
|
func stopAPIServer() {
|
||||||
apiServer.stop()
|
apiServer.stop()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@discardableResult
|
||||||
|
private func cancelActiveGeneration() -> Task<Void, Never>? {
|
||||||
|
let activeGeneration = generationTask
|
||||||
|
activeGeneration?.cancel()
|
||||||
|
generationTask = nil
|
||||||
|
isGenerating = false
|
||||||
|
|
||||||
|
if let last = conversation.messages.indices.last,
|
||||||
|
conversation.messages[last].isStreaming {
|
||||||
|
conversation.finalizeMessage(at: last)
|
||||||
|
markDirtyIfNeeded()
|
||||||
|
}
|
||||||
|
|
||||||
|
return activeGeneration
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -998,6 +998,74 @@ final class APIServerRewriteTests: XCTestCase {
|
|||||||
XCTAssertLessThan(elapsed, 0.2)
|
XCTAssertLessThan(elapsed, 0.2)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func testServerShutdownCancelsInFlightStreamAndDrainsActiveRequests() async throws {
|
||||||
|
let harness = try await makeHarness()
|
||||||
|
|
||||||
|
let request = APIChatCompletionRequest(
|
||||||
|
model: "gemma",
|
||||||
|
messages: [
|
||||||
|
APIChatMessage(role: "user", content: .text("Count from one to fifty with commas, using many tokens."), name: nil, tool_calls: nil, tool_call_id: nil)
|
||||||
|
],
|
||||||
|
temperature: 0,
|
||||||
|
top_p: 1,
|
||||||
|
max_tokens: 128,
|
||||||
|
stream: true,
|
||||||
|
stop: nil,
|
||||||
|
tools: nil,
|
||||||
|
tool_choice: nil,
|
||||||
|
frequency_penalty: nil,
|
||||||
|
presence_penalty: nil,
|
||||||
|
n: nil
|
||||||
|
)
|
||||||
|
|
||||||
|
let url = URL(string: "http://127.0.0.1:\(harness.port)/v1/chat/completions")!
|
||||||
|
var urlRequest = URLRequest(url: url)
|
||||||
|
urlRequest.httpMethod = "POST"
|
||||||
|
urlRequest.setValue("application/json", forHTTPHeaderField: "Content-Type")
|
||||||
|
urlRequest.httpBody = try JSONEncoder().encode(request)
|
||||||
|
|
||||||
|
let observer = StreamCancellationObserver()
|
||||||
|
let session = URLSession(configuration: .ephemeral)
|
||||||
|
let task = Task {
|
||||||
|
let (bytes, response) = try await session.bytes(for: urlRequest)
|
||||||
|
let httpResponse = try XCTUnwrap(response as? HTTPURLResponse)
|
||||||
|
XCTAssertEqual(httpResponse.statusCode, 200)
|
||||||
|
|
||||||
|
for try await line in bytes.lines {
|
||||||
|
guard line.hasPrefix("data: ") else { continue }
|
||||||
|
let payload = String(line.dropFirst(6))
|
||||||
|
if payload == "[DONE]" {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
guard let data = payload.data(using: .utf8) else { continue }
|
||||||
|
let chunk = try JSONDecoder().decode(APIChatCompletionChunk.self, from: data)
|
||||||
|
if let deltaContent = chunk.choices.first?.delta.content, !deltaContent.isEmpty {
|
||||||
|
await observer.markFirstContentSeen()
|
||||||
|
try await Task.sleep(nanoseconds: 30_000_000_000)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
try await waitUntil(timeoutSeconds: 10) {
|
||||||
|
await observer.hasSeenFirstContent
|
||||||
|
}
|
||||||
|
|
||||||
|
await harness.server.shutdown(timeoutSeconds: 2.0)
|
||||||
|
|
||||||
|
let liveSnapshot = LiveCounters.shared.snapshot()
|
||||||
|
XCTAssertEqual(liveSnapshot.activeRequests, 0)
|
||||||
|
let isRunning = await MainActor.run { harness.server.isRunning }
|
||||||
|
XCTAssertFalse(isRunning)
|
||||||
|
|
||||||
|
session.invalidateAndCancel()
|
||||||
|
task.cancel()
|
||||||
|
_ = try? await task.value
|
||||||
|
await MainActor.run {
|
||||||
|
harness.modelManager.unloadModel()
|
||||||
|
}
|
||||||
|
TokenPrefixCache.shared.reset()
|
||||||
|
}
|
||||||
|
|
||||||
func testRepeatedStreamingDisconnectsDoNotBreakSubsequentGeneration() async throws {
|
func testRepeatedStreamingDisconnectsDoNotBreakSubsequentGeneration() async throws {
|
||||||
let harness = try await makeHarness()
|
let harness = try await makeHarness()
|
||||||
defer { harness.stop() }
|
defer { harness.stop() }
|
||||||
|
|||||||
Reference in New Issue
Block a user