feat: added model management

This commit is contained in:
2026-03-21 19:15:13 +01:00
parent 84a6b2229f
commit 6b14d7d46c
17 changed files with 1210 additions and 47 deletions

View File

@@ -23,18 +23,21 @@
2E3A02DF9C6A5109E532D5E2 /* ChatDocumentController.swift in Sources */ = {isa = PBXBuildFile; fileRef = D5C1FCEFEA72B9ABB87FB20E /* ChatDocumentController.swift */; };
3A9DB84947BBBBED06CF9E1E /* TestImageFixtures.swift in Sources */ = {isa = PBXBuildFile; fileRef = 31BD930DEC051408444C30D4 /* TestImageFixtures.swift */; };
4158FA884D981D73288FB74C /* SaveChatCommands.swift in Sources */ = {isa = PBXBuildFile; fileRef = 2E2FCA55CEBEBCED78D9479A /* SaveChatCommands.swift */; };
4B7449F57226CB48C4F5EEBD /* LocalModelResolverTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = 43315501A5AFC0EA014F44F5 /* LocalModelResolverTests.swift */; };
4CB13DC1AC7A500DDBB443EC /* ChatInputView.swift in Sources */ = {isa = PBXBuildFile; fileRef = E5E6AD02CDF23BDAB64700A7 /* ChatInputView.swift */; };
4DC033E45880B2948B47DEB1 /* FocusedValues.swift in Sources */ = {isa = PBXBuildFile; fileRef = EF518FEBF3A38E830E3CE1A5 /* FocusedValues.swift */; };
50B6861FF8610B3ED4FFAD9D /* MLXServerApp.swift in Sources */ = {isa = PBXBuildFile; fileRef = C67742651DB486871CEF1612 /* MLXServerApp.swift */; };
50DD129CCF2843482DEC3B96 /* APIServer.swift in Sources */ = {isa = PBXBuildFile; fileRef = 3D08828E16B17EF02C14243E /* APIServer.swift */; };
5946258F1DE88CE904584E0B /* ContentView.swift in Sources */ = {isa = PBXBuildFile; fileRef = 944C699FBB76C734C9DF2F2E /* ContentView.swift */; };
5C1E8FE1C521914CEF98D3AA /* ChatMessagesView.swift in Sources */ = {isa = PBXBuildFile; fileRef = DB1A5E8B1C9F2BC4D262C53A /* ChatMessagesView.swift */; };
5D41C2B260265A32FF42264B /* ModelManagementView.swift in Sources */ = {isa = PBXBuildFile; fileRef = 8A1B8F9568F95E07D212A2B7 /* ModelManagementView.swift */; };
621B7E4382199AC1378F5F9C /* StatusBarView.swift in Sources */ = {isa = PBXBuildFile; fileRef = B0EAB35D7130D56B9E7484BA /* StatusBarView.swift */; };
67262C5E24739F1FE0011439 /* StreamingSSEEncoder.swift in Sources */ = {isa = PBXBuildFile; fileRef = 615F8A7C9ABCADEB215D31BD /* StreamingSSEEncoder.swift */; };
67B815DC3304BF4B2E9974A8 /* LiveCountersTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = 7E7DF9F68C10C718844B7B01 /* LiveCountersTests.swift */; };
67D0628F148FE3C2200E0AEF /* APIServerResponseResolutionTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = 051FEC14CC76A677F79ACD21 /* APIServerResponseResolutionTests.swift */; };
6828CCA8B78AB40906F87CAB /* LocalModelResolver.swift in Sources */ = {isa = PBXBuildFile; fileRef = D733A0D1D4AC25DDDA6C8684 /* LocalModelResolver.swift */; };
741692862DB1F13EA0B2D14D /* TokenPrefixCache.swift in Sources */ = {isa = PBXBuildFile; fileRef = 1962D530BEABCC7F1E8E0ED1 /* TokenPrefixCache.swift */; };
75E046B4ABB1E6FEF17C1A60 /* ModelManagementWindow.swift in Sources */ = {isa = PBXBuildFile; fileRef = 721D6F203A10434FE0223042 /* ModelManagementWindow.swift */; };
7936325B425DFA2931F6E421 /* ModelBackedQuantizationTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = F7E6F18C80D9859E89D2B4E3 /* ModelBackedQuantizationTests.swift */; };
7CD765C1E2F9F4D7504C8D09 /* Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = B629DA084A9A40E54F8EA5FA /* Assets.xcassets */; };
80646C5066BF79BC76E1D9D7 /* ModelConfig.swift in Sources */ = {isa = PBXBuildFile; fileRef = 38DFC212AF4359A45FBE22BA /* ModelConfig.swift */; };
@@ -100,6 +103,7 @@
3D08828E16B17EF02C14243E /* APIServer.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = APIServer.swift; sourceTree = "<group>"; };
4147321383E94E9F17A0154E /* SettingsView.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = SettingsView.swift; sourceTree = "<group>"; };
4239CFF94B819C35A8D4D617 /* MonitorView.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = MonitorView.swift; sourceTree = "<group>"; };
43315501A5AFC0EA014F44F5 /* LocalModelResolverTests.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = LocalModelResolverTests.swift; sourceTree = "<group>"; };
49C383DD5224F3420EB98DB2 /* StreamingSSEEncoderTests.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = StreamingSSEEncoderTests.swift; sourceTree = "<group>"; };
57AC0815F72BDD32FC54C88A /* GenerationSettingsTests.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = GenerationSettingsTests.swift; sourceTree = "<group>"; };
5F9426FA5A4AC55F8D9C080E /* PromptBuilderTests.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = PromptBuilderTests.swift; sourceTree = "<group>"; };
@@ -108,9 +112,11 @@
6B3AA91D2C7842D7366F9A41 /* ChatDocumentPackage.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ChatDocumentPackage.swift; sourceTree = "<group>"; };
6EE59189918D06B8D2F588FC /* MLXServer.app */ = {isa = PBXFileReference; explicitFileType = wrapper.application; includeInIndex = 0; path = MLXServer.app; sourceTree = BUILT_PRODUCTS_DIR; };
6FAF7455BD387CD2061E0CBF /* GenerationSettings.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = GenerationSettings.swift; sourceTree = "<group>"; };
721D6F203A10434FE0223042 /* ModelManagementWindow.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ModelManagementWindow.swift; sourceTree = "<group>"; };
7AE2A32FBB744696DEA77435 /* GenerationSettingsEditor.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = GenerationSettingsEditor.swift; sourceTree = "<group>"; };
7C1A89C076E717F87A60397D /* ImageDecoder.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ImageDecoder.swift; sourceTree = "<group>"; };
7E7DF9F68C10C718844B7B01 /* LiveCountersTests.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = LiveCountersTests.swift; sourceTree = "<group>"; };
8A1B8F9568F95E07D212A2B7 /* ModelManagementView.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ModelManagementView.swift; sourceTree = "<group>"; };
922CBDC9206737BD04AF2874 /* ModelManager.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ModelManager.swift; sourceTree = "<group>"; };
944C699FBB76C734C9DF2F2E /* ContentView.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ContentView.swift; sourceTree = "<group>"; };
A4B359324B5FD8D106C74338 /* ChatMessage.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ChatMessage.swift; sourceTree = "<group>"; };
@@ -199,6 +205,7 @@
57AC0815F72BDD32FC54C88A /* GenerationSettingsTests.swift */,
E4573DC9314915F4C7963B4E /* ImageDecoderTests.swift */,
7E7DF9F68C10C718844B7B01 /* LiveCountersTests.swift */,
43315501A5AFC0EA014F44F5 /* LocalModelResolverTests.swift */,
D388BE00B42C06ED9D9905BF /* ModelBackedInferenceValidationTests.swift */,
F7E6F18C80D9859E89D2B4E3 /* ModelBackedQuantizationTests.swift */,
5F9426FA5A4AC55F8D9C080E /* PromptBuilderTests.swift */,
@@ -246,6 +253,8 @@
DB1A5E8B1C9F2BC4D262C53A /* ChatMessagesView.swift */,
2DC8C86D397B1FCA08E07CBD /* DownloadModalView.swift */,
7AE2A32FBB744696DEA77435 /* GenerationSettingsEditor.swift */,
8A1B8F9568F95E07D212A2B7 /* ModelManagementView.swift */,
721D6F203A10434FE0223042 /* ModelManagementWindow.swift */,
C3C3A76C02AF70A9D8F868FC /* ModelPickerView.swift */,
4239CFF94B819C35A8D4D617 /* MonitorView.swift */,
37FEB592E5E717F817B03151 /* SceneManagementView.swift */,
@@ -416,6 +425,7 @@
847B445654860396AF5A8280 /* GenerationSettingsTests.swift in Sources */,
E92B6656C251EDA246B8F582 /* ImageDecoderTests.swift in Sources */,
67B815DC3304BF4B2E9974A8 /* LiveCountersTests.swift in Sources */,
4B7449F57226CB48C4F5EEBD /* LocalModelResolverTests.swift in Sources */,
8E665E21CCCD87A907CEA78D /* ModelBackedInferenceValidationTests.swift in Sources */,
7936325B425DFA2931F6E421 /* ModelBackedQuantizationTests.swift in Sources */,
1FE8C624898960ECCE39C0D4 /* PromptBuilderTests.swift in Sources */,
@@ -455,6 +465,8 @@
6828CCA8B78AB40906F87CAB /* LocalModelResolver.swift in Sources */,
50B6861FF8610B3ED4FFAD9D /* MLXServerApp.swift in Sources */,
80646C5066BF79BC76E1D9D7 /* ModelConfig.swift in Sources */,
5D41C2B260265A32FF42264B /* ModelManagementView.swift in Sources */,
75E046B4ABB1E6FEF17C1A60 /* ModelManagementWindow.swift in Sources */,
0168AEE16009097901363E16 /* ModelManager.swift in Sources */,
2CAAF7129F7CC45200FA9F6B /* ModelPickerView.swift in Sources */,
B1D9BC407DB7DB1489230C20 /* MonitorView.swift in Sources */,

View File

@@ -12,3 +12,15 @@ struct SceneCommands: Commands {
}
}
}
struct ModelCommands: Commands {
@Environment(\.openWindow) private var openWindow
var body: some Commands {
CommandMenu("Models") {
Button("Manage Models…") {
openWindow(id: ModelManagementWindow.windowID)
}
}
}
}

View File

@@ -29,6 +29,7 @@ struct ContentView: View {
AnyView(mainContent)
.navigationTitle(navigationTitleText)
.onAppear {
modelManager.refreshAvailableModels()
if chatVM == nil {
let vm = ChatViewModel(modelManager: modelManager)
chatVM = vm
@@ -68,7 +69,7 @@ struct ContentView: View {
AnyView(lifecycleContent)
.alert("Model Error", isPresented: $showLoadError) {
Button("Retry") {
if let config = modelManager.currentModel ?? ModelConfig.availableModels.first {
if let config = modelManager.currentModel ?? modelManager.availableModels.first {
Task { await modelManager.loadModel(config) }
}
}
@@ -228,7 +229,7 @@ struct ContentView: View {
@ViewBuilder
private var modelSwitchShortcuts: some View {
ForEach(Array(ModelConfig.availableModels.enumerated()), id: \.element.id) { index, config in
ForEach(Array(ModelConfig.curatedModels.enumerated()), id: \.element.id) { index, config in
if index < 9 {
Button("") {
Task { await modelManager.loadModel(config) }
@@ -420,7 +421,7 @@ struct ContentView: View {
guard modelManager.currentModel == nil else { return }
let modelId = Preferences.defaultModelId ?? Preferences.lastModelId ?? ModelConfig.default.id
if let config = ModelConfig.availableModels.first(where: { $0.id == modelId }) {
if let config = ModelConfig.resolve(modelId) {
await modelManager.loadModel(config)
}
}

View File

@@ -52,17 +52,26 @@ struct MLXServerApp: App {
.commands {
SaveChatCommands()
SceneCommands()
ModelCommands()
}
Window("Scenes", id: SceneManagementWindow.windowID) {
SceneManagementView()
.environment(modelManager)
.environment(sceneStore)
}
.defaultSize(width: 900, height: 560)
Window("Models", id: ModelManagementWindow.windowID) {
ModelManagementView()
.environment(modelManager)
}
.defaultSize(width: 900, height: 620)
#if os(macOS)
Settings {
SettingsView()
.environment(modelManager)
.environment(sceneStore)
}
#endif

View File

@@ -53,7 +53,7 @@ struct ChatScene: Codable, Identifiable, Hashable {
var resolvedModel: ModelConfig? {
guard let modelId else { return nil }
return ModelConfig.availableModels.first(where: { $0.id == modelId })
return ModelConfig.resolve(modelId)
}
static let empty = ChatScene(name: "New Scene")

View File

@@ -1,30 +1,81 @@
import Foundation
import MLXLMCommon
struct ModelMetadataOverride: Codable, Hashable, Sendable {
var contextLength: Int
var primaryLoaderKind: ModelConfig.LoaderKind
var supportsImages: Bool
var supportsTools: Bool
func normalized() -> ModelMetadataOverride {
ModelMetadataOverride(
contextLength: max(0, contextLength),
primaryLoaderKind: primaryLoaderKind,
supportsImages: supportsImages,
supportsTools: supportsTools
)
}
}
/// Defines a supported model with its metadata.
struct ModelConfig: Identifiable, Hashable {
enum LoaderKind: Hashable {
enum LoaderKind: String, CaseIterable, Codable, Hashable, Sendable {
case llm
case vlm
var displayName: String {
switch self {
case .llm:
return "Text"
case .vlm:
return "Vision"
}
}
}
let id: String // alias: "gemma", "gemma3n", "qwen"
let repoId: String // HuggingFace ID
let displayName: String
let contextLength: Int
let loaderKind: LoaderKind
let loaderKinds: [LoaderKind]
let supportsImages: Bool
let supportsTools: Bool
let defaultGenerationSettings: GenerationSettings
let isCurated: Bool
let localSizeBytes: Int64?
/// All models supported by the app.
static let availableModels: [ModelConfig] = [
init(
id: String,
repoId: String,
displayName: String,
contextLength: Int,
loaderKinds: [LoaderKind],
supportsImages: Bool,
supportsTools: Bool,
defaultGenerationSettings: GenerationSettings,
isCurated: Bool = true,
localSizeBytes: Int64? = nil
) {
self.id = id
self.repoId = repoId
self.displayName = displayName
self.contextLength = contextLength
self.loaderKinds = loaderKinds
self.supportsImages = supportsImages
self.supportsTools = supportsTools
self.defaultGenerationSettings = defaultGenerationSettings
self.isCurated = isCurated
self.localSizeBytes = localSizeBytes
}
/// Curated models supported and tuned by the app.
static let curatedModels: [ModelConfig] = [
ModelConfig(
id: "gemma",
repoId: "mlx-community/gemma-3-4b-it-4bit",
displayName: "Gemma 3 4B",
contextLength: 128_000,
loaderKind: .vlm,
loaderKinds: [.vlm],
supportsImages: true,
supportsTools: true,
defaultGenerationSettings: .technicalDefault
@@ -34,7 +85,7 @@ struct ModelConfig: Identifiable, Hashable {
repoId: "mlx-community/Qwen3.5-4B-MLX-4bit",
displayName: "Qwen3.5 4B",
contextLength: 256_000,
loaderKind: .vlm,
loaderKinds: [.vlm],
supportsImages: true,
supportsTools: true,
defaultGenerationSettings: .technicalDefault
@@ -44,7 +95,7 @@ struct ModelConfig: Identifiable, Hashable {
repoId: "mlx-community/Qwen3.5-0.8B-4bit",
displayName: "Qwen3.5 0.8B",
contextLength: 256_000,
loaderKind: .vlm,
loaderKinds: [.vlm],
supportsImages: true,
supportsTools: true,
defaultGenerationSettings: .technicalDefault
@@ -54,7 +105,7 @@ struct ModelConfig: Identifiable, Hashable {
repoId: "mlx-community/Qwen3.5-9B-4bit",
displayName: "Qwen3.5 9B",
contextLength: 256_000,
loaderKind: .vlm,
loaderKinds: [.vlm],
supportsImages: true,
supportsTools: true,
defaultGenerationSettings: .technicalDefault
@@ -64,7 +115,7 @@ struct ModelConfig: Identifiable, Hashable {
repoId: "synk/L3-8B-Stheno-v3.2-MLX",
displayName: "Stheno L3 8B",
contextLength: 8_192,
loaderKind: .llm,
loaderKinds: [.llm],
supportsImages: false,
supportsTools: false,
defaultGenerationSettings: .roleplayDefault
@@ -74,18 +125,35 @@ struct ModelConfig: Identifiable, Hashable {
repoId: "hobaratio/MN-Violet-Lotus-12B-mlx-4Bit",
displayName: "Violet Lotus 12B",
contextLength: 32_768,
loaderKind: .llm,
loaderKinds: [.llm],
supportsImages: false,
supportsTools: false,
defaultGenerationSettings: .roleplayDefault
),
]
static let `default` = availableModels[0]
static var availableModels: [ModelConfig] {
mergedModels(localModels: LocalModelResolver.discoveredLocalModels())
}
static let `default` = curatedModels[0]
/// Whether this model is cached locally (no download needed).
var isLocal: Bool {
LocalModelResolver.isAvailable(repoId: repoId)
localSizeBytes != nil || LocalModelResolver.isAvailable(repoId: repoId)
}
var primaryLoaderKind: LoaderKind {
loaderKinds.first ?? .llm
}
var metadataOverrideValue: ModelMetadataOverride {
ModelMetadataOverride(
contextLength: contextLength,
primaryLoaderKind: primaryLoaderKind,
supportsImages: supportsImages,
supportsTools: supportsTools
)
}
/// Build a ModelConfiguration for mlx-swift-lm from this config.
@@ -96,6 +164,9 @@ struct ModelConfig: Identifiable, Hashable {
/// Resolve a model string (alias, full repo ID, or partial match) to a ModelConfig.
/// Mirrors the Python server's `ModelManager.resolve_model()`.
static func resolve(_ requested: String) -> ModelConfig? {
let requested = requested.trimmingCharacters(in: .whitespacesAndNewlines)
guard !requested.isEmpty else { return nil }
// Exact alias match
if let config = availableModels.first(where: { $0.id == requested }) {
return config
@@ -108,6 +179,129 @@ struct ModelConfig: Identifiable, Hashable {
if let config = availableModels.first(where: { requested.contains($0.id) || $0.repoId.contains(requested) || requested.contains($0.repoId) }) {
return config
}
if requested.contains("/") {
return remoteCustom(repoId: requested)
}
return nil
}
static func mergedModels(
localModels: [LocalModelResolver.LocalModelInfo],
applyingOverrides: Bool = true
) -> [ModelConfig] {
let localByRepo = Dictionary(uniqueKeysWithValues: localModels.map { ($0.repoId, $0) })
let curatedRepoIds = Set(curatedModels.map(\.repoId))
let curated = curatedModels.map { config in
if let local = localByRepo[config.repoId] {
return applyingOverrides ? applyMetadataOverrideIfNeeded(to: config.withLocalSize(local.sizeBytes)) : config.withLocalSize(local.sizeBytes)
}
return applyingOverrides ? applyMetadataOverrideIfNeeded(to: config) : config
}
let discoveredCustom = localModels
.filter { !curatedRepoIds.contains($0.repoId) }
.map(customLocal)
.sorted { lhs, rhs in
lhs.displayName.localizedCaseInsensitiveCompare(rhs.displayName) == .orderedAscending
}
return curated + discoveredCustom
}
static func baselineModel(
forRepoId repoId: String,
localModels: [LocalModelResolver.LocalModelInfo]
) -> ModelConfig? {
mergedModels(localModels: localModels, applyingOverrides: false)
.first(where: { $0.repoId == repoId || $0.id == repoId })
?? (repoId.contains("/") ? remoteCustom(repoId: repoId) : nil)
}
static func remoteCustom(repoId: String) -> ModelConfig {
let supportsImages = inferredVisionSupport(repoId: repoId)
return applyMetadataOverrideIfNeeded(to: ModelConfig(
id: repoId,
repoId: repoId,
displayName: displayName(for: repoId),
contextLength: 0,
loaderKinds: supportsImages ? [.vlm, .llm] : [.llm, .vlm],
supportsImages: supportsImages,
supportsTools: inferredToolSupport(repoId: repoId),
defaultGenerationSettings: .generalDefault,
isCurated: false
))
}
static func displayName(for repoId: String) -> String {
let raw = repoId.split(separator: "/").last.map(String.init) ?? repoId
return raw
.replacingOccurrences(of: "-", with: " ")
.replacingOccurrences(of: "_", with: " ")
}
private static func customLocal(_ local: LocalModelResolver.LocalModelInfo) -> ModelConfig {
applyMetadataOverrideIfNeeded(to: ModelConfig(
id: local.repoId,
repoId: local.repoId,
displayName: displayName(for: local.repoId),
contextLength: local.contextLength,
loaderKinds: local.loaderKinds,
supportsImages: local.supportsImages,
supportsTools: inferredToolSupport(repoId: local.repoId),
defaultGenerationSettings: .generalDefault,
isCurated: false,
localSizeBytes: local.sizeBytes
))
}
private static func inferredToolSupport(repoId: String) -> Bool {
let normalized = repoId.lowercased()
return normalized.contains("qwen") || normalized.contains("gemma")
}
private static func inferredVisionSupport(repoId: String) -> Bool {
let normalized = repoId.lowercased()
return normalized.contains("vision") || normalized.contains("vl") || normalized.contains("gemma-3") || normalized.contains("qwen")
}
private func withLocalSize(_ sizeBytes: Int64) -> ModelConfig {
ModelConfig(
id: id,
repoId: repoId,
displayName: displayName,
contextLength: contextLength,
loaderKinds: loaderKinds,
supportsImages: supportsImages,
supportsTools: supportsTools,
defaultGenerationSettings: defaultGenerationSettings,
isCurated: isCurated,
localSizeBytes: sizeBytes
)
}
private func applyingMetadataOverride(_ override: ModelMetadataOverride) -> ModelConfig {
let normalized = override.normalized()
let reorderedLoaderKinds = [normalized.primaryLoaderKind] + LoaderKind.allCases.filter { $0 != normalized.primaryLoaderKind }
return ModelConfig(
id: id,
repoId: repoId,
displayName: displayName,
contextLength: normalized.contextLength,
loaderKinds: reorderedLoaderKinds,
supportsImages: normalized.supportsImages,
supportsTools: normalized.supportsTools,
defaultGenerationSettings: defaultGenerationSettings,
isCurated: isCurated,
localSizeBytes: localSizeBytes
)
}
private static func applyMetadataOverrideIfNeeded(to config: ModelConfig) -> ModelConfig {
guard let override = Preferences.modelMetadataOverride(forRepoId: config.repoId) else {
return config
}
return config.applyingMetadataOverride(override)
}
}

View File

@@ -6,6 +6,17 @@ import Foundation
/// ~/Library/Containers/de.rfc1437.mlxserver/Data/Library/Caches/models/{org}/{name}/
enum LocalModelResolver {
struct LocalModelInfo: Identifiable, Hashable {
let repoId: String
let directory: URL
let sizeBytes: Int64
let contextLength: Int
let loaderKinds: [ModelConfig.LoaderKind]
let supportsImages: Bool
var id: String { repoId }
}
/// Base directory where HubApi stores downloaded models.
private static let modelsBase: URL? = {
FileManager.default.urls(for: .cachesDirectory, in: .userDomainMask).first?
@@ -31,6 +42,46 @@ enum LocalModelResolver {
resolve(repoId: repoId) != nil
}
static func discoveredLocalModels() -> [LocalModelInfo] {
guard let base = modelsBase else { return [] }
return discoverModels(in: base)
}
static func discoverModels(in base: URL) -> [LocalModelInfo] {
let fileManager = FileManager.default
let directoryKeys: Set<URLResourceKey> = [.isDirectoryKey]
guard let ownerDirectories = try? fileManager.contentsOfDirectory(
at: base,
includingPropertiesForKeys: Array(directoryKeys),
options: [.skipsHiddenFiles]
) else {
return []
}
var discovered: [LocalModelInfo] = []
for ownerDirectory in ownerDirectories {
guard isDirectory(ownerDirectory) else { continue }
guard let repoDirectories = try? fileManager.contentsOfDirectory(
at: ownerDirectory,
includingPropertiesForKeys: Array(directoryKeys),
options: [.skipsHiddenFiles]
) else {
continue
}
for repoDirectory in repoDirectories where isDirectory(repoDirectory) {
if let info = localModelInfo(ownerDirectory: ownerDirectory, repoDirectory: repoDirectory) {
discovered.append(info)
}
}
}
return discovered.sorted {
$0.repoId.localizedCaseInsensitiveCompare($1.repoId) == .orderedAscending
}
}
/// Delete the local cache for a model so it will be re-downloaded next time.
@discardableResult
static func deleteLocal(repoId: String) -> Bool {
@@ -46,4 +97,122 @@ enum LocalModelResolver {
return false
}
}
private static func localModelInfo(ownerDirectory: URL, repoDirectory: URL) -> LocalModelInfo? {
let repoId = "\(ownerDirectory.lastPathComponent)/\(repoDirectory.lastPathComponent)"
guard containsModelArtifacts(at: repoDirectory) else { return nil }
let config = readJSONObject(at: repoDirectory.appendingPathComponent("config.json"))
let tokenizerConfig = readJSONObject(at: repoDirectory.appendingPathComponent("tokenizer_config.json"))
let supportsImages = inferredSupportsImages(
repoDirectory: repoDirectory,
config: config,
tokenizerConfig: tokenizerConfig
)
let sizeBytes = directorySize(at: repoDirectory)
let contextLength = inferredContextLength(config: config, tokenizerConfig: tokenizerConfig)
let loaderKinds: [ModelConfig.LoaderKind] = supportsImages ? [.vlm, .llm] : [.llm, .vlm]
return LocalModelInfo(
repoId: repoId,
directory: repoDirectory,
sizeBytes: sizeBytes,
contextLength: contextLength,
loaderKinds: loaderKinds,
supportsImages: supportsImages
)
}
private static func containsModelArtifacts(at directory: URL) -> Bool {
let requiredPaths = [
directory.appendingPathComponent("config.json").path,
directory.appendingPathComponent("model.safetensors").path,
directory.appendingPathComponent("model.safetensors.index.json").path,
]
return requiredPaths.contains { FileManager.default.fileExists(atPath: $0) }
}
private static func isDirectory(_ url: URL) -> Bool {
(try? url.resourceValues(forKeys: [.isDirectoryKey]).isDirectory) == true
}
private static func readJSONObject(at url: URL) -> [String: Any]? {
guard let data = try? Data(contentsOf: url) else { return nil }
return (try? JSONSerialization.jsonObject(with: data)) as? [String: Any]
}
private static func inferredSupportsImages(
repoDirectory: URL,
config: [String: Any]?,
tokenizerConfig: [String: Any]?
) -> Bool {
if config?["vision_config"] != nil {
return true
}
if tokenizerConfig?["image_token"] != nil {
return true
}
let metadataFiles = [
"processor_config.json",
"preprocessor_config.json",
"video_preprocessor_config.json",
]
return metadataFiles.contains {
FileManager.default.fileExists(atPath: repoDirectory.appendingPathComponent($0).path)
}
}
private static func inferredContextLength(
config: [String: Any]?,
tokenizerConfig: [String: Any]?
) -> Int {
if let value = integerValue(at: ["text_config", "max_position_embeddings"], in: config) {
return value
}
if let value = integerValue(at: ["max_position_embeddings"], in: config) {
return value
}
if let value = integerValue(at: ["model_max_length"], in: tokenizerConfig) {
return value
}
return 0
}
private static func integerValue(at path: [String], in json: [String: Any]?) -> Int? {
guard let json else { return nil }
var current: Any = json
for component in path {
guard let dictionary = current as? [String: Any], let next = dictionary[component] else {
return nil
}
current = next
}
if let number = current as? NSNumber {
return number.intValue
}
return current as? Int
}
private static func directorySize(at directory: URL) -> Int64 {
let keys: [URLResourceKey] = [.isRegularFileKey, .fileSizeKey]
guard let enumerator = FileManager.default.enumerator(
at: directory,
includingPropertiesForKeys: keys,
options: [.skipsHiddenFiles]
) else {
return 0
}
var total: Int64 = 0
for case let fileURL as URL in enumerator {
guard let values = try? fileURL.resourceValues(forKeys: Set(keys)), values.isRegularFile == true else {
continue
}
total += Int64(values.fileSize ?? 0)
}
return total
}
}

View File

@@ -7,6 +7,7 @@ enum Preferences {
private static let jsonEncoder = JSONEncoder()
private static let jsonDecoder = JSONDecoder()
private static let legacyThinkingDefault = true
private static let modelMetadataOverridesKey = "modelMetadataOverrides"
// MARK: - Last used model
@@ -118,6 +119,26 @@ enum Preferences {
modelGenerationSettingsMap[modelId] != nil
}
static func modelMetadataOverride(forRepoId repoId: String) -> ModelMetadataOverride? {
modelMetadataOverridesMap[repoId]?.normalized()
}
static func setModelMetadataOverride(_ override: ModelMetadataOverride, forRepoId repoId: String) {
var map = modelMetadataOverridesMap
map[repoId] = override.normalized()
modelMetadataOverridesMap = map
}
static func removeModelMetadataOverride(forRepoId repoId: String) {
var map = modelMetadataOverridesMap
map.removeValue(forKey: repoId)
modelMetadataOverridesMap = map
}
static func hasModelMetadataOverride(forRepoId repoId: String) -> Bool {
modelMetadataOverridesMap[repoId] != nil
}
private static var modelGenerationSettingsMap: [String: GenerationSettings] {
get {
guard let data = defaults.data(forKey: modelGenerationSettingsKey) else { return [:] }
@@ -129,6 +150,17 @@ enum Preferences {
}
}
private static var modelMetadataOverridesMap: [String: ModelMetadataOverride] {
get {
guard let data = defaults.data(forKey: modelMetadataOverridesKey) else { return [:] }
return (try? jsonDecoder.decode([String: ModelMetadataOverride].self, from: data)) ?? [:]
}
set {
guard let data = try? jsonEncoder.encode(newValue) else { return }
defaults.set(data, forKey: modelMetadataOverridesKey)
}
}
// MARK: - Idle unload
private static let idleUnloadMinutesKey = "idleUnloadMinutes"

View File

@@ -559,7 +559,7 @@ final class ChatViewModel {
if modelManager.currentModel == nil {
let modelId = Preferences.defaultModelId ?? Preferences.lastModelId ?? ModelConfig.default.id
if let config = ModelConfig.availableModels.first(where: { $0.id == modelId }) {
if let config = ModelConfig.resolve(modelId) {
await modelManager.loadModel(config)
}
}

View File

@@ -19,7 +19,10 @@ final class ModelManager {
let cachesDir = FileManager.default.urls(for: .cachesDirectory, in: .userDomainMask).first
return HubApi(downloadBase: cachesDir, cache: nil)
}()
var currentModel: ModelConfig?
var availableModels: [ModelConfig]
private(set) var discoveredLocalModels: [LocalModelResolver.LocalModelInfo] = []
var modelContainer: ModelContainer?
var isLoading = false
var downloadProgress: Double = 0
@@ -36,6 +39,50 @@ final class ModelManager {
private(set) var lastUsed: Date?
private var latestLoadRequestID = UUID()
init() {
availableModels = []
refreshAvailableModels()
}
var curatedModels: [ModelConfig] {
availableModels.filter(\.isCurated)
}
var localModelsOnDisk: [ModelConfig] {
availableModels
.filter(\.isLocal)
.sorted {
$0.displayName.localizedCaseInsensitiveCompare($1.displayName) == .orderedAscending
}
}
func refreshAvailableModels() {
discoveredLocalModels = LocalModelResolver.discoveredLocalModels()
availableModels = ModelConfig.mergedModels(localModels: discoveredLocalModels)
if let currentModel {
self.currentModel = availableModels.first(where: { $0.repoId == currentModel.repoId }) ?? currentModel
}
}
func discoveredLocalModelInfo(repoId: String) -> LocalModelResolver.LocalModelInfo? {
discoveredLocalModels.first(where: { $0.repoId == repoId })
}
func baselineModel(repoId: String) -> ModelConfig? {
ModelConfig.baselineModel(forRepoId: repoId, localModels: discoveredLocalModels)
}
func saveMetadataOverride(_ override: ModelMetadataOverride, for config: ModelConfig) {
Preferences.setModelMetadataOverride(override, forRepoId: config.repoId)
refreshAvailableModels()
}
func clearMetadataOverride(for config: ModelConfig) {
Preferences.removeModelMetadataOverride(forRepoId: config.repoId)
refreshAvailableModels()
}
private func clearLoadedState() {
idleTimer?.invalidate()
idleTimer = nil
@@ -55,7 +102,11 @@ final class ModelManager {
/// Prefers the local snapshot from ~/.cache/huggingface/hub/ (shared with the Python server).
/// Only downloads if the model isn't cached locally.
func loadModel(_ config: ModelConfig) async {
if currentModel?.id == config.id && modelContainer != nil {
refreshAvailableModels()
let effectiveConfig = availableModels.first(where: { $0.repoId == config.repoId }) ?? config
if currentModel?.repoId == effectiveConfig.repoId && modelContainer != nil {
currentModel = effectiveConfig
return // already loaded
}
@@ -65,10 +116,10 @@ final class ModelManager {
MLX.GPU.clearCache()
isLoading = true
downloadProgress = 0
loadingModelName = config.displayName
loadingModelName = effectiveConfig.displayName
errorMessage = nil
let needsDownload = !config.isLocal
let needsDownload = !effectiveConfig.isLocal
if needsDownload {
isDownloading = true
downloadFilesTotal = 0
@@ -91,32 +142,23 @@ final class ModelManager {
}
let configuration: ModelConfiguration
if let localDir = LocalModelResolver.resolve(repoId: config.repoId) {
if let localDir = LocalModelResolver.resolve(repoId: effectiveConfig.repoId) {
configuration = ModelConfiguration(directory: localDir)
} else {
configuration = config.modelConfiguration
configuration = effectiveConfig.modelConfiguration
}
let container: ModelContainer
switch config.loaderKind {
case .llm:
container = try await LLMModelFactory.shared.loadContainer(
hub: Self.hub,
let container = try await Self.loadContainer(
for: effectiveConfig,
configuration: configuration,
progressHandler: progressHandler
)
case .vlm:
container = try await VLMModelFactory.shared.loadContainer(
hub: Self.hub,
configuration: configuration,
progressHandler: progressHandler
)
}
guard latestLoadRequestID == requestID else { return }
refreshAvailableModels()
self.isDownloading = false
self.modelContainer = container
self.currentModel = config
self.currentModel = self.availableModels.first(where: { $0.repoId == effectiveConfig.repoId }) ?? effectiveConfig
touchActivity()
} catch {
guard latestLoadRequestID == requestID else { return }
@@ -135,6 +177,25 @@ final class ModelManager {
await loadModel(config)
}
func addModel(repoId: String) async {
let repoId = repoId.trimmingCharacters(in: .whitespacesAndNewlines)
guard !repoId.isEmpty else {
errorMessage = "Enter a HuggingFace model ID."
return
}
let config = ModelConfig.resolve(repoId) ?? ModelConfig.remoteCustom(repoId: repoId)
await loadModel(config)
}
func deleteModel(_ config: ModelConfig) {
if currentModel?.repoId == config.repoId {
unloadModel()
}
_ = LocalModelResolver.deleteLocal(repoId: config.repoId)
refreshAvailableModels()
}
/// Unload the current model and free GPU memory.
func unloadModel() {
latestLoadRequestID = UUID()
@@ -161,4 +222,35 @@ final class ModelManager {
var isReady: Bool {
modelContainer != nil && !isLoading
}
private static func loadContainer(
for config: ModelConfig,
configuration: ModelConfiguration,
progressHandler: @escaping @Sendable (Progress) -> Void
) async throws -> ModelContainer {
var lastError: Error?
for loaderKind in config.loaderKinds {
do {
switch loaderKind {
case .llm:
return try await LLMModelFactory.shared.loadContainer(
hub: Self.hub,
configuration: configuration,
progressHandler: progressHandler
)
case .vlm:
return try await VLMModelFactory.shared.loadContainer(
hub: Self.hub,
configuration: configuration,
progressHandler: progressHandler
)
}
} catch {
lastError = error
}
}
throw lastError ?? NSError(domain: "ModelManager", code: -1, userInfo: [NSLocalizedDescriptionKey: "Unsupported model configuration"])
}
}

View File

@@ -0,0 +1,431 @@
import SwiftUI
struct ModelManagementView: View {
@Environment(ModelManager.self) private var modelManager
@State private var newRepoId = ""
@State private var pendingDelete: ModelConfig?
@State private var editingMetadataModel: ModelConfig?
@FocusState private var isRepoIdFieldFocused: Bool
private let sizeFormatter: ByteCountFormatter = {
let formatter = ByteCountFormatter()
formatter.allowedUnits = [.useGB, .useMB, .useKB]
formatter.countStyle = .file
formatter.includesUnit = true
formatter.isAdaptive = true
return formatter
}()
var body: some View {
ScrollView {
VStack(alignment: .leading, spacing: 18) {
GroupBox("Add Model") {
VStack(alignment: .leading, spacing: 10) {
Text("Enter a HuggingFace model ID. The app will download it, load it once, and then keep it available in the regular model picker.")
.font(.caption)
.foregroundStyle(.secondary)
HStack {
TextField("owner/repo", text: $newRepoId)
.textFieldStyle(.roundedBorder)
.focused($isRepoIdFieldFocused)
.onSubmit {
downloadEnteredModel()
}
Button("Download & Select") {
downloadEnteredModel()
}
.disabled(modelManager.isLoading || newRepoId.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty)
}
}
}
GroupBox("Recommended Defaults") {
VStack(spacing: 0) {
ForEach(modelManager.curatedModels) { model in
curatedRow(model)
if model.id != modelManager.curatedModels.last?.id {
Divider()
}
}
}
}
GroupBox("Models On Disk") {
if modelManager.localModelsOnDisk.isEmpty {
ContentUnavailableView(
"No Local Models",
systemImage: "externaldrive",
description: Text("Downloaded models will appear here with their summed file sizes.")
)
.frame(maxWidth: .infinity)
.padding(.vertical, 20)
} else {
VStack(spacing: 0) {
ForEach(modelManager.localModelsOnDisk) { model in
localRow(model)
if model.id != modelManager.localModelsOnDisk.last?.id {
Divider()
}
}
}
}
}
}
.padding(20)
}
.navigationTitle("Models")
.frame(minWidth: 760, minHeight: 520)
.sheet(item: $editingMetadataModel) { model in
ModelMetadataEditorView(
model: model,
baselineModel: modelManager.baselineModel(repoId: model.repoId) ?? model,
detectedLocalModel: modelManager.discoveredLocalModelInfo(repoId: model.repoId),
hasSavedOverride: Preferences.hasModelMetadataOverride(forRepoId: model.repoId),
onSave: { override in
modelManager.saveMetadataOverride(override, for: model)
},
onReset: {
modelManager.clearMetadataOverride(for: model)
}
)
}
.alert(
"Delete Local Model?",
isPresented: Binding(
get: { pendingDelete != nil },
set: { if !$0 { pendingDelete = nil } }
)
) {
Button("Delete", role: .destructive) {
if let pendingDelete {
modelManager.deleteModel(pendingDelete)
}
self.pendingDelete = nil
}
Button("Cancel", role: .cancel) {
pendingDelete = nil
}
} message: {
if let pendingDelete {
Text("This removes the local files for \(pendingDelete.repoId).")
}
}
.onAppear {
modelManager.refreshAvailableModels()
if newRepoId.isEmpty {
isRepoIdFieldFocused = true
}
}
}
@ViewBuilder
private func curatedRow(_ model: ModelConfig) -> some View {
HStack(alignment: .top, spacing: 14) {
VStack(alignment: .leading, spacing: 4) {
HStack(spacing: 8) {
Text(model.displayName)
.font(.headline)
if modelManager.currentModel?.repoId == model.repoId {
Text("Loaded")
.font(.caption.weight(.semibold))
.padding(.horizontal, 8)
.padding(.vertical, 3)
.background(.green.opacity(0.15), in: Capsule())
}
}
Text(model.repoId)
.font(.caption)
.foregroundStyle(.secondary)
}
Spacer()
Label(
model.isLocal ? "On Disk" : "Not Downloaded",
systemImage: model.isLocal ? "checkmark.circle.fill" : "arrow.down.circle"
)
.font(.caption)
.foregroundStyle(model.isLocal ? .green : .secondary)
Button(model.isLocal ? "Load" : "Download") {
Task {
await modelManager.loadModel(model)
}
}
.disabled(modelManager.isLoading)
Button("Metadata…") {
editingMetadataModel = model
}
}
.padding(.vertical, 10)
}
@ViewBuilder
private func localRow(_ model: ModelConfig) -> some View {
HStack(alignment: .top, spacing: 14) {
VStack(alignment: .leading, spacing: 4) {
HStack(spacing: 8) {
Text(model.displayName)
.font(.headline)
if !model.isCurated {
Text("Custom")
.font(.caption.weight(.semibold))
.padding(.horizontal, 8)
.padding(.vertical, 3)
.background(.secondary.opacity(0.14), in: Capsule())
}
if modelManager.currentModel?.repoId == model.repoId {
Text("Loaded")
.font(.caption.weight(.semibold))
.padding(.horizontal, 8)
.padding(.vertical, 3)
.background(.green.opacity(0.15), in: Capsule())
}
}
Text(model.repoId)
.font(.caption)
.foregroundStyle(.secondary)
}
Spacer()
if let localSizeBytes = model.localSizeBytes {
Text(sizeFormatter.string(fromByteCount: localSizeBytes))
.font(.caption.monospacedDigit())
.foregroundStyle(.secondary)
.frame(width: 90, alignment: .trailing)
}
Button("Load") {
Task {
await modelManager.loadModel(model)
}
}
.disabled(modelManager.isLoading)
Button("Metadata…") {
editingMetadataModel = model
}
Button("Delete", role: .destructive) {
pendingDelete = model
}
.disabled(modelManager.isLoading)
}
.padding(.vertical, 10)
}
private func downloadEnteredModel() {
let repoId = newRepoId.trimmingCharacters(in: .whitespacesAndNewlines)
guard !repoId.isEmpty else { return }
Task {
await modelManager.addModel(repoId: repoId)
if modelManager.errorMessage == nil {
newRepoId = ""
}
}
}
}
private struct ModelMetadataEditorView: View {
@Environment(\.dismiss) private var dismiss
let model: ModelConfig
let baselineModel: ModelConfig
let detectedLocalModel: LocalModelResolver.LocalModelInfo?
let hasSavedOverride: Bool
let onSave: (ModelMetadataOverride) -> Void
let onReset: () -> Void
@State private var contextLengthText: String
@State private var primaryLoaderKind: ModelConfig.LoaderKind
@State private var supportsImages: Bool
@State private var supportsTools: Bool
init(
model: ModelConfig,
baselineModel: ModelConfig,
detectedLocalModel: LocalModelResolver.LocalModelInfo?,
hasSavedOverride: Bool,
onSave: @escaping (ModelMetadataOverride) -> Void,
onReset: @escaping () -> Void
) {
self.model = model
self.baselineModel = baselineModel
self.detectedLocalModel = detectedLocalModel
self.hasSavedOverride = hasSavedOverride
self.onSave = onSave
self.onReset = onReset
_contextLengthText = State(initialValue: String(model.contextLength))
_primaryLoaderKind = State(initialValue: model.primaryLoaderKind)
_supportsImages = State(initialValue: model.supportsImages)
_supportsTools = State(initialValue: model.supportsTools)
}
var body: some View {
NavigationStack {
Form {
Section("Metadata") {
TextField("Context length", text: $contextLengthText)
.textFieldStyle(.roundedBorder)
Picker("Primary loader", selection: $primaryLoaderKind) {
ForEach(ModelConfig.LoaderKind.allCases, id: \.self) { loaderKind in
Text(loaderKind.displayName).tag(loaderKind)
}
}
Toggle("Supports images", isOn: $supportsImages)
Toggle("Supports tools", isOn: $supportsTools)
}
Section("Comparison") {
Text(defaultsSummary)
.foregroundStyle(.secondary)
Grid(alignment: .leading, horizontalSpacing: 16, verticalSpacing: 8) {
GridRow {
Text("")
Text("Effective")
.font(.caption.weight(.semibold))
.foregroundStyle(.secondary)
Text(baselineHeading)
.font(.caption.weight(.semibold))
.foregroundStyle(.secondary)
}
comparisonRow(
label: "Context",
effective: currentOverride?.contextLength.description ?? "Invalid",
baseline: baselineModel.contextLength > 0 ? "\(baselineModel.contextLength)" : "Unknown"
)
comparisonRow(
label: "Loader",
effective: primaryLoaderKind.displayName,
baseline: baselineModel.primaryLoaderKind.displayName
)
comparisonRow(
label: "Images",
effective: yesNo(supportsImages),
baseline: yesNo(baselineModel.supportsImages)
)
comparisonRow(
label: "Tools",
effective: yesNo(supportsTools),
baseline: yesNo(baselineModel.supportsTools)
)
}
}
if let detectedLocalModel {
Section("Discovered Source") {
LabeledContent("Detected context") {
Text(detectedLocalModel.contextLength > 0 ? "\(detectedLocalModel.contextLength)" : "Unknown")
}
LabeledContent("Detected loader order") {
Text(detectedLocalModel.loaderKinds.map(\.displayName).joined(separator: ", "))
}
LabeledContent("Detected vision") {
Text(yesNo(detectedLocalModel.supportsImages))
}
}
}
}
.formStyle(.grouped)
.navigationTitle(model.displayName)
.frame(minWidth: 520, minHeight: 380)
.toolbar {
ToolbarItem(placement: .cancellationAction) {
Button("Cancel") {
dismiss()
}
}
ToolbarItem(placement: .primaryAction) {
Button("Save") {
guard let currentOverride else { return }
onSave(currentOverride)
dismiss()
}
.disabled(currentOverride == nil)
}
if hasSavedOverride {
ToolbarItem(placement: .automatic) {
Button("Reset to Detected") {
onReset()
dismiss()
}
}
}
}
}
}
private var currentOverride: ModelMetadataOverride? {
guard let contextLength = Int(contextLengthText.trimmingCharacters(in: .whitespacesAndNewlines)), contextLength >= 0 else {
return nil
}
return ModelMetadataOverride(
contextLength: contextLength,
primaryLoaderKind: primaryLoaderKind,
supportsImages: supportsImages,
supportsTools: supportsTools
)
}
private var defaultsSummary: String {
if detectedLocalModel != nil {
if hasSavedOverride {
return "The editable fields show the effective overridden metadata. The comparison column shows the discovered baseline from the local model files."
}
return "The editable fields currently match the discovered baseline from the local model files. Save to store an override for this repo ID."
}
if model.isCurated {
return hasSavedOverride
? "The editable fields show the effective overridden metadata. The comparison column shows the curated built-in baseline."
: "The editable fields currently match the curated built-in baseline. Save to store an override for this repo ID."
}
if hasSavedOverride {
return "The editable fields show the effective overridden metadata. The comparison column shows the inferred baseline for this repo ID."
}
return "The editable fields currently match the inferred baseline for this repo ID. Save to store an override."
}
private var baselineHeading: String {
if detectedLocalModel != nil {
return "Detected"
}
if model.isCurated {
return "Built-in"
}
return "Inferred"
}
@ViewBuilder
private func comparisonRow(label: String, effective: String, baseline: String) -> some View {
GridRow {
Text(label)
Text(effective)
.monospaced()
Text(baseline)
.foregroundStyle(.secondary)
.monospaced()
}
}
private func yesNo(_ value: Bool) -> String {
value ? "Yes" : "No"
}
}

View File

@@ -0,0 +1,5 @@
import Foundation
enum ModelManagementWindow {
static let windowID = "model-manager"
}

View File

@@ -7,14 +7,14 @@ struct ModelPickerView: View {
var body: some View {
HStack(spacing: 8) {
Picker("Model", selection: selectedModelBinding) {
ForEach(ModelConfig.availableModels) { config in
ForEach(modelManager.availableModels) { config in
Label(
config.displayName,
config.isCurated ? config.displayName : config.repoId,
systemImage: config.isLocal ? "checkmark.circle.fill" : "arrow.down.circle"
).tag(config.id)
}
}
.frame(width: 160)
.frame(width: 260)
.disabled(modelManager.isLoading)
// Re-download button (visible when a model is loaded)
@@ -50,9 +50,20 @@ struct ModelPickerView: View {
private var selectedModelBinding: Binding<String> {
Binding(
get: { modelManager.currentModel?.id ?? ModelConfig.default.id },
get: {
if let currentId = modelManager.currentModel?.id {
return currentId
}
if let defaultId = Preferences.defaultModelId,
let config = modelManager.availableModels.first(where: { $0.id == defaultId || $0.repoId == defaultId }) {
return config.id
}
return ModelConfig.default.id
},
set: { newId in
guard let config = ModelConfig.availableModels.first(where: { $0.id == newId }) else { return }
guard let config = modelManager.availableModels.first(where: { $0.id == newId }) ?? ModelConfig.resolve(newId) else {
return
}
Task {
await modelManager.loadModel(config)
}

View File

@@ -210,6 +210,7 @@ struct SceneManagementView: View {
}
private struct SceneEditorView: View {
@Environment(ModelManager.self) private var modelManager
@Environment(SceneStore.self) private var sceneStore
let scene: ChatScene
@@ -221,7 +222,7 @@ private struct SceneEditorView: View {
Picker("Model", selection: modelBinding) {
Text("Current model").tag(Optional<String>.none)
ForEach(ModelConfig.availableModels) { model in
ForEach(modelManager.availableModels) { model in
Text(model.displayName).tag(Optional(model.id))
}
}
@@ -257,6 +258,9 @@ private struct SceneEditorView: View {
}
.formStyle(.grouped)
.navigationTitle(scene.displayName)
.onAppear {
modelManager.refreshAvailableModels()
}
}
private var modelBinding: Binding<String?> {

View File

@@ -2,6 +2,7 @@ import SwiftUI
struct SettingsView: View {
@Environment(\.openWindow) private var openWindow
@Environment(ModelManager.self) private var modelManager
@Environment(SceneStore.self) private var sceneStore
@State private var systemPrompt: String = Preferences.systemPrompt
@State private var apiPort: String = String(Preferences.apiPort)
@@ -29,7 +30,7 @@ struct SettingsView: View {
Form {
Section("Startup") {
Picker("Default model", selection: $defaultModelId) {
ForEach(ModelConfig.availableModels) { model in
ForEach(modelManager.availableModels) { model in
Text(model.displayName).tag(model.id)
}
}
@@ -44,7 +45,7 @@ struct SettingsView: View {
Section("Generation Defaults") {
Picker("Defaults for model", selection: $generationDefaultsModelId) {
ForEach(ModelConfig.availableModels) { model in
ForEach(modelManager.availableModels) { model in
Text(model.displayName).tag(model.id)
}
}
@@ -164,6 +165,15 @@ struct SettingsView: View {
}
.formStyle(.grouped)
.frame(width: 450, height: 650)
.onAppear {
modelManager.refreshAvailableModels()
if !modelManager.availableModels.contains(where: { $0.id == defaultModelId }) {
defaultModelId = modelManager.availableModels.first?.id ?? ModelConfig.default.id
}
if !modelManager.availableModels.contains(where: { $0.id == generationDefaultsModelId }) {
generationDefaultsModelId = defaultModelId
}
}
}
private var generationDefaultsBinding: Binding<GenerationSettings> {

View File

@@ -0,0 +1,180 @@
import Foundation
import XCTest
@testable import MLX_Server
final class LocalModelResolverTests: XCTestCase {
func testDiscoverModelsInfersTextOnlyMetadataAndDirectorySize() throws {
let base = try makeTempModelsRoot()
let repoDirectory = try makeRepoDirectory(base: base, owner: "example", repo: "text-only")
let configURL = repoDirectory.appendingPathComponent("config.json")
let modelURL = repoDirectory.appendingPathComponent("model.safetensors")
let tokenizerURL = repoDirectory.appendingPathComponent("tokenizer.json")
try writeJSON(
[
"architectures": ["LlamaForCausalLM"],
"max_position_embeddings": 32768,
],
to: configURL
)
try Data(repeating: 0x11, count: 64).write(to: modelURL)
try Data(repeating: 0x22, count: 19).write(to: tokenizerURL)
let expectedSize = Int64(
try Data(contentsOf: configURL).count
+ Data(contentsOf: modelURL).count
+ Data(contentsOf: tokenizerURL).count
)
let discovered = LocalModelResolver.discoverModels(in: base)
let model = try XCTUnwrap(discovered.first)
XCTAssertEqual(model.repoId, "example/text-only")
XCTAssertEqual(model.contextLength, 32768)
XCTAssertFalse(model.supportsImages)
XCTAssertEqual(model.loaderKinds, [.llm, .vlm])
XCTAssertEqual(model.sizeBytes, expectedSize)
}
func testDiscoverModelsInfersVisionMetadataFromProcessorFiles() throws {
let base = try makeTempModelsRoot()
let repoDirectory = try makeRepoDirectory(base: base, owner: "example", repo: "vision-model")
try writeJSON(
[
"text_config": ["max_position_embeddings": 262144],
"vision_config": ["hidden_size": 768],
],
to: repoDirectory.appendingPathComponent("config.json")
)
try writeJSON(["processor_class": "Qwen3VLProcessor"], to: repoDirectory.appendingPathComponent("tokenizer_config.json"))
try Data(repeating: 0x33, count: 12).write(to: repoDirectory.appendingPathComponent("processor_config.json"))
try Data(repeating: 0x44, count: 8).write(to: repoDirectory.appendingPathComponent("model.safetensors.index.json"))
let discovered = LocalModelResolver.discoverModels(in: base)
let model = try XCTUnwrap(discovered.first)
XCTAssertEqual(model.repoId, "example/vision-model")
XCTAssertEqual(model.contextLength, 262144)
XCTAssertTrue(model.supportsImages)
XCTAssertEqual(model.loaderKinds, [.vlm, .llm])
}
func testMergedCatalogKeepsCuratedModelsAndAddsCustomLocalModels() {
let localModels = [
LocalModelResolver.LocalModelInfo(
repoId: "mlx-community/gemma-3-4b-it-4bit",
directory: URL(fileURLWithPath: "/tmp/gemma"),
sizeBytes: 1024,
contextLength: 128000,
loaderKinds: [.vlm, .llm],
supportsImages: true
),
LocalModelResolver.LocalModelInfo(
repoId: "custom-org/custom-model",
directory: URL(fileURLWithPath: "/tmp/custom"),
sizeBytes: 2048,
contextLength: 65536,
loaderKinds: [.llm, .vlm],
supportsImages: false
),
]
let merged = ModelConfig.mergedModels(localModels: localModels)
let gemma = merged.first(where: { $0.id == "gemma" })
let custom = merged.first(where: { $0.repoId == "custom-org/custom-model" })
XCTAssertEqual(gemma?.localSizeBytes, 1024)
XCTAssertEqual(custom?.id, "custom-org/custom-model")
XCTAssertEqual(custom?.contextLength, 65536)
XCTAssertFalse(custom?.isCurated ?? true)
}
func testResolveUnknownRepoIdCreatesRemoteCustomConfig() throws {
let config = try XCTUnwrap(ModelConfig.resolve("custom-owner/custom-repo"))
XCTAssertEqual(config.id, "custom-owner/custom-repo")
XCTAssertEqual(config.repoId, "custom-owner/custom-repo")
XCTAssertFalse(config.isCurated)
}
func testMergedCatalogAppliesSavedMetadataOverride() {
let repoId = "custom-org/override-model"
Preferences.setModelMetadataOverride(
ModelMetadataOverride(
contextLength: 123456,
primaryLoaderKind: .vlm,
supportsImages: true,
supportsTools: true
),
forRepoId: repoId
)
defer {
Preferences.removeModelMetadataOverride(forRepoId: repoId)
}
let localModels = [
LocalModelResolver.LocalModelInfo(
repoId: repoId,
directory: URL(fileURLWithPath: "/tmp/custom-override"),
sizeBytes: 2048,
contextLength: 65536,
loaderKinds: [.llm, .vlm],
supportsImages: false
),
]
let merged = ModelConfig.mergedModels(localModels: localModels)
let overridden = merged.first(where: { $0.repoId == repoId })
XCTAssertEqual(overridden?.contextLength, 123456)
XCTAssertEqual(overridden?.primaryLoaderKind, .vlm)
XCTAssertTrue(overridden?.supportsImages ?? false)
XCTAssertTrue(overridden?.supportsTools ?? false)
}
func testResolveUnknownRepoIdUsesSavedMetadataOverride() throws {
let repoId = "custom-owner/custom-repo-with-override"
Preferences.setModelMetadataOverride(
ModelMetadataOverride(
contextLength: 8192,
primaryLoaderKind: .llm,
supportsImages: false,
supportsTools: true
),
forRepoId: repoId
)
defer {
Preferences.removeModelMetadataOverride(forRepoId: repoId)
}
let config = try XCTUnwrap(ModelConfig.resolve(repoId))
XCTAssertEqual(config.contextLength, 8192)
XCTAssertEqual(config.primaryLoaderKind, .llm)
XCTAssertFalse(config.supportsImages)
XCTAssertTrue(config.supportsTools)
}
private func makeTempModelsRoot() throws -> URL {
let root = FileManager.default.temporaryDirectory
.appendingPathComponent(UUID().uuidString, isDirectory: true)
try FileManager.default.createDirectory(at: root, withIntermediateDirectories: true)
addTeardownBlock {
try? FileManager.default.removeItem(at: root)
}
return root
}
private func makeRepoDirectory(base: URL, owner: String, repo: String) throws -> URL {
let directory = base
.appendingPathComponent(owner, isDirectory: true)
.appendingPathComponent(repo, isDirectory: true)
try FileManager.default.createDirectory(at: directory, withIntermediateDirectories: true)
return directory
}
private func writeJSON(_ object: Any, to url: URL) throws {
let data = try JSONSerialization.data(withJSONObject: object, options: [.prettyPrinted, .sortedKeys])
try data.write(to: url)
}
}

View File

@@ -44,7 +44,8 @@ This is intended for targeted validation while keeping the normal default as the
- **Chat interface** with markdown rendering and model-aware image attachments (file picker, drag & drop, clipboard paste, Finder copy-paste on vision-capable models)
- **Scene-based chat starts** — New Chat opens a scene picker with Neutral plus saved scenes, each with an optional model override, a scene prompt layered onto the base system prompt, an auto-sent starter prompt, and optional generation-setting overrides for chat-specific behavior
- **Model picker** in toolbar with local/download status indicators and re-download button
- **Model picker** in toolbar with curated defaults plus any locally discovered MLX models on disk
- **Models window** in the menu for downloading a model by HuggingFace ID, inspecting on-disk model sizes, and deleting local model folders
- **Download progress modal** — shows file progress, percentage, and speed when downloading a new model
- **Thinking mode** — models like Qwen3.5 can reason internally before responding; thinking content appears in a collapsible box. Toggle on/off in Settings.
- **Streaming responses** with live token display