feat: more o n migration to v3
This commit is contained in:
3
.vscode/settings.json
vendored
3
.vscode/settings.json
vendored
@@ -1,6 +1,7 @@
|
||||
{
|
||||
"chat.tools.terminal.autoApprove": {
|
||||
"./test.sh": true,
|
||||
"setopt": true
|
||||
"setopt": true,
|
||||
"./build.sh": true
|
||||
}
|
||||
}
|
||||
@@ -335,19 +335,23 @@ final class APIServer {
|
||||
}
|
||||
}
|
||||
|
||||
// NOTE: repetition / presence / frequency penalties are intentionally
|
||||
// not forwarded to GenerateParameters. mlx-swift-lm 3.31.3's
|
||||
// PenaltyProcessor uses TokenRing.loadPrompt, which assumes a 1-D
|
||||
// prompt MLXArray. VLM models (Gemma3, Qwen-VL, …) hand it a 2-D
|
||||
// [1, N] tokens array, so the ring buffer ends up the wrong size and
|
||||
// every later MLX.where in TokenRing.append crashes via fatalError.
|
||||
// Re-enable once upstream fixes TokenRing to flatten the prompt.
|
||||
let generateParams = GenerateParameters(
|
||||
maxTokens: maxTokens,
|
||||
temperature: Float(generationSettings.temperature),
|
||||
topP: Float(generationSettings.topP),
|
||||
topK: generationSettings.topK,
|
||||
minP: Float(generationSettings.minP),
|
||||
repetitionPenalty: generationSettings.repetitionPenalty.map(Float.init),
|
||||
repetitionContextSize: 128,
|
||||
presencePenalty: generationSettings.presencePenalty.map(Float.init),
|
||||
presenceContextSize: 128,
|
||||
frequencyPenalty: generationSettings.frequencyPenalty.map(Float.init),
|
||||
frequencyContextSize: 128
|
||||
minP: Float(generationSettings.minP)
|
||||
)
|
||||
_ = generationSettings.repetitionPenalty
|
||||
_ = generationSettings.presencePenalty
|
||||
_ = generationSettings.frequencyPenalty
|
||||
let currentModelId = modelManager.currentModel?.id ?? modelName
|
||||
let engine = InferenceEngine(container: container)
|
||||
let preparedInference: InferenceEngine.PreparedInference
|
||||
|
||||
@@ -51,12 +51,15 @@ enum LocalModelResolver {
|
||||
print("[LocalModelResolver] Found \(snapshotDirs.count) snapshots")
|
||||
for snapshotDir in snapshotDirs where isDirectory(snapshotDir) {
|
||||
let configPath = snapshotDir.appendingPathComponent("config.json")
|
||||
if FileManager.default.fileExists(atPath: configPath.path) {
|
||||
guard FileManager.default.fileExists(atPath: configPath.path) else { continue }
|
||||
guard hasCompleteWeights(at: snapshotDir) else {
|
||||
print("[LocalModelResolver] Snapshot missing weight files (incomplete download): \(snapshotDir.path)")
|
||||
continue
|
||||
}
|
||||
print("[LocalModelResolver] Found valid snapshot: \(snapshotDir.path)")
|
||||
return snapshotDir
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
print("[LocalModelResolver] Model not found locally")
|
||||
return nil
|
||||
@@ -155,12 +158,36 @@ enum LocalModelResolver {
|
||||
}
|
||||
|
||||
private static func containsModelArtifacts(at directory: URL) -> Bool {
|
||||
let requiredPaths = [
|
||||
directory.appendingPathComponent("config.json").path,
|
||||
directory.appendingPathComponent("model.safetensors").path,
|
||||
directory.appendingPathComponent("model.safetensors.index.json").path,
|
||||
]
|
||||
return requiredPaths.contains { FileManager.default.fileExists(atPath: $0) }
|
||||
let configExists = FileManager.default.fileExists(
|
||||
atPath: directory.appendingPathComponent("config.json").path
|
||||
)
|
||||
return configExists && hasCompleteWeights(at: directory)
|
||||
}
|
||||
|
||||
/// Returns true when the snapshot has the actual weight files on disk:
|
||||
/// either a single `model.safetensors`, or every shard listed in
|
||||
/// `model.safetensors.index.json`. Returns false for partial/interrupted downloads.
|
||||
static func hasCompleteWeights(at directory: URL) -> Bool {
|
||||
let fm = FileManager.default
|
||||
let single = directory.appendingPathComponent("model.safetensors")
|
||||
if fm.fileExists(atPath: single.path) {
|
||||
return true
|
||||
}
|
||||
|
||||
let indexURL = directory.appendingPathComponent("model.safetensors.index.json")
|
||||
guard fm.fileExists(atPath: indexURL.path),
|
||||
let data = try? Data(contentsOf: indexURL),
|
||||
let json = (try? JSONSerialization.jsonObject(with: data)) as? [String: Any],
|
||||
let weightMap = json["weight_map"] as? [String: Any]
|
||||
else {
|
||||
return false
|
||||
}
|
||||
|
||||
let shardNames = Set(weightMap.values.compactMap { $0 as? String })
|
||||
guard !shardNames.isEmpty else { return false }
|
||||
return shardNames.allSatisfy { name in
|
||||
fm.fileExists(atPath: directory.appendingPathComponent(name).path)
|
||||
}
|
||||
}
|
||||
|
||||
/// Delete the local cache for a model so it will be re-downloaded next time.
|
||||
|
||||
@@ -88,18 +88,19 @@ final class ChatViewModel {
|
||||
let thinkingContext: [String: any Sendable]? = generationSettings.thinkingEnabled
|
||||
? nil
|
||||
: ["enable_thinking": false]
|
||||
// NOTE: repetition / presence / frequency penalties are intentionally
|
||||
// not forwarded to GenerateParameters. mlx-swift-lm 3.31.3's
|
||||
// PenaltyProcessor uses TokenRing.loadPrompt, which assumes a 1-D
|
||||
// prompt MLXArray. VLM models (Gemma3, Qwen-VL, …) hand it a 2-D
|
||||
// [1, N] tokens array, so the ring buffer ends up the wrong size and
|
||||
// every later MLX.where in TokenRing.append crashes via fatalError.
|
||||
// Re-enable once upstream fixes TokenRing to flatten the prompt.
|
||||
let generateParameters = GenerateParameters(
|
||||
maxTokens: generationSettings.maxTokens,
|
||||
temperature: Float(generationSettings.temperature),
|
||||
topP: Float(generationSettings.topP),
|
||||
topK: generationSettings.topK,
|
||||
minP: Float(generationSettings.minP),
|
||||
repetitionPenalty: generationSettings.repetitionPenalty.map(Float.init),
|
||||
repetitionContextSize: 128,
|
||||
presencePenalty: generationSettings.presencePenalty.map(Float.init),
|
||||
presenceContextSize: 128,
|
||||
frequencyPenalty: generationSettings.frequencyPenalty.map(Float.init),
|
||||
frequencyContextSize: 128
|
||||
minP: Float(generationSettings.minP)
|
||||
)
|
||||
let history = conversation.messages.compactMap(historyMessage(from:))
|
||||
if history.isEmpty {
|
||||
|
||||
@@ -25,8 +25,8 @@ final class ModelManager {
|
||||
|
||||
// Download-specific state for the modal
|
||||
var isDownloading = false
|
||||
var downloadFilesTotal: Int64 = 0
|
||||
var downloadFilesCompleted: Int64 = 0
|
||||
var downloadBytesTotal: Int64 = 0
|
||||
var downloadBytesCompleted: Int64 = 0
|
||||
var downloadSpeed: Double = 0 // bytes/sec
|
||||
|
||||
private var idleTimer: Timer?
|
||||
@@ -87,8 +87,8 @@ final class ModelManager {
|
||||
isDownloading = false
|
||||
downloadProgress = 0
|
||||
loadingModelName = ""
|
||||
downloadFilesTotal = 0
|
||||
downloadFilesCompleted = 0
|
||||
downloadBytesTotal = 0
|
||||
downloadBytesCompleted = 0
|
||||
downloadSpeed = 0
|
||||
}
|
||||
|
||||
@@ -116,8 +116,8 @@ final class ModelManager {
|
||||
let needsDownload = !effectiveConfig.isLocal
|
||||
if needsDownload {
|
||||
isDownloading = true
|
||||
downloadFilesTotal = 0
|
||||
downloadFilesCompleted = 0
|
||||
downloadBytesTotal = 0
|
||||
downloadBytesCompleted = 0
|
||||
downloadSpeed = 0
|
||||
}
|
||||
|
||||
@@ -126,8 +126,8 @@ final class ModelManager {
|
||||
Task { @MainActor in
|
||||
self.downloadProgress = progress.fractionCompleted
|
||||
if self.isDownloading {
|
||||
self.downloadFilesTotal = progress.totalUnitCount
|
||||
self.downloadFilesCompleted = progress.completedUnitCount
|
||||
self.downloadBytesTotal = progress.totalUnitCount
|
||||
self.downloadBytesCompleted = progress.completedUnitCount
|
||||
if let speed = progress.userInfo[.throughputKey] as? Double {
|
||||
self.downloadSpeed = speed
|
||||
}
|
||||
|
||||
@@ -20,9 +20,9 @@ struct DownloadModalView: View {
|
||||
.progressViewStyle(.linear)
|
||||
|
||||
HStack {
|
||||
// Files progress
|
||||
if modelManager.downloadFilesTotal > 0 {
|
||||
Text("File \(modelManager.downloadFilesCompleted)/\(modelManager.downloadFilesTotal)")
|
||||
// Bytes progress
|
||||
if modelManager.downloadBytesTotal > 0 {
|
||||
Text("\(formatBytes(modelManager.downloadBytesCompleted)) / \(formatBytes(modelManager.downloadBytesTotal))")
|
||||
.font(.caption.monospacedDigit())
|
||||
.foregroundStyle(.secondary)
|
||||
}
|
||||
@@ -65,4 +65,17 @@ struct DownloadModalView: View {
|
||||
return String(format: "%.0f B/s", bytesPerSec)
|
||||
}
|
||||
}
|
||||
|
||||
private func formatBytes(_ bytes: Int64) -> String {
|
||||
let value = Double(bytes)
|
||||
if value >= 1_073_741_824 {
|
||||
return String(format: "%.2f GB", value / 1_073_741_824)
|
||||
} else if value >= 1_048_576 {
|
||||
return String(format: "%.0f MB", value / 1_048_576)
|
||||
} else if value >= 1024 {
|
||||
return String(format: "%.0f KB", value / 1024)
|
||||
} else {
|
||||
return "\(bytes) B"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user