From 11300e3034cc95075df2ec4063d11b3f96c3dfbd Mon Sep 17 00:00:00 2001 From: Chili Palmer Date: Thu, 30 Apr 2026 11:58:53 +0200 Subject: [PATCH] feat: more o n migration to v3 --- .vscode/settings.json | 3 +- MLXServer/Server/APIServer.swift | 18 +++++--- MLXServer/Utilities/LocalModelResolver.swift | 45 ++++++++++++++++---- MLXServer/ViewModels/ChatViewModel.swift | 15 ++++--- MLXServer/ViewModels/ModelManager.swift | 16 +++---- MLXServer/Views/DownloadModalView.swift | 19 +++++++-- 6 files changed, 81 insertions(+), 35 deletions(-) diff --git a/.vscode/settings.json b/.vscode/settings.json index 84c9d3b..4ba81ef 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -1,6 +1,7 @@ { "chat.tools.terminal.autoApprove": { "./test.sh": true, - "setopt": true + "setopt": true, + "./build.sh": true } } \ No newline at end of file diff --git a/MLXServer/Server/APIServer.swift b/MLXServer/Server/APIServer.swift index 28fba32..f3b9c8c 100644 --- a/MLXServer/Server/APIServer.swift +++ b/MLXServer/Server/APIServer.swift @@ -335,19 +335,23 @@ final class APIServer { } } + // NOTE: repetition / presence / frequency penalties are intentionally + // not forwarded to GenerateParameters. mlx-swift-lm 3.31.3's + // PenaltyProcessor uses TokenRing.loadPrompt, which assumes a 1-D + // prompt MLXArray. VLM models (Gemma3, Qwen-VL, …) hand it a 2-D + // [1, N] tokens array, so the ring buffer ends up the wrong size and + // every later MLX.where in TokenRing.append crashes via fatalError. + // Re-enable once upstream fixes TokenRing to flatten the prompt. let generateParams = GenerateParameters( maxTokens: maxTokens, temperature: Float(generationSettings.temperature), topP: Float(generationSettings.topP), topK: generationSettings.topK, - minP: Float(generationSettings.minP), - repetitionPenalty: generationSettings.repetitionPenalty.map(Float.init), - repetitionContextSize: 128, - presencePenalty: generationSettings.presencePenalty.map(Float.init), - presenceContextSize: 128, - frequencyPenalty: generationSettings.frequencyPenalty.map(Float.init), - frequencyContextSize: 128 + minP: Float(generationSettings.minP) ) + _ = generationSettings.repetitionPenalty + _ = generationSettings.presencePenalty + _ = generationSettings.frequencyPenalty let currentModelId = modelManager.currentModel?.id ?? modelName let engine = InferenceEngine(container: container) let preparedInference: InferenceEngine.PreparedInference diff --git a/MLXServer/Utilities/LocalModelResolver.swift b/MLXServer/Utilities/LocalModelResolver.swift index a946089..57d516d 100644 --- a/MLXServer/Utilities/LocalModelResolver.swift +++ b/MLXServer/Utilities/LocalModelResolver.swift @@ -51,10 +51,13 @@ enum LocalModelResolver { print("[LocalModelResolver] Found \(snapshotDirs.count) snapshots") for snapshotDir in snapshotDirs where isDirectory(snapshotDir) { let configPath = snapshotDir.appendingPathComponent("config.json") - if FileManager.default.fileExists(atPath: configPath.path) { - print("[LocalModelResolver] Found valid snapshot: \(snapshotDir.path)") - return snapshotDir + guard FileManager.default.fileExists(atPath: configPath.path) else { continue } + guard hasCompleteWeights(at: snapshotDir) else { + print("[LocalModelResolver] Snapshot missing weight files (incomplete download): \(snapshotDir.path)") + continue } + print("[LocalModelResolver] Found valid snapshot: \(snapshotDir.path)") + return snapshotDir } } @@ -155,12 +158,36 @@ enum LocalModelResolver { } private static func containsModelArtifacts(at directory: URL) -> Bool { - let requiredPaths = [ - directory.appendingPathComponent("config.json").path, - directory.appendingPathComponent("model.safetensors").path, - directory.appendingPathComponent("model.safetensors.index.json").path, - ] - return requiredPaths.contains { FileManager.default.fileExists(atPath: $0) } + let configExists = FileManager.default.fileExists( + atPath: directory.appendingPathComponent("config.json").path + ) + return configExists && hasCompleteWeights(at: directory) + } + + /// Returns true when the snapshot has the actual weight files on disk: + /// either a single `model.safetensors`, or every shard listed in + /// `model.safetensors.index.json`. Returns false for partial/interrupted downloads. + static func hasCompleteWeights(at directory: URL) -> Bool { + let fm = FileManager.default + let single = directory.appendingPathComponent("model.safetensors") + if fm.fileExists(atPath: single.path) { + return true + } + + let indexURL = directory.appendingPathComponent("model.safetensors.index.json") + guard fm.fileExists(atPath: indexURL.path), + let data = try? Data(contentsOf: indexURL), + let json = (try? JSONSerialization.jsonObject(with: data)) as? [String: Any], + let weightMap = json["weight_map"] as? [String: Any] + else { + return false + } + + let shardNames = Set(weightMap.values.compactMap { $0 as? String }) + guard !shardNames.isEmpty else { return false } + return shardNames.allSatisfy { name in + fm.fileExists(atPath: directory.appendingPathComponent(name).path) + } } /// Delete the local cache for a model so it will be re-downloaded next time. diff --git a/MLXServer/ViewModels/ChatViewModel.swift b/MLXServer/ViewModels/ChatViewModel.swift index 48ff9f9..fbbc98c 100644 --- a/MLXServer/ViewModels/ChatViewModel.swift +++ b/MLXServer/ViewModels/ChatViewModel.swift @@ -88,18 +88,19 @@ final class ChatViewModel { let thinkingContext: [String: any Sendable]? = generationSettings.thinkingEnabled ? nil : ["enable_thinking": false] + // NOTE: repetition / presence / frequency penalties are intentionally + // not forwarded to GenerateParameters. mlx-swift-lm 3.31.3's + // PenaltyProcessor uses TokenRing.loadPrompt, which assumes a 1-D + // prompt MLXArray. VLM models (Gemma3, Qwen-VL, …) hand it a 2-D + // [1, N] tokens array, so the ring buffer ends up the wrong size and + // every later MLX.where in TokenRing.append crashes via fatalError. + // Re-enable once upstream fixes TokenRing to flatten the prompt. let generateParameters = GenerateParameters( maxTokens: generationSettings.maxTokens, temperature: Float(generationSettings.temperature), topP: Float(generationSettings.topP), topK: generationSettings.topK, - minP: Float(generationSettings.minP), - repetitionPenalty: generationSettings.repetitionPenalty.map(Float.init), - repetitionContextSize: 128, - presencePenalty: generationSettings.presencePenalty.map(Float.init), - presenceContextSize: 128, - frequencyPenalty: generationSettings.frequencyPenalty.map(Float.init), - frequencyContextSize: 128 + minP: Float(generationSettings.minP) ) let history = conversation.messages.compactMap(historyMessage(from:)) if history.isEmpty { diff --git a/MLXServer/ViewModels/ModelManager.swift b/MLXServer/ViewModels/ModelManager.swift index ba00223..11f3457 100644 --- a/MLXServer/ViewModels/ModelManager.swift +++ b/MLXServer/ViewModels/ModelManager.swift @@ -25,8 +25,8 @@ final class ModelManager { // Download-specific state for the modal var isDownloading = false - var downloadFilesTotal: Int64 = 0 - var downloadFilesCompleted: Int64 = 0 + var downloadBytesTotal: Int64 = 0 + var downloadBytesCompleted: Int64 = 0 var downloadSpeed: Double = 0 // bytes/sec private var idleTimer: Timer? @@ -87,8 +87,8 @@ final class ModelManager { isDownloading = false downloadProgress = 0 loadingModelName = "" - downloadFilesTotal = 0 - downloadFilesCompleted = 0 + downloadBytesTotal = 0 + downloadBytesCompleted = 0 downloadSpeed = 0 } @@ -116,8 +116,8 @@ final class ModelManager { let needsDownload = !effectiveConfig.isLocal if needsDownload { isDownloading = true - downloadFilesTotal = 0 - downloadFilesCompleted = 0 + downloadBytesTotal = 0 + downloadBytesCompleted = 0 downloadSpeed = 0 } @@ -126,8 +126,8 @@ final class ModelManager { Task { @MainActor in self.downloadProgress = progress.fractionCompleted if self.isDownloading { - self.downloadFilesTotal = progress.totalUnitCount - self.downloadFilesCompleted = progress.completedUnitCount + self.downloadBytesTotal = progress.totalUnitCount + self.downloadBytesCompleted = progress.completedUnitCount if let speed = progress.userInfo[.throughputKey] as? Double { self.downloadSpeed = speed } diff --git a/MLXServer/Views/DownloadModalView.swift b/MLXServer/Views/DownloadModalView.swift index 7c38dd4..7ca1b33 100644 --- a/MLXServer/Views/DownloadModalView.swift +++ b/MLXServer/Views/DownloadModalView.swift @@ -20,9 +20,9 @@ struct DownloadModalView: View { .progressViewStyle(.linear) HStack { - // Files progress - if modelManager.downloadFilesTotal > 0 { - Text("File \(modelManager.downloadFilesCompleted)/\(modelManager.downloadFilesTotal)") + // Bytes progress + if modelManager.downloadBytesTotal > 0 { + Text("\(formatBytes(modelManager.downloadBytesCompleted)) / \(formatBytes(modelManager.downloadBytesTotal))") .font(.caption.monospacedDigit()) .foregroundStyle(.secondary) } @@ -65,4 +65,17 @@ struct DownloadModalView: View { return String(format: "%.0f B/s", bytesPerSec) } } + + private func formatBytes(_ bytes: Int64) -> String { + let value = Double(bytes) + if value >= 1_073_741_824 { + return String(format: "%.2f GB", value / 1_073_741_824) + } else if value >= 1_048_576 { + return String(format: "%.0f MB", value / 1_048_576) + } else if value >= 1024 { + return String(format: "%.0f KB", value / 1024) + } else { + return "\(bytes) B" + } + } }