diff --git a/MLXServer.xcodeproj/project.pbxproj b/MLXServer.xcodeproj/project.pbxproj index ec2b104..fb284b8 100644 --- a/MLXServer.xcodeproj/project.pbxproj +++ b/MLXServer.xcodeproj/project.pbxproj @@ -23,18 +23,21 @@ 2E3A02DF9C6A5109E532D5E2 /* ChatDocumentController.swift in Sources */ = {isa = PBXBuildFile; fileRef = D5C1FCEFEA72B9ABB87FB20E /* ChatDocumentController.swift */; }; 3A9DB84947BBBBED06CF9E1E /* TestImageFixtures.swift in Sources */ = {isa = PBXBuildFile; fileRef = 31BD930DEC051408444C30D4 /* TestImageFixtures.swift */; }; 4158FA884D981D73288FB74C /* SaveChatCommands.swift in Sources */ = {isa = PBXBuildFile; fileRef = 2E2FCA55CEBEBCED78D9479A /* SaveChatCommands.swift */; }; + 4B7449F57226CB48C4F5EEBD /* LocalModelResolverTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = 43315501A5AFC0EA014F44F5 /* LocalModelResolverTests.swift */; }; 4CB13DC1AC7A500DDBB443EC /* ChatInputView.swift in Sources */ = {isa = PBXBuildFile; fileRef = E5E6AD02CDF23BDAB64700A7 /* ChatInputView.swift */; }; 4DC033E45880B2948B47DEB1 /* FocusedValues.swift in Sources */ = {isa = PBXBuildFile; fileRef = EF518FEBF3A38E830E3CE1A5 /* FocusedValues.swift */; }; 50B6861FF8610B3ED4FFAD9D /* MLXServerApp.swift in Sources */ = {isa = PBXBuildFile; fileRef = C67742651DB486871CEF1612 /* MLXServerApp.swift */; }; 50DD129CCF2843482DEC3B96 /* APIServer.swift in Sources */ = {isa = PBXBuildFile; fileRef = 3D08828E16B17EF02C14243E /* APIServer.swift */; }; 5946258F1DE88CE904584E0B /* ContentView.swift in Sources */ = {isa = PBXBuildFile; fileRef = 944C699FBB76C734C9DF2F2E /* ContentView.swift */; }; 5C1E8FE1C521914CEF98D3AA /* ChatMessagesView.swift in Sources */ = {isa = PBXBuildFile; fileRef = DB1A5E8B1C9F2BC4D262C53A /* ChatMessagesView.swift */; }; + 5D41C2B260265A32FF42264B /* ModelManagementView.swift in Sources */ = {isa = PBXBuildFile; fileRef = 8A1B8F9568F95E07D212A2B7 /* ModelManagementView.swift */; }; 621B7E4382199AC1378F5F9C /* StatusBarView.swift in Sources */ = {isa = PBXBuildFile; fileRef = B0EAB35D7130D56B9E7484BA /* StatusBarView.swift */; }; 67262C5E24739F1FE0011439 /* StreamingSSEEncoder.swift in Sources */ = {isa = PBXBuildFile; fileRef = 615F8A7C9ABCADEB215D31BD /* StreamingSSEEncoder.swift */; }; 67B815DC3304BF4B2E9974A8 /* LiveCountersTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = 7E7DF9F68C10C718844B7B01 /* LiveCountersTests.swift */; }; 67D0628F148FE3C2200E0AEF /* APIServerResponseResolutionTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = 051FEC14CC76A677F79ACD21 /* APIServerResponseResolutionTests.swift */; }; 6828CCA8B78AB40906F87CAB /* LocalModelResolver.swift in Sources */ = {isa = PBXBuildFile; fileRef = D733A0D1D4AC25DDDA6C8684 /* LocalModelResolver.swift */; }; 741692862DB1F13EA0B2D14D /* TokenPrefixCache.swift in Sources */ = {isa = PBXBuildFile; fileRef = 1962D530BEABCC7F1E8E0ED1 /* TokenPrefixCache.swift */; }; + 75E046B4ABB1E6FEF17C1A60 /* ModelManagementWindow.swift in Sources */ = {isa = PBXBuildFile; fileRef = 721D6F203A10434FE0223042 /* ModelManagementWindow.swift */; }; 7936325B425DFA2931F6E421 /* ModelBackedQuantizationTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = F7E6F18C80D9859E89D2B4E3 /* ModelBackedQuantizationTests.swift */; }; 7CD765C1E2F9F4D7504C8D09 /* Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = B629DA084A9A40E54F8EA5FA /* Assets.xcassets */; }; 80646C5066BF79BC76E1D9D7 /* ModelConfig.swift in Sources */ = {isa = PBXBuildFile; fileRef = 38DFC212AF4359A45FBE22BA /* ModelConfig.swift */; }; @@ -100,6 +103,7 @@ 3D08828E16B17EF02C14243E /* APIServer.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = APIServer.swift; sourceTree = ""; }; 4147321383E94E9F17A0154E /* SettingsView.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = SettingsView.swift; sourceTree = ""; }; 4239CFF94B819C35A8D4D617 /* MonitorView.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = MonitorView.swift; sourceTree = ""; }; + 43315501A5AFC0EA014F44F5 /* LocalModelResolverTests.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = LocalModelResolverTests.swift; sourceTree = ""; }; 49C383DD5224F3420EB98DB2 /* StreamingSSEEncoderTests.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = StreamingSSEEncoderTests.swift; sourceTree = ""; }; 57AC0815F72BDD32FC54C88A /* GenerationSettingsTests.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = GenerationSettingsTests.swift; sourceTree = ""; }; 5F9426FA5A4AC55F8D9C080E /* PromptBuilderTests.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = PromptBuilderTests.swift; sourceTree = ""; }; @@ -108,9 +112,11 @@ 6B3AA91D2C7842D7366F9A41 /* ChatDocumentPackage.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ChatDocumentPackage.swift; sourceTree = ""; }; 6EE59189918D06B8D2F588FC /* MLXServer.app */ = {isa = PBXFileReference; explicitFileType = wrapper.application; includeInIndex = 0; path = MLXServer.app; sourceTree = BUILT_PRODUCTS_DIR; }; 6FAF7455BD387CD2061E0CBF /* GenerationSettings.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = GenerationSettings.swift; sourceTree = ""; }; + 721D6F203A10434FE0223042 /* ModelManagementWindow.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ModelManagementWindow.swift; sourceTree = ""; }; 7AE2A32FBB744696DEA77435 /* GenerationSettingsEditor.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = GenerationSettingsEditor.swift; sourceTree = ""; }; 7C1A89C076E717F87A60397D /* ImageDecoder.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ImageDecoder.swift; sourceTree = ""; }; 7E7DF9F68C10C718844B7B01 /* LiveCountersTests.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = LiveCountersTests.swift; sourceTree = ""; }; + 8A1B8F9568F95E07D212A2B7 /* ModelManagementView.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ModelManagementView.swift; sourceTree = ""; }; 922CBDC9206737BD04AF2874 /* ModelManager.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ModelManager.swift; sourceTree = ""; }; 944C699FBB76C734C9DF2F2E /* ContentView.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ContentView.swift; sourceTree = ""; }; A4B359324B5FD8D106C74338 /* ChatMessage.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ChatMessage.swift; sourceTree = ""; }; @@ -199,6 +205,7 @@ 57AC0815F72BDD32FC54C88A /* GenerationSettingsTests.swift */, E4573DC9314915F4C7963B4E /* ImageDecoderTests.swift */, 7E7DF9F68C10C718844B7B01 /* LiveCountersTests.swift */, + 43315501A5AFC0EA014F44F5 /* LocalModelResolverTests.swift */, D388BE00B42C06ED9D9905BF /* ModelBackedInferenceValidationTests.swift */, F7E6F18C80D9859E89D2B4E3 /* ModelBackedQuantizationTests.swift */, 5F9426FA5A4AC55F8D9C080E /* PromptBuilderTests.swift */, @@ -246,6 +253,8 @@ DB1A5E8B1C9F2BC4D262C53A /* ChatMessagesView.swift */, 2DC8C86D397B1FCA08E07CBD /* DownloadModalView.swift */, 7AE2A32FBB744696DEA77435 /* GenerationSettingsEditor.swift */, + 8A1B8F9568F95E07D212A2B7 /* ModelManagementView.swift */, + 721D6F203A10434FE0223042 /* ModelManagementWindow.swift */, C3C3A76C02AF70A9D8F868FC /* ModelPickerView.swift */, 4239CFF94B819C35A8D4D617 /* MonitorView.swift */, 37FEB592E5E717F817B03151 /* SceneManagementView.swift */, @@ -416,6 +425,7 @@ 847B445654860396AF5A8280 /* GenerationSettingsTests.swift in Sources */, E92B6656C251EDA246B8F582 /* ImageDecoderTests.swift in Sources */, 67B815DC3304BF4B2E9974A8 /* LiveCountersTests.swift in Sources */, + 4B7449F57226CB48C4F5EEBD /* LocalModelResolverTests.swift in Sources */, 8E665E21CCCD87A907CEA78D /* ModelBackedInferenceValidationTests.swift in Sources */, 7936325B425DFA2931F6E421 /* ModelBackedQuantizationTests.swift in Sources */, 1FE8C624898960ECCE39C0D4 /* PromptBuilderTests.swift in Sources */, @@ -455,6 +465,8 @@ 6828CCA8B78AB40906F87CAB /* LocalModelResolver.swift in Sources */, 50B6861FF8610B3ED4FFAD9D /* MLXServerApp.swift in Sources */, 80646C5066BF79BC76E1D9D7 /* ModelConfig.swift in Sources */, + 5D41C2B260265A32FF42264B /* ModelManagementView.swift in Sources */, + 75E046B4ABB1E6FEF17C1A60 /* ModelManagementWindow.swift in Sources */, 0168AEE16009097901363E16 /* ModelManager.swift in Sources */, 2CAAF7129F7CC45200FA9F6B /* ModelPickerView.swift in Sources */, B1D9BC407DB7DB1489230C20 /* MonitorView.swift in Sources */, diff --git a/MLXServer/Commands/SceneCommands.swift b/MLXServer/Commands/SceneCommands.swift index 2bd8e82..57886c8 100644 --- a/MLXServer/Commands/SceneCommands.swift +++ b/MLXServer/Commands/SceneCommands.swift @@ -11,4 +11,16 @@ struct SceneCommands: Commands { .keyboardShortcut(",", modifiers: [.command, .shift]) } } +} + +struct ModelCommands: Commands { + @Environment(\.openWindow) private var openWindow + + var body: some Commands { + CommandMenu("Models") { + Button("Manage Models…") { + openWindow(id: ModelManagementWindow.windowID) + } + } + } } \ No newline at end of file diff --git a/MLXServer/ContentView.swift b/MLXServer/ContentView.swift index c929c1d..e112fe6 100644 --- a/MLXServer/ContentView.swift +++ b/MLXServer/ContentView.swift @@ -29,6 +29,7 @@ struct ContentView: View { AnyView(mainContent) .navigationTitle(navigationTitleText) .onAppear { + modelManager.refreshAvailableModels() if chatVM == nil { let vm = ChatViewModel(modelManager: modelManager) chatVM = vm @@ -68,7 +69,7 @@ struct ContentView: View { AnyView(lifecycleContent) .alert("Model Error", isPresented: $showLoadError) { Button("Retry") { - if let config = modelManager.currentModel ?? ModelConfig.availableModels.first { + if let config = modelManager.currentModel ?? modelManager.availableModels.first { Task { await modelManager.loadModel(config) } } } @@ -228,7 +229,7 @@ struct ContentView: View { @ViewBuilder private var modelSwitchShortcuts: some View { - ForEach(Array(ModelConfig.availableModels.enumerated()), id: \.element.id) { index, config in + ForEach(Array(ModelConfig.curatedModels.enumerated()), id: \.element.id) { index, config in if index < 9 { Button("") { Task { await modelManager.loadModel(config) } @@ -420,7 +421,7 @@ struct ContentView: View { guard modelManager.currentModel == nil else { return } let modelId = Preferences.defaultModelId ?? Preferences.lastModelId ?? ModelConfig.default.id - if let config = ModelConfig.availableModels.first(where: { $0.id == modelId }) { + if let config = ModelConfig.resolve(modelId) { await modelManager.loadModel(config) } } diff --git a/MLXServer/MLXServerApp.swift b/MLXServer/MLXServerApp.swift index f7abf88..56bc884 100644 --- a/MLXServer/MLXServerApp.swift +++ b/MLXServer/MLXServerApp.swift @@ -52,17 +52,26 @@ struct MLXServerApp: App { .commands { SaveChatCommands() SceneCommands() + ModelCommands() } Window("Scenes", id: SceneManagementWindow.windowID) { SceneManagementView() + .environment(modelManager) .environment(sceneStore) } .defaultSize(width: 900, height: 560) + Window("Models", id: ModelManagementWindow.windowID) { + ModelManagementView() + .environment(modelManager) + } + .defaultSize(width: 900, height: 620) + #if os(macOS) Settings { SettingsView() + .environment(modelManager) .environment(sceneStore) } #endif diff --git a/MLXServer/Models/ChatScene.swift b/MLXServer/Models/ChatScene.swift index 59bde57..02af11a 100644 --- a/MLXServer/Models/ChatScene.swift +++ b/MLXServer/Models/ChatScene.swift @@ -53,7 +53,7 @@ struct ChatScene: Codable, Identifiable, Hashable { var resolvedModel: ModelConfig? { guard let modelId else { return nil } - return ModelConfig.availableModels.first(where: { $0.id == modelId }) + return ModelConfig.resolve(modelId) } static let empty = ChatScene(name: "New Scene") diff --git a/MLXServer/Models/ModelConfig.swift b/MLXServer/Models/ModelConfig.swift index 49da622..5f3cff1 100644 --- a/MLXServer/Models/ModelConfig.swift +++ b/MLXServer/Models/ModelConfig.swift @@ -1,30 +1,81 @@ import Foundation import MLXLMCommon +struct ModelMetadataOverride: Codable, Hashable, Sendable { + var contextLength: Int + var primaryLoaderKind: ModelConfig.LoaderKind + var supportsImages: Bool + var supportsTools: Bool + + func normalized() -> ModelMetadataOverride { + ModelMetadataOverride( + contextLength: max(0, contextLength), + primaryLoaderKind: primaryLoaderKind, + supportsImages: supportsImages, + supportsTools: supportsTools + ) + } +} + /// Defines a supported model with its metadata. struct ModelConfig: Identifiable, Hashable { - enum LoaderKind: Hashable { + enum LoaderKind: String, CaseIterable, Codable, Hashable, Sendable { case llm case vlm + + var displayName: String { + switch self { + case .llm: + return "Text" + case .vlm: + return "Vision" + } + } } let id: String // alias: "gemma", "gemma3n", "qwen" let repoId: String // HuggingFace ID let displayName: String let contextLength: Int - let loaderKind: LoaderKind + let loaderKinds: [LoaderKind] let supportsImages: Bool let supportsTools: Bool let defaultGenerationSettings: GenerationSettings + let isCurated: Bool + let localSizeBytes: Int64? - /// All models supported by the app. - static let availableModels: [ModelConfig] = [ + init( + id: String, + repoId: String, + displayName: String, + contextLength: Int, + loaderKinds: [LoaderKind], + supportsImages: Bool, + supportsTools: Bool, + defaultGenerationSettings: GenerationSettings, + isCurated: Bool = true, + localSizeBytes: Int64? = nil + ) { + self.id = id + self.repoId = repoId + self.displayName = displayName + self.contextLength = contextLength + self.loaderKinds = loaderKinds + self.supportsImages = supportsImages + self.supportsTools = supportsTools + self.defaultGenerationSettings = defaultGenerationSettings + self.isCurated = isCurated + self.localSizeBytes = localSizeBytes + } + + /// Curated models supported and tuned by the app. + static let curatedModels: [ModelConfig] = [ ModelConfig( id: "gemma", repoId: "mlx-community/gemma-3-4b-it-4bit", displayName: "Gemma 3 4B", contextLength: 128_000, - loaderKind: .vlm, + loaderKinds: [.vlm], supportsImages: true, supportsTools: true, defaultGenerationSettings: .technicalDefault @@ -34,7 +85,7 @@ struct ModelConfig: Identifiable, Hashable { repoId: "mlx-community/Qwen3.5-4B-MLX-4bit", displayName: "Qwen3.5 4B", contextLength: 256_000, - loaderKind: .vlm, + loaderKinds: [.vlm], supportsImages: true, supportsTools: true, defaultGenerationSettings: .technicalDefault @@ -44,7 +95,7 @@ struct ModelConfig: Identifiable, Hashable { repoId: "mlx-community/Qwen3.5-0.8B-4bit", displayName: "Qwen3.5 0.8B", contextLength: 256_000, - loaderKind: .vlm, + loaderKinds: [.vlm], supportsImages: true, supportsTools: true, defaultGenerationSettings: .technicalDefault @@ -54,7 +105,7 @@ struct ModelConfig: Identifiable, Hashable { repoId: "mlx-community/Qwen3.5-9B-4bit", displayName: "Qwen3.5 9B", contextLength: 256_000, - loaderKind: .vlm, + loaderKinds: [.vlm], supportsImages: true, supportsTools: true, defaultGenerationSettings: .technicalDefault @@ -64,7 +115,7 @@ struct ModelConfig: Identifiable, Hashable { repoId: "synk/L3-8B-Stheno-v3.2-MLX", displayName: "Stheno L3 8B", contextLength: 8_192, - loaderKind: .llm, + loaderKinds: [.llm], supportsImages: false, supportsTools: false, defaultGenerationSettings: .roleplayDefault @@ -74,18 +125,35 @@ struct ModelConfig: Identifiable, Hashable { repoId: "hobaratio/MN-Violet-Lotus-12B-mlx-4Bit", displayName: "Violet Lotus 12B", contextLength: 32_768, - loaderKind: .llm, + loaderKinds: [.llm], supportsImages: false, supportsTools: false, defaultGenerationSettings: .roleplayDefault ), ] - static let `default` = availableModels[0] + static var availableModels: [ModelConfig] { + mergedModels(localModels: LocalModelResolver.discoveredLocalModels()) + } + + static let `default` = curatedModels[0] /// Whether this model is cached locally (no download needed). var isLocal: Bool { - LocalModelResolver.isAvailable(repoId: repoId) + localSizeBytes != nil || LocalModelResolver.isAvailable(repoId: repoId) + } + + var primaryLoaderKind: LoaderKind { + loaderKinds.first ?? .llm + } + + var metadataOverrideValue: ModelMetadataOverride { + ModelMetadataOverride( + contextLength: contextLength, + primaryLoaderKind: primaryLoaderKind, + supportsImages: supportsImages, + supportsTools: supportsTools + ) } /// Build a ModelConfiguration for mlx-swift-lm from this config. @@ -96,6 +164,9 @@ struct ModelConfig: Identifiable, Hashable { /// Resolve a model string (alias, full repo ID, or partial match) to a ModelConfig. /// Mirrors the Python server's `ModelManager.resolve_model()`. static func resolve(_ requested: String) -> ModelConfig? { + let requested = requested.trimmingCharacters(in: .whitespacesAndNewlines) + guard !requested.isEmpty else { return nil } + // Exact alias match if let config = availableModels.first(where: { $0.id == requested }) { return config @@ -108,6 +179,129 @@ struct ModelConfig: Identifiable, Hashable { if let config = availableModels.first(where: { requested.contains($0.id) || $0.repoId.contains(requested) || requested.contains($0.repoId) }) { return config } + if requested.contains("/") { + return remoteCustom(repoId: requested) + } return nil } + + static func mergedModels( + localModels: [LocalModelResolver.LocalModelInfo], + applyingOverrides: Bool = true + ) -> [ModelConfig] { + let localByRepo = Dictionary(uniqueKeysWithValues: localModels.map { ($0.repoId, $0) }) + let curatedRepoIds = Set(curatedModels.map(\.repoId)) + + let curated = curatedModels.map { config in + if let local = localByRepo[config.repoId] { + return applyingOverrides ? applyMetadataOverrideIfNeeded(to: config.withLocalSize(local.sizeBytes)) : config.withLocalSize(local.sizeBytes) + } + return applyingOverrides ? applyMetadataOverrideIfNeeded(to: config) : config + } + + let discoveredCustom = localModels + .filter { !curatedRepoIds.contains($0.repoId) } + .map(customLocal) + .sorted { lhs, rhs in + lhs.displayName.localizedCaseInsensitiveCompare(rhs.displayName) == .orderedAscending + } + + return curated + discoveredCustom + } + + static func baselineModel( + forRepoId repoId: String, + localModels: [LocalModelResolver.LocalModelInfo] + ) -> ModelConfig? { + mergedModels(localModels: localModels, applyingOverrides: false) + .first(where: { $0.repoId == repoId || $0.id == repoId }) + ?? (repoId.contains("/") ? remoteCustom(repoId: repoId) : nil) + } + + static func remoteCustom(repoId: String) -> ModelConfig { + let supportsImages = inferredVisionSupport(repoId: repoId) + return applyMetadataOverrideIfNeeded(to: ModelConfig( + id: repoId, + repoId: repoId, + displayName: displayName(for: repoId), + contextLength: 0, + loaderKinds: supportsImages ? [.vlm, .llm] : [.llm, .vlm], + supportsImages: supportsImages, + supportsTools: inferredToolSupport(repoId: repoId), + defaultGenerationSettings: .generalDefault, + isCurated: false + )) + } + + static func displayName(for repoId: String) -> String { + let raw = repoId.split(separator: "/").last.map(String.init) ?? repoId + return raw + .replacingOccurrences(of: "-", with: " ") + .replacingOccurrences(of: "_", with: " ") + } + + private static func customLocal(_ local: LocalModelResolver.LocalModelInfo) -> ModelConfig { + applyMetadataOverrideIfNeeded(to: ModelConfig( + id: local.repoId, + repoId: local.repoId, + displayName: displayName(for: local.repoId), + contextLength: local.contextLength, + loaderKinds: local.loaderKinds, + supportsImages: local.supportsImages, + supportsTools: inferredToolSupport(repoId: local.repoId), + defaultGenerationSettings: .generalDefault, + isCurated: false, + localSizeBytes: local.sizeBytes + )) + } + + private static func inferredToolSupport(repoId: String) -> Bool { + let normalized = repoId.lowercased() + return normalized.contains("qwen") || normalized.contains("gemma") + } + + private static func inferredVisionSupport(repoId: String) -> Bool { + let normalized = repoId.lowercased() + return normalized.contains("vision") || normalized.contains("vl") || normalized.contains("gemma-3") || normalized.contains("qwen") + } + + private func withLocalSize(_ sizeBytes: Int64) -> ModelConfig { + ModelConfig( + id: id, + repoId: repoId, + displayName: displayName, + contextLength: contextLength, + loaderKinds: loaderKinds, + supportsImages: supportsImages, + supportsTools: supportsTools, + defaultGenerationSettings: defaultGenerationSettings, + isCurated: isCurated, + localSizeBytes: sizeBytes + ) + } + + private func applyingMetadataOverride(_ override: ModelMetadataOverride) -> ModelConfig { + let normalized = override.normalized() + let reorderedLoaderKinds = [normalized.primaryLoaderKind] + LoaderKind.allCases.filter { $0 != normalized.primaryLoaderKind } + + return ModelConfig( + id: id, + repoId: repoId, + displayName: displayName, + contextLength: normalized.contextLength, + loaderKinds: reorderedLoaderKinds, + supportsImages: normalized.supportsImages, + supportsTools: normalized.supportsTools, + defaultGenerationSettings: defaultGenerationSettings, + isCurated: isCurated, + localSizeBytes: localSizeBytes + ) + } + + private static func applyMetadataOverrideIfNeeded(to config: ModelConfig) -> ModelConfig { + guard let override = Preferences.modelMetadataOverride(forRepoId: config.repoId) else { + return config + } + return config.applyingMetadataOverride(override) + } } diff --git a/MLXServer/Utilities/LocalModelResolver.swift b/MLXServer/Utilities/LocalModelResolver.swift index 2ce904b..cdcb6b8 100644 --- a/MLXServer/Utilities/LocalModelResolver.swift +++ b/MLXServer/Utilities/LocalModelResolver.swift @@ -6,6 +6,17 @@ import Foundation /// ~/Library/Containers/de.rfc1437.mlxserver/Data/Library/Caches/models/{org}/{name}/ enum LocalModelResolver { + struct LocalModelInfo: Identifiable, Hashable { + let repoId: String + let directory: URL + let sizeBytes: Int64 + let contextLength: Int + let loaderKinds: [ModelConfig.LoaderKind] + let supportsImages: Bool + + var id: String { repoId } + } + /// Base directory where HubApi stores downloaded models. private static let modelsBase: URL? = { FileManager.default.urls(for: .cachesDirectory, in: .userDomainMask).first? @@ -31,6 +42,46 @@ enum LocalModelResolver { resolve(repoId: repoId) != nil } + static func discoveredLocalModels() -> [LocalModelInfo] { + guard let base = modelsBase else { return [] } + return discoverModels(in: base) + } + + static func discoverModels(in base: URL) -> [LocalModelInfo] { + let fileManager = FileManager.default + let directoryKeys: Set = [.isDirectoryKey] + guard let ownerDirectories = try? fileManager.contentsOfDirectory( + at: base, + includingPropertiesForKeys: Array(directoryKeys), + options: [.skipsHiddenFiles] + ) else { + return [] + } + + var discovered: [LocalModelInfo] = [] + + for ownerDirectory in ownerDirectories { + guard isDirectory(ownerDirectory) else { continue } + guard let repoDirectories = try? fileManager.contentsOfDirectory( + at: ownerDirectory, + includingPropertiesForKeys: Array(directoryKeys), + options: [.skipsHiddenFiles] + ) else { + continue + } + + for repoDirectory in repoDirectories where isDirectory(repoDirectory) { + if let info = localModelInfo(ownerDirectory: ownerDirectory, repoDirectory: repoDirectory) { + discovered.append(info) + } + } + } + + return discovered.sorted { + $0.repoId.localizedCaseInsensitiveCompare($1.repoId) == .orderedAscending + } + } + /// Delete the local cache for a model so it will be re-downloaded next time. @discardableResult static func deleteLocal(repoId: String) -> Bool { @@ -46,4 +97,122 @@ enum LocalModelResolver { return false } } + + private static func localModelInfo(ownerDirectory: URL, repoDirectory: URL) -> LocalModelInfo? { + let repoId = "\(ownerDirectory.lastPathComponent)/\(repoDirectory.lastPathComponent)" + guard containsModelArtifacts(at: repoDirectory) else { return nil } + + let config = readJSONObject(at: repoDirectory.appendingPathComponent("config.json")) + let tokenizerConfig = readJSONObject(at: repoDirectory.appendingPathComponent("tokenizer_config.json")) + let supportsImages = inferredSupportsImages( + repoDirectory: repoDirectory, + config: config, + tokenizerConfig: tokenizerConfig + ) + let sizeBytes = directorySize(at: repoDirectory) + let contextLength = inferredContextLength(config: config, tokenizerConfig: tokenizerConfig) + let loaderKinds: [ModelConfig.LoaderKind] = supportsImages ? [.vlm, .llm] : [.llm, .vlm] + + return LocalModelInfo( + repoId: repoId, + directory: repoDirectory, + sizeBytes: sizeBytes, + contextLength: contextLength, + loaderKinds: loaderKinds, + supportsImages: supportsImages + ) + } + + private static func containsModelArtifacts(at directory: URL) -> Bool { + let requiredPaths = [ + directory.appendingPathComponent("config.json").path, + directory.appendingPathComponent("model.safetensors").path, + directory.appendingPathComponent("model.safetensors.index.json").path, + ] + return requiredPaths.contains { FileManager.default.fileExists(atPath: $0) } + } + + private static func isDirectory(_ url: URL) -> Bool { + (try? url.resourceValues(forKeys: [.isDirectoryKey]).isDirectory) == true + } + + private static func readJSONObject(at url: URL) -> [String: Any]? { + guard let data = try? Data(contentsOf: url) else { return nil } + return (try? JSONSerialization.jsonObject(with: data)) as? [String: Any] + } + + private static func inferredSupportsImages( + repoDirectory: URL, + config: [String: Any]?, + tokenizerConfig: [String: Any]? + ) -> Bool { + if config?["vision_config"] != nil { + return true + } + if tokenizerConfig?["image_token"] != nil { + return true + } + + let metadataFiles = [ + "processor_config.json", + "preprocessor_config.json", + "video_preprocessor_config.json", + ] + return metadataFiles.contains { + FileManager.default.fileExists(atPath: repoDirectory.appendingPathComponent($0).path) + } + } + + private static func inferredContextLength( + config: [String: Any]?, + tokenizerConfig: [String: Any]? + ) -> Int { + if let value = integerValue(at: ["text_config", "max_position_embeddings"], in: config) { + return value + } + if let value = integerValue(at: ["max_position_embeddings"], in: config) { + return value + } + if let value = integerValue(at: ["model_max_length"], in: tokenizerConfig) { + return value + } + return 0 + } + + private static func integerValue(at path: [String], in json: [String: Any]?) -> Int? { + guard let json else { return nil } + + var current: Any = json + for component in path { + guard let dictionary = current as? [String: Any], let next = dictionary[component] else { + return nil + } + current = next + } + + if let number = current as? NSNumber { + return number.intValue + } + return current as? Int + } + + private static func directorySize(at directory: URL) -> Int64 { + let keys: [URLResourceKey] = [.isRegularFileKey, .fileSizeKey] + guard let enumerator = FileManager.default.enumerator( + at: directory, + includingPropertiesForKeys: keys, + options: [.skipsHiddenFiles] + ) else { + return 0 + } + + var total: Int64 = 0 + for case let fileURL as URL in enumerator { + guard let values = try? fileURL.resourceValues(forKeys: Set(keys)), values.isRegularFile == true else { + continue + } + total += Int64(values.fileSize ?? 0) + } + return total + } } diff --git a/MLXServer/Utilities/Preferences.swift b/MLXServer/Utilities/Preferences.swift index e894788..05b2346 100644 --- a/MLXServer/Utilities/Preferences.swift +++ b/MLXServer/Utilities/Preferences.swift @@ -7,6 +7,7 @@ enum Preferences { private static let jsonEncoder = JSONEncoder() private static let jsonDecoder = JSONDecoder() private static let legacyThinkingDefault = true + private static let modelMetadataOverridesKey = "modelMetadataOverrides" // MARK: - Last used model @@ -118,6 +119,26 @@ enum Preferences { modelGenerationSettingsMap[modelId] != nil } + static func modelMetadataOverride(forRepoId repoId: String) -> ModelMetadataOverride? { + modelMetadataOverridesMap[repoId]?.normalized() + } + + static func setModelMetadataOverride(_ override: ModelMetadataOverride, forRepoId repoId: String) { + var map = modelMetadataOverridesMap + map[repoId] = override.normalized() + modelMetadataOverridesMap = map + } + + static func removeModelMetadataOverride(forRepoId repoId: String) { + var map = modelMetadataOverridesMap + map.removeValue(forKey: repoId) + modelMetadataOverridesMap = map + } + + static func hasModelMetadataOverride(forRepoId repoId: String) -> Bool { + modelMetadataOverridesMap[repoId] != nil + } + private static var modelGenerationSettingsMap: [String: GenerationSettings] { get { guard let data = defaults.data(forKey: modelGenerationSettingsKey) else { return [:] } @@ -129,6 +150,17 @@ enum Preferences { } } + private static var modelMetadataOverridesMap: [String: ModelMetadataOverride] { + get { + guard let data = defaults.data(forKey: modelMetadataOverridesKey) else { return [:] } + return (try? jsonDecoder.decode([String: ModelMetadataOverride].self, from: data)) ?? [:] + } + set { + guard let data = try? jsonEncoder.encode(newValue) else { return } + defaults.set(data, forKey: modelMetadataOverridesKey) + } + } + // MARK: - Idle unload private static let idleUnloadMinutesKey = "idleUnloadMinutes" diff --git a/MLXServer/ViewModels/ChatViewModel.swift b/MLXServer/ViewModels/ChatViewModel.swift index 3e307f6..48ff9f9 100644 --- a/MLXServer/ViewModels/ChatViewModel.swift +++ b/MLXServer/ViewModels/ChatViewModel.swift @@ -559,7 +559,7 @@ final class ChatViewModel { if modelManager.currentModel == nil { let modelId = Preferences.defaultModelId ?? Preferences.lastModelId ?? ModelConfig.default.id - if let config = ModelConfig.availableModels.first(where: { $0.id == modelId }) { + if let config = ModelConfig.resolve(modelId) { await modelManager.loadModel(config) } } diff --git a/MLXServer/ViewModels/ModelManager.swift b/MLXServer/ViewModels/ModelManager.swift index 1950237..db0ca38 100644 --- a/MLXServer/ViewModels/ModelManager.swift +++ b/MLXServer/ViewModels/ModelManager.swift @@ -19,7 +19,10 @@ final class ModelManager { let cachesDir = FileManager.default.urls(for: .cachesDirectory, in: .userDomainMask).first return HubApi(downloadBase: cachesDir, cache: nil) }() + var currentModel: ModelConfig? + var availableModels: [ModelConfig] + private(set) var discoveredLocalModels: [LocalModelResolver.LocalModelInfo] = [] var modelContainer: ModelContainer? var isLoading = false var downloadProgress: Double = 0 @@ -36,6 +39,50 @@ final class ModelManager { private(set) var lastUsed: Date? private var latestLoadRequestID = UUID() + init() { + availableModels = [] + refreshAvailableModels() + } + + var curatedModels: [ModelConfig] { + availableModels.filter(\.isCurated) + } + + var localModelsOnDisk: [ModelConfig] { + availableModels + .filter(\.isLocal) + .sorted { + $0.displayName.localizedCaseInsensitiveCompare($1.displayName) == .orderedAscending + } + } + + func refreshAvailableModels() { + discoveredLocalModels = LocalModelResolver.discoveredLocalModels() + availableModels = ModelConfig.mergedModels(localModels: discoveredLocalModels) + + if let currentModel { + self.currentModel = availableModels.first(where: { $0.repoId == currentModel.repoId }) ?? currentModel + } + } + + func discoveredLocalModelInfo(repoId: String) -> LocalModelResolver.LocalModelInfo? { + discoveredLocalModels.first(where: { $0.repoId == repoId }) + } + + func baselineModel(repoId: String) -> ModelConfig? { + ModelConfig.baselineModel(forRepoId: repoId, localModels: discoveredLocalModels) + } + + func saveMetadataOverride(_ override: ModelMetadataOverride, for config: ModelConfig) { + Preferences.setModelMetadataOverride(override, forRepoId: config.repoId) + refreshAvailableModels() + } + + func clearMetadataOverride(for config: ModelConfig) { + Preferences.removeModelMetadataOverride(forRepoId: config.repoId) + refreshAvailableModels() + } + private func clearLoadedState() { idleTimer?.invalidate() idleTimer = nil @@ -55,7 +102,11 @@ final class ModelManager { /// Prefers the local snapshot from ~/.cache/huggingface/hub/ (shared with the Python server). /// Only downloads if the model isn't cached locally. func loadModel(_ config: ModelConfig) async { - if currentModel?.id == config.id && modelContainer != nil { + refreshAvailableModels() + let effectiveConfig = availableModels.first(where: { $0.repoId == config.repoId }) ?? config + + if currentModel?.repoId == effectiveConfig.repoId && modelContainer != nil { + currentModel = effectiveConfig return // already loaded } @@ -65,10 +116,10 @@ final class ModelManager { MLX.GPU.clearCache() isLoading = true downloadProgress = 0 - loadingModelName = config.displayName + loadingModelName = effectiveConfig.displayName errorMessage = nil - let needsDownload = !config.isLocal + let needsDownload = !effectiveConfig.isLocal if needsDownload { isDownloading = true downloadFilesTotal = 0 @@ -91,32 +142,23 @@ final class ModelManager { } let configuration: ModelConfiguration - if let localDir = LocalModelResolver.resolve(repoId: config.repoId) { + if let localDir = LocalModelResolver.resolve(repoId: effectiveConfig.repoId) { configuration = ModelConfiguration(directory: localDir) } else { - configuration = config.modelConfiguration + configuration = effectiveConfig.modelConfiguration } - let container: ModelContainer - switch config.loaderKind { - case .llm: - container = try await LLMModelFactory.shared.loadContainer( - hub: Self.hub, - configuration: configuration, - progressHandler: progressHandler - ) - case .vlm: - container = try await VLMModelFactory.shared.loadContainer( - hub: Self.hub, - configuration: configuration, - progressHandler: progressHandler - ) - } + let container = try await Self.loadContainer( + for: effectiveConfig, + configuration: configuration, + progressHandler: progressHandler + ) guard latestLoadRequestID == requestID else { return } + refreshAvailableModels() self.isDownloading = false self.modelContainer = container - self.currentModel = config + self.currentModel = self.availableModels.first(where: { $0.repoId == effectiveConfig.repoId }) ?? effectiveConfig touchActivity() } catch { guard latestLoadRequestID == requestID else { return } @@ -135,6 +177,25 @@ final class ModelManager { await loadModel(config) } + func addModel(repoId: String) async { + let repoId = repoId.trimmingCharacters(in: .whitespacesAndNewlines) + guard !repoId.isEmpty else { + errorMessage = "Enter a HuggingFace model ID." + return + } + + let config = ModelConfig.resolve(repoId) ?? ModelConfig.remoteCustom(repoId: repoId) + await loadModel(config) + } + + func deleteModel(_ config: ModelConfig) { + if currentModel?.repoId == config.repoId { + unloadModel() + } + _ = LocalModelResolver.deleteLocal(repoId: config.repoId) + refreshAvailableModels() + } + /// Unload the current model and free GPU memory. func unloadModel() { latestLoadRequestID = UUID() @@ -161,4 +222,35 @@ final class ModelManager { var isReady: Bool { modelContainer != nil && !isLoading } + + private static func loadContainer( + for config: ModelConfig, + configuration: ModelConfiguration, + progressHandler: @escaping @Sendable (Progress) -> Void + ) async throws -> ModelContainer { + var lastError: Error? + + for loaderKind in config.loaderKinds { + do { + switch loaderKind { + case .llm: + return try await LLMModelFactory.shared.loadContainer( + hub: Self.hub, + configuration: configuration, + progressHandler: progressHandler + ) + case .vlm: + return try await VLMModelFactory.shared.loadContainer( + hub: Self.hub, + configuration: configuration, + progressHandler: progressHandler + ) + } + } catch { + lastError = error + } + } + + throw lastError ?? NSError(domain: "ModelManager", code: -1, userInfo: [NSLocalizedDescriptionKey: "Unsupported model configuration"]) + } } diff --git a/MLXServer/Views/ModelManagementView.swift b/MLXServer/Views/ModelManagementView.swift new file mode 100644 index 0000000..dd651d3 --- /dev/null +++ b/MLXServer/Views/ModelManagementView.swift @@ -0,0 +1,431 @@ +import SwiftUI + +struct ModelManagementView: View { + @Environment(ModelManager.self) private var modelManager + + @State private var newRepoId = "" + @State private var pendingDelete: ModelConfig? + @State private var editingMetadataModel: ModelConfig? + @FocusState private var isRepoIdFieldFocused: Bool + + private let sizeFormatter: ByteCountFormatter = { + let formatter = ByteCountFormatter() + formatter.allowedUnits = [.useGB, .useMB, .useKB] + formatter.countStyle = .file + formatter.includesUnit = true + formatter.isAdaptive = true + return formatter + }() + + var body: some View { + ScrollView { + VStack(alignment: .leading, spacing: 18) { + GroupBox("Add Model") { + VStack(alignment: .leading, spacing: 10) { + Text("Enter a HuggingFace model ID. The app will download it, load it once, and then keep it available in the regular model picker.") + .font(.caption) + .foregroundStyle(.secondary) + + HStack { + TextField("owner/repo", text: $newRepoId) + .textFieldStyle(.roundedBorder) + .focused($isRepoIdFieldFocused) + .onSubmit { + downloadEnteredModel() + } + + Button("Download & Select") { + downloadEnteredModel() + } + .disabled(modelManager.isLoading || newRepoId.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty) + } + } + } + + GroupBox("Recommended Defaults") { + VStack(spacing: 0) { + ForEach(modelManager.curatedModels) { model in + curatedRow(model) + if model.id != modelManager.curatedModels.last?.id { + Divider() + } + } + } + } + + GroupBox("Models On Disk") { + if modelManager.localModelsOnDisk.isEmpty { + ContentUnavailableView( + "No Local Models", + systemImage: "externaldrive", + description: Text("Downloaded models will appear here with their summed file sizes.") + ) + .frame(maxWidth: .infinity) + .padding(.vertical, 20) + } else { + VStack(spacing: 0) { + ForEach(modelManager.localModelsOnDisk) { model in + localRow(model) + if model.id != modelManager.localModelsOnDisk.last?.id { + Divider() + } + } + } + } + } + } + .padding(20) + } + .navigationTitle("Models") + .frame(minWidth: 760, minHeight: 520) + .sheet(item: $editingMetadataModel) { model in + ModelMetadataEditorView( + model: model, + baselineModel: modelManager.baselineModel(repoId: model.repoId) ?? model, + detectedLocalModel: modelManager.discoveredLocalModelInfo(repoId: model.repoId), + hasSavedOverride: Preferences.hasModelMetadataOverride(forRepoId: model.repoId), + onSave: { override in + modelManager.saveMetadataOverride(override, for: model) + }, + onReset: { + modelManager.clearMetadataOverride(for: model) + } + ) + } + .alert( + "Delete Local Model?", + isPresented: Binding( + get: { pendingDelete != nil }, + set: { if !$0 { pendingDelete = nil } } + ) + ) { + Button("Delete", role: .destructive) { + if let pendingDelete { + modelManager.deleteModel(pendingDelete) + } + self.pendingDelete = nil + } + Button("Cancel", role: .cancel) { + pendingDelete = nil + } + } message: { + if let pendingDelete { + Text("This removes the local files for \(pendingDelete.repoId).") + } + } + .onAppear { + modelManager.refreshAvailableModels() + if newRepoId.isEmpty { + isRepoIdFieldFocused = true + } + } + } + + @ViewBuilder + private func curatedRow(_ model: ModelConfig) -> some View { + HStack(alignment: .top, spacing: 14) { + VStack(alignment: .leading, spacing: 4) { + HStack(spacing: 8) { + Text(model.displayName) + .font(.headline) + if modelManager.currentModel?.repoId == model.repoId { + Text("Loaded") + .font(.caption.weight(.semibold)) + .padding(.horizontal, 8) + .padding(.vertical, 3) + .background(.green.opacity(0.15), in: Capsule()) + } + } + + Text(model.repoId) + .font(.caption) + .foregroundStyle(.secondary) + } + + Spacer() + + Label( + model.isLocal ? "On Disk" : "Not Downloaded", + systemImage: model.isLocal ? "checkmark.circle.fill" : "arrow.down.circle" + ) + .font(.caption) + .foregroundStyle(model.isLocal ? .green : .secondary) + + Button(model.isLocal ? "Load" : "Download") { + Task { + await modelManager.loadModel(model) + } + } + .disabled(modelManager.isLoading) + + Button("Metadata…") { + editingMetadataModel = model + } + } + .padding(.vertical, 10) + } + + @ViewBuilder + private func localRow(_ model: ModelConfig) -> some View { + HStack(alignment: .top, spacing: 14) { + VStack(alignment: .leading, spacing: 4) { + HStack(spacing: 8) { + Text(model.displayName) + .font(.headline) + if !model.isCurated { + Text("Custom") + .font(.caption.weight(.semibold)) + .padding(.horizontal, 8) + .padding(.vertical, 3) + .background(.secondary.opacity(0.14), in: Capsule()) + } + if modelManager.currentModel?.repoId == model.repoId { + Text("Loaded") + .font(.caption.weight(.semibold)) + .padding(.horizontal, 8) + .padding(.vertical, 3) + .background(.green.opacity(0.15), in: Capsule()) + } + } + + Text(model.repoId) + .font(.caption) + .foregroundStyle(.secondary) + } + + Spacer() + + if let localSizeBytes = model.localSizeBytes { + Text(sizeFormatter.string(fromByteCount: localSizeBytes)) + .font(.caption.monospacedDigit()) + .foregroundStyle(.secondary) + .frame(width: 90, alignment: .trailing) + } + + Button("Load") { + Task { + await modelManager.loadModel(model) + } + } + .disabled(modelManager.isLoading) + + Button("Metadata…") { + editingMetadataModel = model + } + + Button("Delete", role: .destructive) { + pendingDelete = model + } + .disabled(modelManager.isLoading) + } + .padding(.vertical, 10) + } + + private func downloadEnteredModel() { + let repoId = newRepoId.trimmingCharacters(in: .whitespacesAndNewlines) + guard !repoId.isEmpty else { return } + + Task { + await modelManager.addModel(repoId: repoId) + if modelManager.errorMessage == nil { + newRepoId = "" + } + } + } +} + +private struct ModelMetadataEditorView: View { + @Environment(\.dismiss) private var dismiss + + let model: ModelConfig + let baselineModel: ModelConfig + let detectedLocalModel: LocalModelResolver.LocalModelInfo? + let hasSavedOverride: Bool + let onSave: (ModelMetadataOverride) -> Void + let onReset: () -> Void + + @State private var contextLengthText: String + @State private var primaryLoaderKind: ModelConfig.LoaderKind + @State private var supportsImages: Bool + @State private var supportsTools: Bool + + init( + model: ModelConfig, + baselineModel: ModelConfig, + detectedLocalModel: LocalModelResolver.LocalModelInfo?, + hasSavedOverride: Bool, + onSave: @escaping (ModelMetadataOverride) -> Void, + onReset: @escaping () -> Void + ) { + self.model = model + self.baselineModel = baselineModel + self.detectedLocalModel = detectedLocalModel + self.hasSavedOverride = hasSavedOverride + self.onSave = onSave + self.onReset = onReset + _contextLengthText = State(initialValue: String(model.contextLength)) + _primaryLoaderKind = State(initialValue: model.primaryLoaderKind) + _supportsImages = State(initialValue: model.supportsImages) + _supportsTools = State(initialValue: model.supportsTools) + } + + var body: some View { + NavigationStack { + Form { + Section("Metadata") { + TextField("Context length", text: $contextLengthText) + .textFieldStyle(.roundedBorder) + + Picker("Primary loader", selection: $primaryLoaderKind) { + ForEach(ModelConfig.LoaderKind.allCases, id: \.self) { loaderKind in + Text(loaderKind.displayName).tag(loaderKind) + } + } + + Toggle("Supports images", isOn: $supportsImages) + Toggle("Supports tools", isOn: $supportsTools) + } + + Section("Comparison") { + Text(defaultsSummary) + .foregroundStyle(.secondary) + + Grid(alignment: .leading, horizontalSpacing: 16, verticalSpacing: 8) { + GridRow { + Text("") + Text("Effective") + .font(.caption.weight(.semibold)) + .foregroundStyle(.secondary) + Text(baselineHeading) + .font(.caption.weight(.semibold)) + .foregroundStyle(.secondary) + } + + comparisonRow( + label: "Context", + effective: currentOverride?.contextLength.description ?? "Invalid", + baseline: baselineModel.contextLength > 0 ? "\(baselineModel.contextLength)" : "Unknown" + ) + comparisonRow( + label: "Loader", + effective: primaryLoaderKind.displayName, + baseline: baselineModel.primaryLoaderKind.displayName + ) + comparisonRow( + label: "Images", + effective: yesNo(supportsImages), + baseline: yesNo(baselineModel.supportsImages) + ) + comparisonRow( + label: "Tools", + effective: yesNo(supportsTools), + baseline: yesNo(baselineModel.supportsTools) + ) + } + } + + if let detectedLocalModel { + Section("Discovered Source") { + LabeledContent("Detected context") { + Text(detectedLocalModel.contextLength > 0 ? "\(detectedLocalModel.contextLength)" : "Unknown") + } + LabeledContent("Detected loader order") { + Text(detectedLocalModel.loaderKinds.map(\.displayName).joined(separator: ", ")) + } + LabeledContent("Detected vision") { + Text(yesNo(detectedLocalModel.supportsImages)) + } + } + } + } + .formStyle(.grouped) + .navigationTitle(model.displayName) + .frame(minWidth: 520, minHeight: 380) + .toolbar { + ToolbarItem(placement: .cancellationAction) { + Button("Cancel") { + dismiss() + } + } + + ToolbarItem(placement: .primaryAction) { + Button("Save") { + guard let currentOverride else { return } + onSave(currentOverride) + dismiss() + } + .disabled(currentOverride == nil) + } + + if hasSavedOverride { + ToolbarItem(placement: .automatic) { + Button("Reset to Detected") { + onReset() + dismiss() + } + } + } + } + } + } + + private var currentOverride: ModelMetadataOverride? { + guard let contextLength = Int(contextLengthText.trimmingCharacters(in: .whitespacesAndNewlines)), contextLength >= 0 else { + return nil + } + + return ModelMetadataOverride( + contextLength: contextLength, + primaryLoaderKind: primaryLoaderKind, + supportsImages: supportsImages, + supportsTools: supportsTools + ) + } + + private var defaultsSummary: String { + if detectedLocalModel != nil { + if hasSavedOverride { + return "The editable fields show the effective overridden metadata. The comparison column shows the discovered baseline from the local model files." + } + return "The editable fields currently match the discovered baseline from the local model files. Save to store an override for this repo ID." + } + + if model.isCurated { + return hasSavedOverride + ? "The editable fields show the effective overridden metadata. The comparison column shows the curated built-in baseline." + : "The editable fields currently match the curated built-in baseline. Save to store an override for this repo ID." + } + + if hasSavedOverride { + return "The editable fields show the effective overridden metadata. The comparison column shows the inferred baseline for this repo ID." + } + + return "The editable fields currently match the inferred baseline for this repo ID. Save to store an override." + } + + private var baselineHeading: String { + if detectedLocalModel != nil { + return "Detected" + } + if model.isCurated { + return "Built-in" + } + return "Inferred" + } + + @ViewBuilder + private func comparisonRow(label: String, effective: String, baseline: String) -> some View { + GridRow { + Text(label) + Text(effective) + .monospaced() + Text(baseline) + .foregroundStyle(.secondary) + .monospaced() + } + } + + private func yesNo(_ value: Bool) -> String { + value ? "Yes" : "No" + } +} \ No newline at end of file diff --git a/MLXServer/Views/ModelManagementWindow.swift b/MLXServer/Views/ModelManagementWindow.swift new file mode 100644 index 0000000..18d8bd6 --- /dev/null +++ b/MLXServer/Views/ModelManagementWindow.swift @@ -0,0 +1,5 @@ +import Foundation + +enum ModelManagementWindow { + static let windowID = "model-manager" +} \ No newline at end of file diff --git a/MLXServer/Views/ModelPickerView.swift b/MLXServer/Views/ModelPickerView.swift index bcc8ea4..f0c8771 100644 --- a/MLXServer/Views/ModelPickerView.swift +++ b/MLXServer/Views/ModelPickerView.swift @@ -7,14 +7,14 @@ struct ModelPickerView: View { var body: some View { HStack(spacing: 8) { Picker("Model", selection: selectedModelBinding) { - ForEach(ModelConfig.availableModels) { config in + ForEach(modelManager.availableModels) { config in Label( - config.displayName, + config.isCurated ? config.displayName : config.repoId, systemImage: config.isLocal ? "checkmark.circle.fill" : "arrow.down.circle" ).tag(config.id) } } - .frame(width: 160) + .frame(width: 260) .disabled(modelManager.isLoading) // Re-download button (visible when a model is loaded) @@ -50,9 +50,20 @@ struct ModelPickerView: View { private var selectedModelBinding: Binding { Binding( - get: { modelManager.currentModel?.id ?? ModelConfig.default.id }, + get: { + if let currentId = modelManager.currentModel?.id { + return currentId + } + if let defaultId = Preferences.defaultModelId, + let config = modelManager.availableModels.first(where: { $0.id == defaultId || $0.repoId == defaultId }) { + return config.id + } + return ModelConfig.default.id + }, set: { newId in - guard let config = ModelConfig.availableModels.first(where: { $0.id == newId }) else { return } + guard let config = modelManager.availableModels.first(where: { $0.id == newId }) ?? ModelConfig.resolve(newId) else { + return + } Task { await modelManager.loadModel(config) } diff --git a/MLXServer/Views/SceneManagementView.swift b/MLXServer/Views/SceneManagementView.swift index 1ea7747..8b65393 100644 --- a/MLXServer/Views/SceneManagementView.swift +++ b/MLXServer/Views/SceneManagementView.swift @@ -210,6 +210,7 @@ struct SceneManagementView: View { } private struct SceneEditorView: View { + @Environment(ModelManager.self) private var modelManager @Environment(SceneStore.self) private var sceneStore let scene: ChatScene @@ -221,7 +222,7 @@ private struct SceneEditorView: View { Picker("Model", selection: modelBinding) { Text("Current model").tag(Optional.none) - ForEach(ModelConfig.availableModels) { model in + ForEach(modelManager.availableModels) { model in Text(model.displayName).tag(Optional(model.id)) } } @@ -257,6 +258,9 @@ private struct SceneEditorView: View { } .formStyle(.grouped) .navigationTitle(scene.displayName) + .onAppear { + modelManager.refreshAvailableModels() + } } private var modelBinding: Binding { diff --git a/MLXServer/Views/SettingsView.swift b/MLXServer/Views/SettingsView.swift index f79609a..cb71930 100644 --- a/MLXServer/Views/SettingsView.swift +++ b/MLXServer/Views/SettingsView.swift @@ -2,6 +2,7 @@ import SwiftUI struct SettingsView: View { @Environment(\.openWindow) private var openWindow + @Environment(ModelManager.self) private var modelManager @Environment(SceneStore.self) private var sceneStore @State private var systemPrompt: String = Preferences.systemPrompt @State private var apiPort: String = String(Preferences.apiPort) @@ -29,7 +30,7 @@ struct SettingsView: View { Form { Section("Startup") { Picker("Default model", selection: $defaultModelId) { - ForEach(ModelConfig.availableModels) { model in + ForEach(modelManager.availableModels) { model in Text(model.displayName).tag(model.id) } } @@ -44,7 +45,7 @@ struct SettingsView: View { Section("Generation Defaults") { Picker("Defaults for model", selection: $generationDefaultsModelId) { - ForEach(ModelConfig.availableModels) { model in + ForEach(modelManager.availableModels) { model in Text(model.displayName).tag(model.id) } } @@ -164,6 +165,15 @@ struct SettingsView: View { } .formStyle(.grouped) .frame(width: 450, height: 650) + .onAppear { + modelManager.refreshAvailableModels() + if !modelManager.availableModels.contains(where: { $0.id == defaultModelId }) { + defaultModelId = modelManager.availableModels.first?.id ?? ModelConfig.default.id + } + if !modelManager.availableModels.contains(where: { $0.id == generationDefaultsModelId }) { + generationDefaultsModelId = defaultModelId + } + } } private var generationDefaultsBinding: Binding { diff --git a/MLXServerTests/Server/LocalModelResolverTests.swift b/MLXServerTests/Server/LocalModelResolverTests.swift new file mode 100644 index 0000000..44c3333 --- /dev/null +++ b/MLXServerTests/Server/LocalModelResolverTests.swift @@ -0,0 +1,180 @@ +import Foundation +import XCTest +@testable import MLX_Server + +final class LocalModelResolverTests: XCTestCase { + func testDiscoverModelsInfersTextOnlyMetadataAndDirectorySize() throws { + let base = try makeTempModelsRoot() + let repoDirectory = try makeRepoDirectory(base: base, owner: "example", repo: "text-only") + let configURL = repoDirectory.appendingPathComponent("config.json") + let modelURL = repoDirectory.appendingPathComponent("model.safetensors") + let tokenizerURL = repoDirectory.appendingPathComponent("tokenizer.json") + + try writeJSON( + [ + "architectures": ["LlamaForCausalLM"], + "max_position_embeddings": 32768, + ], + to: configURL + ) + try Data(repeating: 0x11, count: 64).write(to: modelURL) + try Data(repeating: 0x22, count: 19).write(to: tokenizerURL) + + let expectedSize = Int64( + try Data(contentsOf: configURL).count + + Data(contentsOf: modelURL).count + + Data(contentsOf: tokenizerURL).count + ) + + let discovered = LocalModelResolver.discoverModels(in: base) + let model = try XCTUnwrap(discovered.first) + + XCTAssertEqual(model.repoId, "example/text-only") + XCTAssertEqual(model.contextLength, 32768) + XCTAssertFalse(model.supportsImages) + XCTAssertEqual(model.loaderKinds, [.llm, .vlm]) + XCTAssertEqual(model.sizeBytes, expectedSize) + } + + func testDiscoverModelsInfersVisionMetadataFromProcessorFiles() throws { + let base = try makeTempModelsRoot() + let repoDirectory = try makeRepoDirectory(base: base, owner: "example", repo: "vision-model") + try writeJSON( + [ + "text_config": ["max_position_embeddings": 262144], + "vision_config": ["hidden_size": 768], + ], + to: repoDirectory.appendingPathComponent("config.json") + ) + try writeJSON(["processor_class": "Qwen3VLProcessor"], to: repoDirectory.appendingPathComponent("tokenizer_config.json")) + try Data(repeating: 0x33, count: 12).write(to: repoDirectory.appendingPathComponent("processor_config.json")) + try Data(repeating: 0x44, count: 8).write(to: repoDirectory.appendingPathComponent("model.safetensors.index.json")) + + let discovered = LocalModelResolver.discoverModels(in: base) + let model = try XCTUnwrap(discovered.first) + + XCTAssertEqual(model.repoId, "example/vision-model") + XCTAssertEqual(model.contextLength, 262144) + XCTAssertTrue(model.supportsImages) + XCTAssertEqual(model.loaderKinds, [.vlm, .llm]) + } + + func testMergedCatalogKeepsCuratedModelsAndAddsCustomLocalModels() { + let localModels = [ + LocalModelResolver.LocalModelInfo( + repoId: "mlx-community/gemma-3-4b-it-4bit", + directory: URL(fileURLWithPath: "/tmp/gemma"), + sizeBytes: 1024, + contextLength: 128000, + loaderKinds: [.vlm, .llm], + supportsImages: true + ), + LocalModelResolver.LocalModelInfo( + repoId: "custom-org/custom-model", + directory: URL(fileURLWithPath: "/tmp/custom"), + sizeBytes: 2048, + contextLength: 65536, + loaderKinds: [.llm, .vlm], + supportsImages: false + ), + ] + + let merged = ModelConfig.mergedModels(localModels: localModels) + let gemma = merged.first(where: { $0.id == "gemma" }) + let custom = merged.first(where: { $0.repoId == "custom-org/custom-model" }) + + XCTAssertEqual(gemma?.localSizeBytes, 1024) + XCTAssertEqual(custom?.id, "custom-org/custom-model") + XCTAssertEqual(custom?.contextLength, 65536) + XCTAssertFalse(custom?.isCurated ?? true) + } + + func testResolveUnknownRepoIdCreatesRemoteCustomConfig() throws { + let config = try XCTUnwrap(ModelConfig.resolve("custom-owner/custom-repo")) + + XCTAssertEqual(config.id, "custom-owner/custom-repo") + XCTAssertEqual(config.repoId, "custom-owner/custom-repo") + XCTAssertFalse(config.isCurated) + } + + func testMergedCatalogAppliesSavedMetadataOverride() { + let repoId = "custom-org/override-model" + Preferences.setModelMetadataOverride( + ModelMetadataOverride( + contextLength: 123456, + primaryLoaderKind: .vlm, + supportsImages: true, + supportsTools: true + ), + forRepoId: repoId + ) + defer { + Preferences.removeModelMetadataOverride(forRepoId: repoId) + } + + let localModels = [ + LocalModelResolver.LocalModelInfo( + repoId: repoId, + directory: URL(fileURLWithPath: "/tmp/custom-override"), + sizeBytes: 2048, + contextLength: 65536, + loaderKinds: [.llm, .vlm], + supportsImages: false + ), + ] + + let merged = ModelConfig.mergedModels(localModels: localModels) + let overridden = merged.first(where: { $0.repoId == repoId }) + + XCTAssertEqual(overridden?.contextLength, 123456) + XCTAssertEqual(overridden?.primaryLoaderKind, .vlm) + XCTAssertTrue(overridden?.supportsImages ?? false) + XCTAssertTrue(overridden?.supportsTools ?? false) + } + + func testResolveUnknownRepoIdUsesSavedMetadataOverride() throws { + let repoId = "custom-owner/custom-repo-with-override" + Preferences.setModelMetadataOverride( + ModelMetadataOverride( + contextLength: 8192, + primaryLoaderKind: .llm, + supportsImages: false, + supportsTools: true + ), + forRepoId: repoId + ) + defer { + Preferences.removeModelMetadataOverride(forRepoId: repoId) + } + + let config = try XCTUnwrap(ModelConfig.resolve(repoId)) + + XCTAssertEqual(config.contextLength, 8192) + XCTAssertEqual(config.primaryLoaderKind, .llm) + XCTAssertFalse(config.supportsImages) + XCTAssertTrue(config.supportsTools) + } + + private func makeTempModelsRoot() throws -> URL { + let root = FileManager.default.temporaryDirectory + .appendingPathComponent(UUID().uuidString, isDirectory: true) + try FileManager.default.createDirectory(at: root, withIntermediateDirectories: true) + addTeardownBlock { + try? FileManager.default.removeItem(at: root) + } + return root + } + + private func makeRepoDirectory(base: URL, owner: String, repo: String) throws -> URL { + let directory = base + .appendingPathComponent(owner, isDirectory: true) + .appendingPathComponent(repo, isDirectory: true) + try FileManager.default.createDirectory(at: directory, withIntermediateDirectories: true) + return directory + } + + private func writeJSON(_ object: Any, to url: URL) throws { + let data = try JSONSerialization.data(withJSONObject: object, options: [.prettyPrinted, .sortedKeys]) + try data.write(to: url) + } +} \ No newline at end of file diff --git a/README.md b/README.md index 2cf06ed..5fe6b51 100644 --- a/README.md +++ b/README.md @@ -44,7 +44,8 @@ This is intended for targeted validation while keeping the normal default as the - **Chat interface** with markdown rendering and model-aware image attachments (file picker, drag & drop, clipboard paste, Finder copy-paste on vision-capable models) - **Scene-based chat starts** — New Chat opens a scene picker with Neutral plus saved scenes, each with an optional model override, a scene prompt layered onto the base system prompt, an auto-sent starter prompt, and optional generation-setting overrides for chat-specific behavior -- **Model picker** in toolbar with local/download status indicators and re-download button +- **Model picker** in toolbar with curated defaults plus any locally discovered MLX models on disk +- **Models window** in the menu for downloading a model by HuggingFace ID, inspecting on-disk model sizes, and deleting local model folders - **Download progress modal** — shows file progress, percentage, and speed when downloading a new model - **Thinking mode** — models like Qwen3.5 can reason internally before responding; thinking content appears in a collapsible box. Toggle on/off in Settings. - **Streaming responses** with live token display