feat: added model management
This commit is contained in:
@@ -23,18 +23,21 @@
|
||||
2E3A02DF9C6A5109E532D5E2 /* ChatDocumentController.swift in Sources */ = {isa = PBXBuildFile; fileRef = D5C1FCEFEA72B9ABB87FB20E /* ChatDocumentController.swift */; };
|
||||
3A9DB84947BBBBED06CF9E1E /* TestImageFixtures.swift in Sources */ = {isa = PBXBuildFile; fileRef = 31BD930DEC051408444C30D4 /* TestImageFixtures.swift */; };
|
||||
4158FA884D981D73288FB74C /* SaveChatCommands.swift in Sources */ = {isa = PBXBuildFile; fileRef = 2E2FCA55CEBEBCED78D9479A /* SaveChatCommands.swift */; };
|
||||
4B7449F57226CB48C4F5EEBD /* LocalModelResolverTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = 43315501A5AFC0EA014F44F5 /* LocalModelResolverTests.swift */; };
|
||||
4CB13DC1AC7A500DDBB443EC /* ChatInputView.swift in Sources */ = {isa = PBXBuildFile; fileRef = E5E6AD02CDF23BDAB64700A7 /* ChatInputView.swift */; };
|
||||
4DC033E45880B2948B47DEB1 /* FocusedValues.swift in Sources */ = {isa = PBXBuildFile; fileRef = EF518FEBF3A38E830E3CE1A5 /* FocusedValues.swift */; };
|
||||
50B6861FF8610B3ED4FFAD9D /* MLXServerApp.swift in Sources */ = {isa = PBXBuildFile; fileRef = C67742651DB486871CEF1612 /* MLXServerApp.swift */; };
|
||||
50DD129CCF2843482DEC3B96 /* APIServer.swift in Sources */ = {isa = PBXBuildFile; fileRef = 3D08828E16B17EF02C14243E /* APIServer.swift */; };
|
||||
5946258F1DE88CE904584E0B /* ContentView.swift in Sources */ = {isa = PBXBuildFile; fileRef = 944C699FBB76C734C9DF2F2E /* ContentView.swift */; };
|
||||
5C1E8FE1C521914CEF98D3AA /* ChatMessagesView.swift in Sources */ = {isa = PBXBuildFile; fileRef = DB1A5E8B1C9F2BC4D262C53A /* ChatMessagesView.swift */; };
|
||||
5D41C2B260265A32FF42264B /* ModelManagementView.swift in Sources */ = {isa = PBXBuildFile; fileRef = 8A1B8F9568F95E07D212A2B7 /* ModelManagementView.swift */; };
|
||||
621B7E4382199AC1378F5F9C /* StatusBarView.swift in Sources */ = {isa = PBXBuildFile; fileRef = B0EAB35D7130D56B9E7484BA /* StatusBarView.swift */; };
|
||||
67262C5E24739F1FE0011439 /* StreamingSSEEncoder.swift in Sources */ = {isa = PBXBuildFile; fileRef = 615F8A7C9ABCADEB215D31BD /* StreamingSSEEncoder.swift */; };
|
||||
67B815DC3304BF4B2E9974A8 /* LiveCountersTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = 7E7DF9F68C10C718844B7B01 /* LiveCountersTests.swift */; };
|
||||
67D0628F148FE3C2200E0AEF /* APIServerResponseResolutionTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = 051FEC14CC76A677F79ACD21 /* APIServerResponseResolutionTests.swift */; };
|
||||
6828CCA8B78AB40906F87CAB /* LocalModelResolver.swift in Sources */ = {isa = PBXBuildFile; fileRef = D733A0D1D4AC25DDDA6C8684 /* LocalModelResolver.swift */; };
|
||||
741692862DB1F13EA0B2D14D /* TokenPrefixCache.swift in Sources */ = {isa = PBXBuildFile; fileRef = 1962D530BEABCC7F1E8E0ED1 /* TokenPrefixCache.swift */; };
|
||||
75E046B4ABB1E6FEF17C1A60 /* ModelManagementWindow.swift in Sources */ = {isa = PBXBuildFile; fileRef = 721D6F203A10434FE0223042 /* ModelManagementWindow.swift */; };
|
||||
7936325B425DFA2931F6E421 /* ModelBackedQuantizationTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = F7E6F18C80D9859E89D2B4E3 /* ModelBackedQuantizationTests.swift */; };
|
||||
7CD765C1E2F9F4D7504C8D09 /* Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = B629DA084A9A40E54F8EA5FA /* Assets.xcassets */; };
|
||||
80646C5066BF79BC76E1D9D7 /* ModelConfig.swift in Sources */ = {isa = PBXBuildFile; fileRef = 38DFC212AF4359A45FBE22BA /* ModelConfig.swift */; };
|
||||
@@ -100,6 +103,7 @@
|
||||
3D08828E16B17EF02C14243E /* APIServer.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = APIServer.swift; sourceTree = "<group>"; };
|
||||
4147321383E94E9F17A0154E /* SettingsView.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = SettingsView.swift; sourceTree = "<group>"; };
|
||||
4239CFF94B819C35A8D4D617 /* MonitorView.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = MonitorView.swift; sourceTree = "<group>"; };
|
||||
43315501A5AFC0EA014F44F5 /* LocalModelResolverTests.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = LocalModelResolverTests.swift; sourceTree = "<group>"; };
|
||||
49C383DD5224F3420EB98DB2 /* StreamingSSEEncoderTests.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = StreamingSSEEncoderTests.swift; sourceTree = "<group>"; };
|
||||
57AC0815F72BDD32FC54C88A /* GenerationSettingsTests.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = GenerationSettingsTests.swift; sourceTree = "<group>"; };
|
||||
5F9426FA5A4AC55F8D9C080E /* PromptBuilderTests.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = PromptBuilderTests.swift; sourceTree = "<group>"; };
|
||||
@@ -108,9 +112,11 @@
|
||||
6B3AA91D2C7842D7366F9A41 /* ChatDocumentPackage.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ChatDocumentPackage.swift; sourceTree = "<group>"; };
|
||||
6EE59189918D06B8D2F588FC /* MLXServer.app */ = {isa = PBXFileReference; explicitFileType = wrapper.application; includeInIndex = 0; path = MLXServer.app; sourceTree = BUILT_PRODUCTS_DIR; };
|
||||
6FAF7455BD387CD2061E0CBF /* GenerationSettings.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = GenerationSettings.swift; sourceTree = "<group>"; };
|
||||
721D6F203A10434FE0223042 /* ModelManagementWindow.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ModelManagementWindow.swift; sourceTree = "<group>"; };
|
||||
7AE2A32FBB744696DEA77435 /* GenerationSettingsEditor.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = GenerationSettingsEditor.swift; sourceTree = "<group>"; };
|
||||
7C1A89C076E717F87A60397D /* ImageDecoder.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ImageDecoder.swift; sourceTree = "<group>"; };
|
||||
7E7DF9F68C10C718844B7B01 /* LiveCountersTests.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = LiveCountersTests.swift; sourceTree = "<group>"; };
|
||||
8A1B8F9568F95E07D212A2B7 /* ModelManagementView.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ModelManagementView.swift; sourceTree = "<group>"; };
|
||||
922CBDC9206737BD04AF2874 /* ModelManager.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ModelManager.swift; sourceTree = "<group>"; };
|
||||
944C699FBB76C734C9DF2F2E /* ContentView.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ContentView.swift; sourceTree = "<group>"; };
|
||||
A4B359324B5FD8D106C74338 /* ChatMessage.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ChatMessage.swift; sourceTree = "<group>"; };
|
||||
@@ -199,6 +205,7 @@
|
||||
57AC0815F72BDD32FC54C88A /* GenerationSettingsTests.swift */,
|
||||
E4573DC9314915F4C7963B4E /* ImageDecoderTests.swift */,
|
||||
7E7DF9F68C10C718844B7B01 /* LiveCountersTests.swift */,
|
||||
43315501A5AFC0EA014F44F5 /* LocalModelResolverTests.swift */,
|
||||
D388BE00B42C06ED9D9905BF /* ModelBackedInferenceValidationTests.swift */,
|
||||
F7E6F18C80D9859E89D2B4E3 /* ModelBackedQuantizationTests.swift */,
|
||||
5F9426FA5A4AC55F8D9C080E /* PromptBuilderTests.swift */,
|
||||
@@ -246,6 +253,8 @@
|
||||
DB1A5E8B1C9F2BC4D262C53A /* ChatMessagesView.swift */,
|
||||
2DC8C86D397B1FCA08E07CBD /* DownloadModalView.swift */,
|
||||
7AE2A32FBB744696DEA77435 /* GenerationSettingsEditor.swift */,
|
||||
8A1B8F9568F95E07D212A2B7 /* ModelManagementView.swift */,
|
||||
721D6F203A10434FE0223042 /* ModelManagementWindow.swift */,
|
||||
C3C3A76C02AF70A9D8F868FC /* ModelPickerView.swift */,
|
||||
4239CFF94B819C35A8D4D617 /* MonitorView.swift */,
|
||||
37FEB592E5E717F817B03151 /* SceneManagementView.swift */,
|
||||
@@ -416,6 +425,7 @@
|
||||
847B445654860396AF5A8280 /* GenerationSettingsTests.swift in Sources */,
|
||||
E92B6656C251EDA246B8F582 /* ImageDecoderTests.swift in Sources */,
|
||||
67B815DC3304BF4B2E9974A8 /* LiveCountersTests.swift in Sources */,
|
||||
4B7449F57226CB48C4F5EEBD /* LocalModelResolverTests.swift in Sources */,
|
||||
8E665E21CCCD87A907CEA78D /* ModelBackedInferenceValidationTests.swift in Sources */,
|
||||
7936325B425DFA2931F6E421 /* ModelBackedQuantizationTests.swift in Sources */,
|
||||
1FE8C624898960ECCE39C0D4 /* PromptBuilderTests.swift in Sources */,
|
||||
@@ -455,6 +465,8 @@
|
||||
6828CCA8B78AB40906F87CAB /* LocalModelResolver.swift in Sources */,
|
||||
50B6861FF8610B3ED4FFAD9D /* MLXServerApp.swift in Sources */,
|
||||
80646C5066BF79BC76E1D9D7 /* ModelConfig.swift in Sources */,
|
||||
5D41C2B260265A32FF42264B /* ModelManagementView.swift in Sources */,
|
||||
75E046B4ABB1E6FEF17C1A60 /* ModelManagementWindow.swift in Sources */,
|
||||
0168AEE16009097901363E16 /* ModelManager.swift in Sources */,
|
||||
2CAAF7129F7CC45200FA9F6B /* ModelPickerView.swift in Sources */,
|
||||
B1D9BC407DB7DB1489230C20 /* MonitorView.swift in Sources */,
|
||||
|
||||
@@ -12,3 +12,15 @@ struct SceneCommands: Commands {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct ModelCommands: Commands {
|
||||
@Environment(\.openWindow) private var openWindow
|
||||
|
||||
var body: some Commands {
|
||||
CommandMenu("Models") {
|
||||
Button("Manage Models…") {
|
||||
openWindow(id: ModelManagementWindow.windowID)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -29,6 +29,7 @@ struct ContentView: View {
|
||||
AnyView(mainContent)
|
||||
.navigationTitle(navigationTitleText)
|
||||
.onAppear {
|
||||
modelManager.refreshAvailableModels()
|
||||
if chatVM == nil {
|
||||
let vm = ChatViewModel(modelManager: modelManager)
|
||||
chatVM = vm
|
||||
@@ -68,7 +69,7 @@ struct ContentView: View {
|
||||
AnyView(lifecycleContent)
|
||||
.alert("Model Error", isPresented: $showLoadError) {
|
||||
Button("Retry") {
|
||||
if let config = modelManager.currentModel ?? ModelConfig.availableModels.first {
|
||||
if let config = modelManager.currentModel ?? modelManager.availableModels.first {
|
||||
Task { await modelManager.loadModel(config) }
|
||||
}
|
||||
}
|
||||
@@ -228,7 +229,7 @@ struct ContentView: View {
|
||||
|
||||
@ViewBuilder
|
||||
private var modelSwitchShortcuts: some View {
|
||||
ForEach(Array(ModelConfig.availableModels.enumerated()), id: \.element.id) { index, config in
|
||||
ForEach(Array(ModelConfig.curatedModels.enumerated()), id: \.element.id) { index, config in
|
||||
if index < 9 {
|
||||
Button("") {
|
||||
Task { await modelManager.loadModel(config) }
|
||||
@@ -420,7 +421,7 @@ struct ContentView: View {
|
||||
guard modelManager.currentModel == nil else { return }
|
||||
|
||||
let modelId = Preferences.defaultModelId ?? Preferences.lastModelId ?? ModelConfig.default.id
|
||||
if let config = ModelConfig.availableModels.first(where: { $0.id == modelId }) {
|
||||
if let config = ModelConfig.resolve(modelId) {
|
||||
await modelManager.loadModel(config)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -52,17 +52,26 @@ struct MLXServerApp: App {
|
||||
.commands {
|
||||
SaveChatCommands()
|
||||
SceneCommands()
|
||||
ModelCommands()
|
||||
}
|
||||
|
||||
Window("Scenes", id: SceneManagementWindow.windowID) {
|
||||
SceneManagementView()
|
||||
.environment(modelManager)
|
||||
.environment(sceneStore)
|
||||
}
|
||||
.defaultSize(width: 900, height: 560)
|
||||
|
||||
Window("Models", id: ModelManagementWindow.windowID) {
|
||||
ModelManagementView()
|
||||
.environment(modelManager)
|
||||
}
|
||||
.defaultSize(width: 900, height: 620)
|
||||
|
||||
#if os(macOS)
|
||||
Settings {
|
||||
SettingsView()
|
||||
.environment(modelManager)
|
||||
.environment(sceneStore)
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -53,7 +53,7 @@ struct ChatScene: Codable, Identifiable, Hashable {
|
||||
|
||||
var resolvedModel: ModelConfig? {
|
||||
guard let modelId else { return nil }
|
||||
return ModelConfig.availableModels.first(where: { $0.id == modelId })
|
||||
return ModelConfig.resolve(modelId)
|
||||
}
|
||||
|
||||
static let empty = ChatScene(name: "New Scene")
|
||||
|
||||
@@ -1,30 +1,81 @@
|
||||
import Foundation
|
||||
import MLXLMCommon
|
||||
|
||||
struct ModelMetadataOverride: Codable, Hashable, Sendable {
|
||||
var contextLength: Int
|
||||
var primaryLoaderKind: ModelConfig.LoaderKind
|
||||
var supportsImages: Bool
|
||||
var supportsTools: Bool
|
||||
|
||||
func normalized() -> ModelMetadataOverride {
|
||||
ModelMetadataOverride(
|
||||
contextLength: max(0, contextLength),
|
||||
primaryLoaderKind: primaryLoaderKind,
|
||||
supportsImages: supportsImages,
|
||||
supportsTools: supportsTools
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// Defines a supported model with its metadata.
|
||||
struct ModelConfig: Identifiable, Hashable {
|
||||
enum LoaderKind: Hashable {
|
||||
enum LoaderKind: String, CaseIterable, Codable, Hashable, Sendable {
|
||||
case llm
|
||||
case vlm
|
||||
|
||||
var displayName: String {
|
||||
switch self {
|
||||
case .llm:
|
||||
return "Text"
|
||||
case .vlm:
|
||||
return "Vision"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let id: String // alias: "gemma", "gemma3n", "qwen"
|
||||
let repoId: String // HuggingFace ID
|
||||
let displayName: String
|
||||
let contextLength: Int
|
||||
let loaderKind: LoaderKind
|
||||
let loaderKinds: [LoaderKind]
|
||||
let supportsImages: Bool
|
||||
let supportsTools: Bool
|
||||
let defaultGenerationSettings: GenerationSettings
|
||||
let isCurated: Bool
|
||||
let localSizeBytes: Int64?
|
||||
|
||||
/// All models supported by the app.
|
||||
static let availableModels: [ModelConfig] = [
|
||||
init(
|
||||
id: String,
|
||||
repoId: String,
|
||||
displayName: String,
|
||||
contextLength: Int,
|
||||
loaderKinds: [LoaderKind],
|
||||
supportsImages: Bool,
|
||||
supportsTools: Bool,
|
||||
defaultGenerationSettings: GenerationSettings,
|
||||
isCurated: Bool = true,
|
||||
localSizeBytes: Int64? = nil
|
||||
) {
|
||||
self.id = id
|
||||
self.repoId = repoId
|
||||
self.displayName = displayName
|
||||
self.contextLength = contextLength
|
||||
self.loaderKinds = loaderKinds
|
||||
self.supportsImages = supportsImages
|
||||
self.supportsTools = supportsTools
|
||||
self.defaultGenerationSettings = defaultGenerationSettings
|
||||
self.isCurated = isCurated
|
||||
self.localSizeBytes = localSizeBytes
|
||||
}
|
||||
|
||||
/// Curated models supported and tuned by the app.
|
||||
static let curatedModels: [ModelConfig] = [
|
||||
ModelConfig(
|
||||
id: "gemma",
|
||||
repoId: "mlx-community/gemma-3-4b-it-4bit",
|
||||
displayName: "Gemma 3 4B",
|
||||
contextLength: 128_000,
|
||||
loaderKind: .vlm,
|
||||
loaderKinds: [.vlm],
|
||||
supportsImages: true,
|
||||
supportsTools: true,
|
||||
defaultGenerationSettings: .technicalDefault
|
||||
@@ -34,7 +85,7 @@ struct ModelConfig: Identifiable, Hashable {
|
||||
repoId: "mlx-community/Qwen3.5-4B-MLX-4bit",
|
||||
displayName: "Qwen3.5 4B",
|
||||
contextLength: 256_000,
|
||||
loaderKind: .vlm,
|
||||
loaderKinds: [.vlm],
|
||||
supportsImages: true,
|
||||
supportsTools: true,
|
||||
defaultGenerationSettings: .technicalDefault
|
||||
@@ -44,7 +95,7 @@ struct ModelConfig: Identifiable, Hashable {
|
||||
repoId: "mlx-community/Qwen3.5-0.8B-4bit",
|
||||
displayName: "Qwen3.5 0.8B",
|
||||
contextLength: 256_000,
|
||||
loaderKind: .vlm,
|
||||
loaderKinds: [.vlm],
|
||||
supportsImages: true,
|
||||
supportsTools: true,
|
||||
defaultGenerationSettings: .technicalDefault
|
||||
@@ -54,7 +105,7 @@ struct ModelConfig: Identifiable, Hashable {
|
||||
repoId: "mlx-community/Qwen3.5-9B-4bit",
|
||||
displayName: "Qwen3.5 9B",
|
||||
contextLength: 256_000,
|
||||
loaderKind: .vlm,
|
||||
loaderKinds: [.vlm],
|
||||
supportsImages: true,
|
||||
supportsTools: true,
|
||||
defaultGenerationSettings: .technicalDefault
|
||||
@@ -64,7 +115,7 @@ struct ModelConfig: Identifiable, Hashable {
|
||||
repoId: "synk/L3-8B-Stheno-v3.2-MLX",
|
||||
displayName: "Stheno L3 8B",
|
||||
contextLength: 8_192,
|
||||
loaderKind: .llm,
|
||||
loaderKinds: [.llm],
|
||||
supportsImages: false,
|
||||
supportsTools: false,
|
||||
defaultGenerationSettings: .roleplayDefault
|
||||
@@ -74,18 +125,35 @@ struct ModelConfig: Identifiable, Hashable {
|
||||
repoId: "hobaratio/MN-Violet-Lotus-12B-mlx-4Bit",
|
||||
displayName: "Violet Lotus 12B",
|
||||
contextLength: 32_768,
|
||||
loaderKind: .llm,
|
||||
loaderKinds: [.llm],
|
||||
supportsImages: false,
|
||||
supportsTools: false,
|
||||
defaultGenerationSettings: .roleplayDefault
|
||||
),
|
||||
]
|
||||
|
||||
static let `default` = availableModels[0]
|
||||
static var availableModels: [ModelConfig] {
|
||||
mergedModels(localModels: LocalModelResolver.discoveredLocalModels())
|
||||
}
|
||||
|
||||
static let `default` = curatedModels[0]
|
||||
|
||||
/// Whether this model is cached locally (no download needed).
|
||||
var isLocal: Bool {
|
||||
LocalModelResolver.isAvailable(repoId: repoId)
|
||||
localSizeBytes != nil || LocalModelResolver.isAvailable(repoId: repoId)
|
||||
}
|
||||
|
||||
var primaryLoaderKind: LoaderKind {
|
||||
loaderKinds.first ?? .llm
|
||||
}
|
||||
|
||||
var metadataOverrideValue: ModelMetadataOverride {
|
||||
ModelMetadataOverride(
|
||||
contextLength: contextLength,
|
||||
primaryLoaderKind: primaryLoaderKind,
|
||||
supportsImages: supportsImages,
|
||||
supportsTools: supportsTools
|
||||
)
|
||||
}
|
||||
|
||||
/// Build a ModelConfiguration for mlx-swift-lm from this config.
|
||||
@@ -96,6 +164,9 @@ struct ModelConfig: Identifiable, Hashable {
|
||||
/// Resolve a model string (alias, full repo ID, or partial match) to a ModelConfig.
|
||||
/// Mirrors the Python server's `ModelManager.resolve_model()`.
|
||||
static func resolve(_ requested: String) -> ModelConfig? {
|
||||
let requested = requested.trimmingCharacters(in: .whitespacesAndNewlines)
|
||||
guard !requested.isEmpty else { return nil }
|
||||
|
||||
// Exact alias match
|
||||
if let config = availableModels.first(where: { $0.id == requested }) {
|
||||
return config
|
||||
@@ -108,6 +179,129 @@ struct ModelConfig: Identifiable, Hashable {
|
||||
if let config = availableModels.first(where: { requested.contains($0.id) || $0.repoId.contains(requested) || requested.contains($0.repoId) }) {
|
||||
return config
|
||||
}
|
||||
if requested.contains("/") {
|
||||
return remoteCustom(repoId: requested)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
static func mergedModels(
|
||||
localModels: [LocalModelResolver.LocalModelInfo],
|
||||
applyingOverrides: Bool = true
|
||||
) -> [ModelConfig] {
|
||||
let localByRepo = Dictionary(uniqueKeysWithValues: localModels.map { ($0.repoId, $0) })
|
||||
let curatedRepoIds = Set(curatedModels.map(\.repoId))
|
||||
|
||||
let curated = curatedModels.map { config in
|
||||
if let local = localByRepo[config.repoId] {
|
||||
return applyingOverrides ? applyMetadataOverrideIfNeeded(to: config.withLocalSize(local.sizeBytes)) : config.withLocalSize(local.sizeBytes)
|
||||
}
|
||||
return applyingOverrides ? applyMetadataOverrideIfNeeded(to: config) : config
|
||||
}
|
||||
|
||||
let discoveredCustom = localModels
|
||||
.filter { !curatedRepoIds.contains($0.repoId) }
|
||||
.map(customLocal)
|
||||
.sorted { lhs, rhs in
|
||||
lhs.displayName.localizedCaseInsensitiveCompare(rhs.displayName) == .orderedAscending
|
||||
}
|
||||
|
||||
return curated + discoveredCustom
|
||||
}
|
||||
|
||||
static func baselineModel(
|
||||
forRepoId repoId: String,
|
||||
localModels: [LocalModelResolver.LocalModelInfo]
|
||||
) -> ModelConfig? {
|
||||
mergedModels(localModels: localModels, applyingOverrides: false)
|
||||
.first(where: { $0.repoId == repoId || $0.id == repoId })
|
||||
?? (repoId.contains("/") ? remoteCustom(repoId: repoId) : nil)
|
||||
}
|
||||
|
||||
static func remoteCustom(repoId: String) -> ModelConfig {
|
||||
let supportsImages = inferredVisionSupport(repoId: repoId)
|
||||
return applyMetadataOverrideIfNeeded(to: ModelConfig(
|
||||
id: repoId,
|
||||
repoId: repoId,
|
||||
displayName: displayName(for: repoId),
|
||||
contextLength: 0,
|
||||
loaderKinds: supportsImages ? [.vlm, .llm] : [.llm, .vlm],
|
||||
supportsImages: supportsImages,
|
||||
supportsTools: inferredToolSupport(repoId: repoId),
|
||||
defaultGenerationSettings: .generalDefault,
|
||||
isCurated: false
|
||||
))
|
||||
}
|
||||
|
||||
static func displayName(for repoId: String) -> String {
|
||||
let raw = repoId.split(separator: "/").last.map(String.init) ?? repoId
|
||||
return raw
|
||||
.replacingOccurrences(of: "-", with: " ")
|
||||
.replacingOccurrences(of: "_", with: " ")
|
||||
}
|
||||
|
||||
private static func customLocal(_ local: LocalModelResolver.LocalModelInfo) -> ModelConfig {
|
||||
applyMetadataOverrideIfNeeded(to: ModelConfig(
|
||||
id: local.repoId,
|
||||
repoId: local.repoId,
|
||||
displayName: displayName(for: local.repoId),
|
||||
contextLength: local.contextLength,
|
||||
loaderKinds: local.loaderKinds,
|
||||
supportsImages: local.supportsImages,
|
||||
supportsTools: inferredToolSupport(repoId: local.repoId),
|
||||
defaultGenerationSettings: .generalDefault,
|
||||
isCurated: false,
|
||||
localSizeBytes: local.sizeBytes
|
||||
))
|
||||
}
|
||||
|
||||
private static func inferredToolSupport(repoId: String) -> Bool {
|
||||
let normalized = repoId.lowercased()
|
||||
return normalized.contains("qwen") || normalized.contains("gemma")
|
||||
}
|
||||
|
||||
private static func inferredVisionSupport(repoId: String) -> Bool {
|
||||
let normalized = repoId.lowercased()
|
||||
return normalized.contains("vision") || normalized.contains("vl") || normalized.contains("gemma-3") || normalized.contains("qwen")
|
||||
}
|
||||
|
||||
private func withLocalSize(_ sizeBytes: Int64) -> ModelConfig {
|
||||
ModelConfig(
|
||||
id: id,
|
||||
repoId: repoId,
|
||||
displayName: displayName,
|
||||
contextLength: contextLength,
|
||||
loaderKinds: loaderKinds,
|
||||
supportsImages: supportsImages,
|
||||
supportsTools: supportsTools,
|
||||
defaultGenerationSettings: defaultGenerationSettings,
|
||||
isCurated: isCurated,
|
||||
localSizeBytes: sizeBytes
|
||||
)
|
||||
}
|
||||
|
||||
private func applyingMetadataOverride(_ override: ModelMetadataOverride) -> ModelConfig {
|
||||
let normalized = override.normalized()
|
||||
let reorderedLoaderKinds = [normalized.primaryLoaderKind] + LoaderKind.allCases.filter { $0 != normalized.primaryLoaderKind }
|
||||
|
||||
return ModelConfig(
|
||||
id: id,
|
||||
repoId: repoId,
|
||||
displayName: displayName,
|
||||
contextLength: normalized.contextLength,
|
||||
loaderKinds: reorderedLoaderKinds,
|
||||
supportsImages: normalized.supportsImages,
|
||||
supportsTools: normalized.supportsTools,
|
||||
defaultGenerationSettings: defaultGenerationSettings,
|
||||
isCurated: isCurated,
|
||||
localSizeBytes: localSizeBytes
|
||||
)
|
||||
}
|
||||
|
||||
private static func applyMetadataOverrideIfNeeded(to config: ModelConfig) -> ModelConfig {
|
||||
guard let override = Preferences.modelMetadataOverride(forRepoId: config.repoId) else {
|
||||
return config
|
||||
}
|
||||
return config.applyingMetadataOverride(override)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,6 +6,17 @@ import Foundation
|
||||
/// ~/Library/Containers/de.rfc1437.mlxserver/Data/Library/Caches/models/{org}/{name}/
|
||||
enum LocalModelResolver {
|
||||
|
||||
struct LocalModelInfo: Identifiable, Hashable {
|
||||
let repoId: String
|
||||
let directory: URL
|
||||
let sizeBytes: Int64
|
||||
let contextLength: Int
|
||||
let loaderKinds: [ModelConfig.LoaderKind]
|
||||
let supportsImages: Bool
|
||||
|
||||
var id: String { repoId }
|
||||
}
|
||||
|
||||
/// Base directory where HubApi stores downloaded models.
|
||||
private static let modelsBase: URL? = {
|
||||
FileManager.default.urls(for: .cachesDirectory, in: .userDomainMask).first?
|
||||
@@ -31,6 +42,46 @@ enum LocalModelResolver {
|
||||
resolve(repoId: repoId) != nil
|
||||
}
|
||||
|
||||
static func discoveredLocalModels() -> [LocalModelInfo] {
|
||||
guard let base = modelsBase else { return [] }
|
||||
return discoverModels(in: base)
|
||||
}
|
||||
|
||||
static func discoverModels(in base: URL) -> [LocalModelInfo] {
|
||||
let fileManager = FileManager.default
|
||||
let directoryKeys: Set<URLResourceKey> = [.isDirectoryKey]
|
||||
guard let ownerDirectories = try? fileManager.contentsOfDirectory(
|
||||
at: base,
|
||||
includingPropertiesForKeys: Array(directoryKeys),
|
||||
options: [.skipsHiddenFiles]
|
||||
) else {
|
||||
return []
|
||||
}
|
||||
|
||||
var discovered: [LocalModelInfo] = []
|
||||
|
||||
for ownerDirectory in ownerDirectories {
|
||||
guard isDirectory(ownerDirectory) else { continue }
|
||||
guard let repoDirectories = try? fileManager.contentsOfDirectory(
|
||||
at: ownerDirectory,
|
||||
includingPropertiesForKeys: Array(directoryKeys),
|
||||
options: [.skipsHiddenFiles]
|
||||
) else {
|
||||
continue
|
||||
}
|
||||
|
||||
for repoDirectory in repoDirectories where isDirectory(repoDirectory) {
|
||||
if let info = localModelInfo(ownerDirectory: ownerDirectory, repoDirectory: repoDirectory) {
|
||||
discovered.append(info)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return discovered.sorted {
|
||||
$0.repoId.localizedCaseInsensitiveCompare($1.repoId) == .orderedAscending
|
||||
}
|
||||
}
|
||||
|
||||
/// Delete the local cache for a model so it will be re-downloaded next time.
|
||||
@discardableResult
|
||||
static func deleteLocal(repoId: String) -> Bool {
|
||||
@@ -46,4 +97,122 @@ enum LocalModelResolver {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
private static func localModelInfo(ownerDirectory: URL, repoDirectory: URL) -> LocalModelInfo? {
|
||||
let repoId = "\(ownerDirectory.lastPathComponent)/\(repoDirectory.lastPathComponent)"
|
||||
guard containsModelArtifacts(at: repoDirectory) else { return nil }
|
||||
|
||||
let config = readJSONObject(at: repoDirectory.appendingPathComponent("config.json"))
|
||||
let tokenizerConfig = readJSONObject(at: repoDirectory.appendingPathComponent("tokenizer_config.json"))
|
||||
let supportsImages = inferredSupportsImages(
|
||||
repoDirectory: repoDirectory,
|
||||
config: config,
|
||||
tokenizerConfig: tokenizerConfig
|
||||
)
|
||||
let sizeBytes = directorySize(at: repoDirectory)
|
||||
let contextLength = inferredContextLength(config: config, tokenizerConfig: tokenizerConfig)
|
||||
let loaderKinds: [ModelConfig.LoaderKind] = supportsImages ? [.vlm, .llm] : [.llm, .vlm]
|
||||
|
||||
return LocalModelInfo(
|
||||
repoId: repoId,
|
||||
directory: repoDirectory,
|
||||
sizeBytes: sizeBytes,
|
||||
contextLength: contextLength,
|
||||
loaderKinds: loaderKinds,
|
||||
supportsImages: supportsImages
|
||||
)
|
||||
}
|
||||
|
||||
private static func containsModelArtifacts(at directory: URL) -> Bool {
|
||||
let requiredPaths = [
|
||||
directory.appendingPathComponent("config.json").path,
|
||||
directory.appendingPathComponent("model.safetensors").path,
|
||||
directory.appendingPathComponent("model.safetensors.index.json").path,
|
||||
]
|
||||
return requiredPaths.contains { FileManager.default.fileExists(atPath: $0) }
|
||||
}
|
||||
|
||||
private static func isDirectory(_ url: URL) -> Bool {
|
||||
(try? url.resourceValues(forKeys: [.isDirectoryKey]).isDirectory) == true
|
||||
}
|
||||
|
||||
private static func readJSONObject(at url: URL) -> [String: Any]? {
|
||||
guard let data = try? Data(contentsOf: url) else { return nil }
|
||||
return (try? JSONSerialization.jsonObject(with: data)) as? [String: Any]
|
||||
}
|
||||
|
||||
private static func inferredSupportsImages(
|
||||
repoDirectory: URL,
|
||||
config: [String: Any]?,
|
||||
tokenizerConfig: [String: Any]?
|
||||
) -> Bool {
|
||||
if config?["vision_config"] != nil {
|
||||
return true
|
||||
}
|
||||
if tokenizerConfig?["image_token"] != nil {
|
||||
return true
|
||||
}
|
||||
|
||||
let metadataFiles = [
|
||||
"processor_config.json",
|
||||
"preprocessor_config.json",
|
||||
"video_preprocessor_config.json",
|
||||
]
|
||||
return metadataFiles.contains {
|
||||
FileManager.default.fileExists(atPath: repoDirectory.appendingPathComponent($0).path)
|
||||
}
|
||||
}
|
||||
|
||||
private static func inferredContextLength(
|
||||
config: [String: Any]?,
|
||||
tokenizerConfig: [String: Any]?
|
||||
) -> Int {
|
||||
if let value = integerValue(at: ["text_config", "max_position_embeddings"], in: config) {
|
||||
return value
|
||||
}
|
||||
if let value = integerValue(at: ["max_position_embeddings"], in: config) {
|
||||
return value
|
||||
}
|
||||
if let value = integerValue(at: ["model_max_length"], in: tokenizerConfig) {
|
||||
return value
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
private static func integerValue(at path: [String], in json: [String: Any]?) -> Int? {
|
||||
guard let json else { return nil }
|
||||
|
||||
var current: Any = json
|
||||
for component in path {
|
||||
guard let dictionary = current as? [String: Any], let next = dictionary[component] else {
|
||||
return nil
|
||||
}
|
||||
current = next
|
||||
}
|
||||
|
||||
if let number = current as? NSNumber {
|
||||
return number.intValue
|
||||
}
|
||||
return current as? Int
|
||||
}
|
||||
|
||||
private static func directorySize(at directory: URL) -> Int64 {
|
||||
let keys: [URLResourceKey] = [.isRegularFileKey, .fileSizeKey]
|
||||
guard let enumerator = FileManager.default.enumerator(
|
||||
at: directory,
|
||||
includingPropertiesForKeys: keys,
|
||||
options: [.skipsHiddenFiles]
|
||||
) else {
|
||||
return 0
|
||||
}
|
||||
|
||||
var total: Int64 = 0
|
||||
for case let fileURL as URL in enumerator {
|
||||
guard let values = try? fileURL.resourceValues(forKeys: Set(keys)), values.isRegularFile == true else {
|
||||
continue
|
||||
}
|
||||
total += Int64(values.fileSize ?? 0)
|
||||
}
|
||||
return total
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ enum Preferences {
|
||||
private static let jsonEncoder = JSONEncoder()
|
||||
private static let jsonDecoder = JSONDecoder()
|
||||
private static let legacyThinkingDefault = true
|
||||
private static let modelMetadataOverridesKey = "modelMetadataOverrides"
|
||||
|
||||
// MARK: - Last used model
|
||||
|
||||
@@ -118,6 +119,26 @@ enum Preferences {
|
||||
modelGenerationSettingsMap[modelId] != nil
|
||||
}
|
||||
|
||||
static func modelMetadataOverride(forRepoId repoId: String) -> ModelMetadataOverride? {
|
||||
modelMetadataOverridesMap[repoId]?.normalized()
|
||||
}
|
||||
|
||||
static func setModelMetadataOverride(_ override: ModelMetadataOverride, forRepoId repoId: String) {
|
||||
var map = modelMetadataOverridesMap
|
||||
map[repoId] = override.normalized()
|
||||
modelMetadataOverridesMap = map
|
||||
}
|
||||
|
||||
static func removeModelMetadataOverride(forRepoId repoId: String) {
|
||||
var map = modelMetadataOverridesMap
|
||||
map.removeValue(forKey: repoId)
|
||||
modelMetadataOverridesMap = map
|
||||
}
|
||||
|
||||
static func hasModelMetadataOverride(forRepoId repoId: String) -> Bool {
|
||||
modelMetadataOverridesMap[repoId] != nil
|
||||
}
|
||||
|
||||
private static var modelGenerationSettingsMap: [String: GenerationSettings] {
|
||||
get {
|
||||
guard let data = defaults.data(forKey: modelGenerationSettingsKey) else { return [:] }
|
||||
@@ -129,6 +150,17 @@ enum Preferences {
|
||||
}
|
||||
}
|
||||
|
||||
private static var modelMetadataOverridesMap: [String: ModelMetadataOverride] {
|
||||
get {
|
||||
guard let data = defaults.data(forKey: modelMetadataOverridesKey) else { return [:] }
|
||||
return (try? jsonDecoder.decode([String: ModelMetadataOverride].self, from: data)) ?? [:]
|
||||
}
|
||||
set {
|
||||
guard let data = try? jsonEncoder.encode(newValue) else { return }
|
||||
defaults.set(data, forKey: modelMetadataOverridesKey)
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Idle unload
|
||||
|
||||
private static let idleUnloadMinutesKey = "idleUnloadMinutes"
|
||||
|
||||
@@ -559,7 +559,7 @@ final class ChatViewModel {
|
||||
|
||||
if modelManager.currentModel == nil {
|
||||
let modelId = Preferences.defaultModelId ?? Preferences.lastModelId ?? ModelConfig.default.id
|
||||
if let config = ModelConfig.availableModels.first(where: { $0.id == modelId }) {
|
||||
if let config = ModelConfig.resolve(modelId) {
|
||||
await modelManager.loadModel(config)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -19,7 +19,10 @@ final class ModelManager {
|
||||
let cachesDir = FileManager.default.urls(for: .cachesDirectory, in: .userDomainMask).first
|
||||
return HubApi(downloadBase: cachesDir, cache: nil)
|
||||
}()
|
||||
|
||||
var currentModel: ModelConfig?
|
||||
var availableModels: [ModelConfig]
|
||||
private(set) var discoveredLocalModels: [LocalModelResolver.LocalModelInfo] = []
|
||||
var modelContainer: ModelContainer?
|
||||
var isLoading = false
|
||||
var downloadProgress: Double = 0
|
||||
@@ -36,6 +39,50 @@ final class ModelManager {
|
||||
private(set) var lastUsed: Date?
|
||||
private var latestLoadRequestID = UUID()
|
||||
|
||||
init() {
|
||||
availableModels = []
|
||||
refreshAvailableModels()
|
||||
}
|
||||
|
||||
var curatedModels: [ModelConfig] {
|
||||
availableModels.filter(\.isCurated)
|
||||
}
|
||||
|
||||
var localModelsOnDisk: [ModelConfig] {
|
||||
availableModels
|
||||
.filter(\.isLocal)
|
||||
.sorted {
|
||||
$0.displayName.localizedCaseInsensitiveCompare($1.displayName) == .orderedAscending
|
||||
}
|
||||
}
|
||||
|
||||
func refreshAvailableModels() {
|
||||
discoveredLocalModels = LocalModelResolver.discoveredLocalModels()
|
||||
availableModels = ModelConfig.mergedModels(localModels: discoveredLocalModels)
|
||||
|
||||
if let currentModel {
|
||||
self.currentModel = availableModels.first(where: { $0.repoId == currentModel.repoId }) ?? currentModel
|
||||
}
|
||||
}
|
||||
|
||||
func discoveredLocalModelInfo(repoId: String) -> LocalModelResolver.LocalModelInfo? {
|
||||
discoveredLocalModels.first(where: { $0.repoId == repoId })
|
||||
}
|
||||
|
||||
func baselineModel(repoId: String) -> ModelConfig? {
|
||||
ModelConfig.baselineModel(forRepoId: repoId, localModels: discoveredLocalModels)
|
||||
}
|
||||
|
||||
func saveMetadataOverride(_ override: ModelMetadataOverride, for config: ModelConfig) {
|
||||
Preferences.setModelMetadataOverride(override, forRepoId: config.repoId)
|
||||
refreshAvailableModels()
|
||||
}
|
||||
|
||||
func clearMetadataOverride(for config: ModelConfig) {
|
||||
Preferences.removeModelMetadataOverride(forRepoId: config.repoId)
|
||||
refreshAvailableModels()
|
||||
}
|
||||
|
||||
private func clearLoadedState() {
|
||||
idleTimer?.invalidate()
|
||||
idleTimer = nil
|
||||
@@ -55,7 +102,11 @@ final class ModelManager {
|
||||
/// Prefers the local snapshot from ~/.cache/huggingface/hub/ (shared with the Python server).
|
||||
/// Only downloads if the model isn't cached locally.
|
||||
func loadModel(_ config: ModelConfig) async {
|
||||
if currentModel?.id == config.id && modelContainer != nil {
|
||||
refreshAvailableModels()
|
||||
let effectiveConfig = availableModels.first(where: { $0.repoId == config.repoId }) ?? config
|
||||
|
||||
if currentModel?.repoId == effectiveConfig.repoId && modelContainer != nil {
|
||||
currentModel = effectiveConfig
|
||||
return // already loaded
|
||||
}
|
||||
|
||||
@@ -65,10 +116,10 @@ final class ModelManager {
|
||||
MLX.GPU.clearCache()
|
||||
isLoading = true
|
||||
downloadProgress = 0
|
||||
loadingModelName = config.displayName
|
||||
loadingModelName = effectiveConfig.displayName
|
||||
errorMessage = nil
|
||||
|
||||
let needsDownload = !config.isLocal
|
||||
let needsDownload = !effectiveConfig.isLocal
|
||||
if needsDownload {
|
||||
isDownloading = true
|
||||
downloadFilesTotal = 0
|
||||
@@ -91,32 +142,23 @@ final class ModelManager {
|
||||
}
|
||||
|
||||
let configuration: ModelConfiguration
|
||||
if let localDir = LocalModelResolver.resolve(repoId: config.repoId) {
|
||||
if let localDir = LocalModelResolver.resolve(repoId: effectiveConfig.repoId) {
|
||||
configuration = ModelConfiguration(directory: localDir)
|
||||
} else {
|
||||
configuration = config.modelConfiguration
|
||||
configuration = effectiveConfig.modelConfiguration
|
||||
}
|
||||
|
||||
let container: ModelContainer
|
||||
switch config.loaderKind {
|
||||
case .llm:
|
||||
container = try await LLMModelFactory.shared.loadContainer(
|
||||
hub: Self.hub,
|
||||
configuration: configuration,
|
||||
progressHandler: progressHandler
|
||||
)
|
||||
case .vlm:
|
||||
container = try await VLMModelFactory.shared.loadContainer(
|
||||
hub: Self.hub,
|
||||
configuration: configuration,
|
||||
progressHandler: progressHandler
|
||||
)
|
||||
}
|
||||
let container = try await Self.loadContainer(
|
||||
for: effectiveConfig,
|
||||
configuration: configuration,
|
||||
progressHandler: progressHandler
|
||||
)
|
||||
|
||||
guard latestLoadRequestID == requestID else { return }
|
||||
refreshAvailableModels()
|
||||
self.isDownloading = false
|
||||
self.modelContainer = container
|
||||
self.currentModel = config
|
||||
self.currentModel = self.availableModels.first(where: { $0.repoId == effectiveConfig.repoId }) ?? effectiveConfig
|
||||
touchActivity()
|
||||
} catch {
|
||||
guard latestLoadRequestID == requestID else { return }
|
||||
@@ -135,6 +177,25 @@ final class ModelManager {
|
||||
await loadModel(config)
|
||||
}
|
||||
|
||||
func addModel(repoId: String) async {
|
||||
let repoId = repoId.trimmingCharacters(in: .whitespacesAndNewlines)
|
||||
guard !repoId.isEmpty else {
|
||||
errorMessage = "Enter a HuggingFace model ID."
|
||||
return
|
||||
}
|
||||
|
||||
let config = ModelConfig.resolve(repoId) ?? ModelConfig.remoteCustom(repoId: repoId)
|
||||
await loadModel(config)
|
||||
}
|
||||
|
||||
func deleteModel(_ config: ModelConfig) {
|
||||
if currentModel?.repoId == config.repoId {
|
||||
unloadModel()
|
||||
}
|
||||
_ = LocalModelResolver.deleteLocal(repoId: config.repoId)
|
||||
refreshAvailableModels()
|
||||
}
|
||||
|
||||
/// Unload the current model and free GPU memory.
|
||||
func unloadModel() {
|
||||
latestLoadRequestID = UUID()
|
||||
@@ -161,4 +222,35 @@ final class ModelManager {
|
||||
var isReady: Bool {
|
||||
modelContainer != nil && !isLoading
|
||||
}
|
||||
|
||||
private static func loadContainer(
|
||||
for config: ModelConfig,
|
||||
configuration: ModelConfiguration,
|
||||
progressHandler: @escaping @Sendable (Progress) -> Void
|
||||
) async throws -> ModelContainer {
|
||||
var lastError: Error?
|
||||
|
||||
for loaderKind in config.loaderKinds {
|
||||
do {
|
||||
switch loaderKind {
|
||||
case .llm:
|
||||
return try await LLMModelFactory.shared.loadContainer(
|
||||
hub: Self.hub,
|
||||
configuration: configuration,
|
||||
progressHandler: progressHandler
|
||||
)
|
||||
case .vlm:
|
||||
return try await VLMModelFactory.shared.loadContainer(
|
||||
hub: Self.hub,
|
||||
configuration: configuration,
|
||||
progressHandler: progressHandler
|
||||
)
|
||||
}
|
||||
} catch {
|
||||
lastError = error
|
||||
}
|
||||
}
|
||||
|
||||
throw lastError ?? NSError(domain: "ModelManager", code: -1, userInfo: [NSLocalizedDescriptionKey: "Unsupported model configuration"])
|
||||
}
|
||||
}
|
||||
|
||||
431
MLXServer/Views/ModelManagementView.swift
Normal file
431
MLXServer/Views/ModelManagementView.swift
Normal file
@@ -0,0 +1,431 @@
|
||||
import SwiftUI
|
||||
|
||||
struct ModelManagementView: View {
|
||||
@Environment(ModelManager.self) private var modelManager
|
||||
|
||||
@State private var newRepoId = ""
|
||||
@State private var pendingDelete: ModelConfig?
|
||||
@State private var editingMetadataModel: ModelConfig?
|
||||
@FocusState private var isRepoIdFieldFocused: Bool
|
||||
|
||||
private let sizeFormatter: ByteCountFormatter = {
|
||||
let formatter = ByteCountFormatter()
|
||||
formatter.allowedUnits = [.useGB, .useMB, .useKB]
|
||||
formatter.countStyle = .file
|
||||
formatter.includesUnit = true
|
||||
formatter.isAdaptive = true
|
||||
return formatter
|
||||
}()
|
||||
|
||||
var body: some View {
|
||||
ScrollView {
|
||||
VStack(alignment: .leading, spacing: 18) {
|
||||
GroupBox("Add Model") {
|
||||
VStack(alignment: .leading, spacing: 10) {
|
||||
Text("Enter a HuggingFace model ID. The app will download it, load it once, and then keep it available in the regular model picker.")
|
||||
.font(.caption)
|
||||
.foregroundStyle(.secondary)
|
||||
|
||||
HStack {
|
||||
TextField("owner/repo", text: $newRepoId)
|
||||
.textFieldStyle(.roundedBorder)
|
||||
.focused($isRepoIdFieldFocused)
|
||||
.onSubmit {
|
||||
downloadEnteredModel()
|
||||
}
|
||||
|
||||
Button("Download & Select") {
|
||||
downloadEnteredModel()
|
||||
}
|
||||
.disabled(modelManager.isLoading || newRepoId.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
GroupBox("Recommended Defaults") {
|
||||
VStack(spacing: 0) {
|
||||
ForEach(modelManager.curatedModels) { model in
|
||||
curatedRow(model)
|
||||
if model.id != modelManager.curatedModels.last?.id {
|
||||
Divider()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
GroupBox("Models On Disk") {
|
||||
if modelManager.localModelsOnDisk.isEmpty {
|
||||
ContentUnavailableView(
|
||||
"No Local Models",
|
||||
systemImage: "externaldrive",
|
||||
description: Text("Downloaded models will appear here with their summed file sizes.")
|
||||
)
|
||||
.frame(maxWidth: .infinity)
|
||||
.padding(.vertical, 20)
|
||||
} else {
|
||||
VStack(spacing: 0) {
|
||||
ForEach(modelManager.localModelsOnDisk) { model in
|
||||
localRow(model)
|
||||
if model.id != modelManager.localModelsOnDisk.last?.id {
|
||||
Divider()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
.padding(20)
|
||||
}
|
||||
.navigationTitle("Models")
|
||||
.frame(minWidth: 760, minHeight: 520)
|
||||
.sheet(item: $editingMetadataModel) { model in
|
||||
ModelMetadataEditorView(
|
||||
model: model,
|
||||
baselineModel: modelManager.baselineModel(repoId: model.repoId) ?? model,
|
||||
detectedLocalModel: modelManager.discoveredLocalModelInfo(repoId: model.repoId),
|
||||
hasSavedOverride: Preferences.hasModelMetadataOverride(forRepoId: model.repoId),
|
||||
onSave: { override in
|
||||
modelManager.saveMetadataOverride(override, for: model)
|
||||
},
|
||||
onReset: {
|
||||
modelManager.clearMetadataOverride(for: model)
|
||||
}
|
||||
)
|
||||
}
|
||||
.alert(
|
||||
"Delete Local Model?",
|
||||
isPresented: Binding(
|
||||
get: { pendingDelete != nil },
|
||||
set: { if !$0 { pendingDelete = nil } }
|
||||
)
|
||||
) {
|
||||
Button("Delete", role: .destructive) {
|
||||
if let pendingDelete {
|
||||
modelManager.deleteModel(pendingDelete)
|
||||
}
|
||||
self.pendingDelete = nil
|
||||
}
|
||||
Button("Cancel", role: .cancel) {
|
||||
pendingDelete = nil
|
||||
}
|
||||
} message: {
|
||||
if let pendingDelete {
|
||||
Text("This removes the local files for \(pendingDelete.repoId).")
|
||||
}
|
||||
}
|
||||
.onAppear {
|
||||
modelManager.refreshAvailableModels()
|
||||
if newRepoId.isEmpty {
|
||||
isRepoIdFieldFocused = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ViewBuilder
|
||||
private func curatedRow(_ model: ModelConfig) -> some View {
|
||||
HStack(alignment: .top, spacing: 14) {
|
||||
VStack(alignment: .leading, spacing: 4) {
|
||||
HStack(spacing: 8) {
|
||||
Text(model.displayName)
|
||||
.font(.headline)
|
||||
if modelManager.currentModel?.repoId == model.repoId {
|
||||
Text("Loaded")
|
||||
.font(.caption.weight(.semibold))
|
||||
.padding(.horizontal, 8)
|
||||
.padding(.vertical, 3)
|
||||
.background(.green.opacity(0.15), in: Capsule())
|
||||
}
|
||||
}
|
||||
|
||||
Text(model.repoId)
|
||||
.font(.caption)
|
||||
.foregroundStyle(.secondary)
|
||||
}
|
||||
|
||||
Spacer()
|
||||
|
||||
Label(
|
||||
model.isLocal ? "On Disk" : "Not Downloaded",
|
||||
systemImage: model.isLocal ? "checkmark.circle.fill" : "arrow.down.circle"
|
||||
)
|
||||
.font(.caption)
|
||||
.foregroundStyle(model.isLocal ? .green : .secondary)
|
||||
|
||||
Button(model.isLocal ? "Load" : "Download") {
|
||||
Task {
|
||||
await modelManager.loadModel(model)
|
||||
}
|
||||
}
|
||||
.disabled(modelManager.isLoading)
|
||||
|
||||
Button("Metadata…") {
|
||||
editingMetadataModel = model
|
||||
}
|
||||
}
|
||||
.padding(.vertical, 10)
|
||||
}
|
||||
|
||||
@ViewBuilder
|
||||
private func localRow(_ model: ModelConfig) -> some View {
|
||||
HStack(alignment: .top, spacing: 14) {
|
||||
VStack(alignment: .leading, spacing: 4) {
|
||||
HStack(spacing: 8) {
|
||||
Text(model.displayName)
|
||||
.font(.headline)
|
||||
if !model.isCurated {
|
||||
Text("Custom")
|
||||
.font(.caption.weight(.semibold))
|
||||
.padding(.horizontal, 8)
|
||||
.padding(.vertical, 3)
|
||||
.background(.secondary.opacity(0.14), in: Capsule())
|
||||
}
|
||||
if modelManager.currentModel?.repoId == model.repoId {
|
||||
Text("Loaded")
|
||||
.font(.caption.weight(.semibold))
|
||||
.padding(.horizontal, 8)
|
||||
.padding(.vertical, 3)
|
||||
.background(.green.opacity(0.15), in: Capsule())
|
||||
}
|
||||
}
|
||||
|
||||
Text(model.repoId)
|
||||
.font(.caption)
|
||||
.foregroundStyle(.secondary)
|
||||
}
|
||||
|
||||
Spacer()
|
||||
|
||||
if let localSizeBytes = model.localSizeBytes {
|
||||
Text(sizeFormatter.string(fromByteCount: localSizeBytes))
|
||||
.font(.caption.monospacedDigit())
|
||||
.foregroundStyle(.secondary)
|
||||
.frame(width: 90, alignment: .trailing)
|
||||
}
|
||||
|
||||
Button("Load") {
|
||||
Task {
|
||||
await modelManager.loadModel(model)
|
||||
}
|
||||
}
|
||||
.disabled(modelManager.isLoading)
|
||||
|
||||
Button("Metadata…") {
|
||||
editingMetadataModel = model
|
||||
}
|
||||
|
||||
Button("Delete", role: .destructive) {
|
||||
pendingDelete = model
|
||||
}
|
||||
.disabled(modelManager.isLoading)
|
||||
}
|
||||
.padding(.vertical, 10)
|
||||
}
|
||||
|
||||
private func downloadEnteredModel() {
|
||||
let repoId = newRepoId.trimmingCharacters(in: .whitespacesAndNewlines)
|
||||
guard !repoId.isEmpty else { return }
|
||||
|
||||
Task {
|
||||
await modelManager.addModel(repoId: repoId)
|
||||
if modelManager.errorMessage == nil {
|
||||
newRepoId = ""
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private struct ModelMetadataEditorView: View {
|
||||
@Environment(\.dismiss) private var dismiss
|
||||
|
||||
let model: ModelConfig
|
||||
let baselineModel: ModelConfig
|
||||
let detectedLocalModel: LocalModelResolver.LocalModelInfo?
|
||||
let hasSavedOverride: Bool
|
||||
let onSave: (ModelMetadataOverride) -> Void
|
||||
let onReset: () -> Void
|
||||
|
||||
@State private var contextLengthText: String
|
||||
@State private var primaryLoaderKind: ModelConfig.LoaderKind
|
||||
@State private var supportsImages: Bool
|
||||
@State private var supportsTools: Bool
|
||||
|
||||
init(
|
||||
model: ModelConfig,
|
||||
baselineModel: ModelConfig,
|
||||
detectedLocalModel: LocalModelResolver.LocalModelInfo?,
|
||||
hasSavedOverride: Bool,
|
||||
onSave: @escaping (ModelMetadataOverride) -> Void,
|
||||
onReset: @escaping () -> Void
|
||||
) {
|
||||
self.model = model
|
||||
self.baselineModel = baselineModel
|
||||
self.detectedLocalModel = detectedLocalModel
|
||||
self.hasSavedOverride = hasSavedOverride
|
||||
self.onSave = onSave
|
||||
self.onReset = onReset
|
||||
_contextLengthText = State(initialValue: String(model.contextLength))
|
||||
_primaryLoaderKind = State(initialValue: model.primaryLoaderKind)
|
||||
_supportsImages = State(initialValue: model.supportsImages)
|
||||
_supportsTools = State(initialValue: model.supportsTools)
|
||||
}
|
||||
|
||||
var body: some View {
|
||||
NavigationStack {
|
||||
Form {
|
||||
Section("Metadata") {
|
||||
TextField("Context length", text: $contextLengthText)
|
||||
.textFieldStyle(.roundedBorder)
|
||||
|
||||
Picker("Primary loader", selection: $primaryLoaderKind) {
|
||||
ForEach(ModelConfig.LoaderKind.allCases, id: \.self) { loaderKind in
|
||||
Text(loaderKind.displayName).tag(loaderKind)
|
||||
}
|
||||
}
|
||||
|
||||
Toggle("Supports images", isOn: $supportsImages)
|
||||
Toggle("Supports tools", isOn: $supportsTools)
|
||||
}
|
||||
|
||||
Section("Comparison") {
|
||||
Text(defaultsSummary)
|
||||
.foregroundStyle(.secondary)
|
||||
|
||||
Grid(alignment: .leading, horizontalSpacing: 16, verticalSpacing: 8) {
|
||||
GridRow {
|
||||
Text("")
|
||||
Text("Effective")
|
||||
.font(.caption.weight(.semibold))
|
||||
.foregroundStyle(.secondary)
|
||||
Text(baselineHeading)
|
||||
.font(.caption.weight(.semibold))
|
||||
.foregroundStyle(.secondary)
|
||||
}
|
||||
|
||||
comparisonRow(
|
||||
label: "Context",
|
||||
effective: currentOverride?.contextLength.description ?? "Invalid",
|
||||
baseline: baselineModel.contextLength > 0 ? "\(baselineModel.contextLength)" : "Unknown"
|
||||
)
|
||||
comparisonRow(
|
||||
label: "Loader",
|
||||
effective: primaryLoaderKind.displayName,
|
||||
baseline: baselineModel.primaryLoaderKind.displayName
|
||||
)
|
||||
comparisonRow(
|
||||
label: "Images",
|
||||
effective: yesNo(supportsImages),
|
||||
baseline: yesNo(baselineModel.supportsImages)
|
||||
)
|
||||
comparisonRow(
|
||||
label: "Tools",
|
||||
effective: yesNo(supportsTools),
|
||||
baseline: yesNo(baselineModel.supportsTools)
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
if let detectedLocalModel {
|
||||
Section("Discovered Source") {
|
||||
LabeledContent("Detected context") {
|
||||
Text(detectedLocalModel.contextLength > 0 ? "\(detectedLocalModel.contextLength)" : "Unknown")
|
||||
}
|
||||
LabeledContent("Detected loader order") {
|
||||
Text(detectedLocalModel.loaderKinds.map(\.displayName).joined(separator: ", "))
|
||||
}
|
||||
LabeledContent("Detected vision") {
|
||||
Text(yesNo(detectedLocalModel.supportsImages))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
.formStyle(.grouped)
|
||||
.navigationTitle(model.displayName)
|
||||
.frame(minWidth: 520, minHeight: 380)
|
||||
.toolbar {
|
||||
ToolbarItem(placement: .cancellationAction) {
|
||||
Button("Cancel") {
|
||||
dismiss()
|
||||
}
|
||||
}
|
||||
|
||||
ToolbarItem(placement: .primaryAction) {
|
||||
Button("Save") {
|
||||
guard let currentOverride else { return }
|
||||
onSave(currentOverride)
|
||||
dismiss()
|
||||
}
|
||||
.disabled(currentOverride == nil)
|
||||
}
|
||||
|
||||
if hasSavedOverride {
|
||||
ToolbarItem(placement: .automatic) {
|
||||
Button("Reset to Detected") {
|
||||
onReset()
|
||||
dismiss()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private var currentOverride: ModelMetadataOverride? {
|
||||
guard let contextLength = Int(contextLengthText.trimmingCharacters(in: .whitespacesAndNewlines)), contextLength >= 0 else {
|
||||
return nil
|
||||
}
|
||||
|
||||
return ModelMetadataOverride(
|
||||
contextLength: contextLength,
|
||||
primaryLoaderKind: primaryLoaderKind,
|
||||
supportsImages: supportsImages,
|
||||
supportsTools: supportsTools
|
||||
)
|
||||
}
|
||||
|
||||
private var defaultsSummary: String {
|
||||
if detectedLocalModel != nil {
|
||||
if hasSavedOverride {
|
||||
return "The editable fields show the effective overridden metadata. The comparison column shows the discovered baseline from the local model files."
|
||||
}
|
||||
return "The editable fields currently match the discovered baseline from the local model files. Save to store an override for this repo ID."
|
||||
}
|
||||
|
||||
if model.isCurated {
|
||||
return hasSavedOverride
|
||||
? "The editable fields show the effective overridden metadata. The comparison column shows the curated built-in baseline."
|
||||
: "The editable fields currently match the curated built-in baseline. Save to store an override for this repo ID."
|
||||
}
|
||||
|
||||
if hasSavedOverride {
|
||||
return "The editable fields show the effective overridden metadata. The comparison column shows the inferred baseline for this repo ID."
|
||||
}
|
||||
|
||||
return "The editable fields currently match the inferred baseline for this repo ID. Save to store an override."
|
||||
}
|
||||
|
||||
private var baselineHeading: String {
|
||||
if detectedLocalModel != nil {
|
||||
return "Detected"
|
||||
}
|
||||
if model.isCurated {
|
||||
return "Built-in"
|
||||
}
|
||||
return "Inferred"
|
||||
}
|
||||
|
||||
@ViewBuilder
|
||||
private func comparisonRow(label: String, effective: String, baseline: String) -> some View {
|
||||
GridRow {
|
||||
Text(label)
|
||||
Text(effective)
|
||||
.monospaced()
|
||||
Text(baseline)
|
||||
.foregroundStyle(.secondary)
|
||||
.monospaced()
|
||||
}
|
||||
}
|
||||
|
||||
private func yesNo(_ value: Bool) -> String {
|
||||
value ? "Yes" : "No"
|
||||
}
|
||||
}
|
||||
5
MLXServer/Views/ModelManagementWindow.swift
Normal file
5
MLXServer/Views/ModelManagementWindow.swift
Normal file
@@ -0,0 +1,5 @@
|
||||
import Foundation
|
||||
|
||||
enum ModelManagementWindow {
|
||||
static let windowID = "model-manager"
|
||||
}
|
||||
@@ -7,14 +7,14 @@ struct ModelPickerView: View {
|
||||
var body: some View {
|
||||
HStack(spacing: 8) {
|
||||
Picker("Model", selection: selectedModelBinding) {
|
||||
ForEach(ModelConfig.availableModels) { config in
|
||||
ForEach(modelManager.availableModels) { config in
|
||||
Label(
|
||||
config.displayName,
|
||||
config.isCurated ? config.displayName : config.repoId,
|
||||
systemImage: config.isLocal ? "checkmark.circle.fill" : "arrow.down.circle"
|
||||
).tag(config.id)
|
||||
}
|
||||
}
|
||||
.frame(width: 160)
|
||||
.frame(width: 260)
|
||||
.disabled(modelManager.isLoading)
|
||||
|
||||
// Re-download button (visible when a model is loaded)
|
||||
@@ -50,9 +50,20 @@ struct ModelPickerView: View {
|
||||
|
||||
private var selectedModelBinding: Binding<String> {
|
||||
Binding(
|
||||
get: { modelManager.currentModel?.id ?? ModelConfig.default.id },
|
||||
get: {
|
||||
if let currentId = modelManager.currentModel?.id {
|
||||
return currentId
|
||||
}
|
||||
if let defaultId = Preferences.defaultModelId,
|
||||
let config = modelManager.availableModels.first(where: { $0.id == defaultId || $0.repoId == defaultId }) {
|
||||
return config.id
|
||||
}
|
||||
return ModelConfig.default.id
|
||||
},
|
||||
set: { newId in
|
||||
guard let config = ModelConfig.availableModels.first(where: { $0.id == newId }) else { return }
|
||||
guard let config = modelManager.availableModels.first(where: { $0.id == newId }) ?? ModelConfig.resolve(newId) else {
|
||||
return
|
||||
}
|
||||
Task {
|
||||
await modelManager.loadModel(config)
|
||||
}
|
||||
|
||||
@@ -210,6 +210,7 @@ struct SceneManagementView: View {
|
||||
}
|
||||
|
||||
private struct SceneEditorView: View {
|
||||
@Environment(ModelManager.self) private var modelManager
|
||||
@Environment(SceneStore.self) private var sceneStore
|
||||
|
||||
let scene: ChatScene
|
||||
@@ -221,7 +222,7 @@ private struct SceneEditorView: View {
|
||||
|
||||
Picker("Model", selection: modelBinding) {
|
||||
Text("Current model").tag(Optional<String>.none)
|
||||
ForEach(ModelConfig.availableModels) { model in
|
||||
ForEach(modelManager.availableModels) { model in
|
||||
Text(model.displayName).tag(Optional(model.id))
|
||||
}
|
||||
}
|
||||
@@ -257,6 +258,9 @@ private struct SceneEditorView: View {
|
||||
}
|
||||
.formStyle(.grouped)
|
||||
.navigationTitle(scene.displayName)
|
||||
.onAppear {
|
||||
modelManager.refreshAvailableModels()
|
||||
}
|
||||
}
|
||||
|
||||
private var modelBinding: Binding<String?> {
|
||||
|
||||
@@ -2,6 +2,7 @@ import SwiftUI
|
||||
|
||||
struct SettingsView: View {
|
||||
@Environment(\.openWindow) private var openWindow
|
||||
@Environment(ModelManager.self) private var modelManager
|
||||
@Environment(SceneStore.self) private var sceneStore
|
||||
@State private var systemPrompt: String = Preferences.systemPrompt
|
||||
@State private var apiPort: String = String(Preferences.apiPort)
|
||||
@@ -29,7 +30,7 @@ struct SettingsView: View {
|
||||
Form {
|
||||
Section("Startup") {
|
||||
Picker("Default model", selection: $defaultModelId) {
|
||||
ForEach(ModelConfig.availableModels) { model in
|
||||
ForEach(modelManager.availableModels) { model in
|
||||
Text(model.displayName).tag(model.id)
|
||||
}
|
||||
}
|
||||
@@ -44,7 +45,7 @@ struct SettingsView: View {
|
||||
|
||||
Section("Generation Defaults") {
|
||||
Picker("Defaults for model", selection: $generationDefaultsModelId) {
|
||||
ForEach(ModelConfig.availableModels) { model in
|
||||
ForEach(modelManager.availableModels) { model in
|
||||
Text(model.displayName).tag(model.id)
|
||||
}
|
||||
}
|
||||
@@ -164,6 +165,15 @@ struct SettingsView: View {
|
||||
}
|
||||
.formStyle(.grouped)
|
||||
.frame(width: 450, height: 650)
|
||||
.onAppear {
|
||||
modelManager.refreshAvailableModels()
|
||||
if !modelManager.availableModels.contains(where: { $0.id == defaultModelId }) {
|
||||
defaultModelId = modelManager.availableModels.first?.id ?? ModelConfig.default.id
|
||||
}
|
||||
if !modelManager.availableModels.contains(where: { $0.id == generationDefaultsModelId }) {
|
||||
generationDefaultsModelId = defaultModelId
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private var generationDefaultsBinding: Binding<GenerationSettings> {
|
||||
|
||||
180
MLXServerTests/Server/LocalModelResolverTests.swift
Normal file
180
MLXServerTests/Server/LocalModelResolverTests.swift
Normal file
@@ -0,0 +1,180 @@
|
||||
import Foundation
|
||||
import XCTest
|
||||
@testable import MLX_Server
|
||||
|
||||
final class LocalModelResolverTests: XCTestCase {
|
||||
func testDiscoverModelsInfersTextOnlyMetadataAndDirectorySize() throws {
|
||||
let base = try makeTempModelsRoot()
|
||||
let repoDirectory = try makeRepoDirectory(base: base, owner: "example", repo: "text-only")
|
||||
let configURL = repoDirectory.appendingPathComponent("config.json")
|
||||
let modelURL = repoDirectory.appendingPathComponent("model.safetensors")
|
||||
let tokenizerURL = repoDirectory.appendingPathComponent("tokenizer.json")
|
||||
|
||||
try writeJSON(
|
||||
[
|
||||
"architectures": ["LlamaForCausalLM"],
|
||||
"max_position_embeddings": 32768,
|
||||
],
|
||||
to: configURL
|
||||
)
|
||||
try Data(repeating: 0x11, count: 64).write(to: modelURL)
|
||||
try Data(repeating: 0x22, count: 19).write(to: tokenizerURL)
|
||||
|
||||
let expectedSize = Int64(
|
||||
try Data(contentsOf: configURL).count
|
||||
+ Data(contentsOf: modelURL).count
|
||||
+ Data(contentsOf: tokenizerURL).count
|
||||
)
|
||||
|
||||
let discovered = LocalModelResolver.discoverModels(in: base)
|
||||
let model = try XCTUnwrap(discovered.first)
|
||||
|
||||
XCTAssertEqual(model.repoId, "example/text-only")
|
||||
XCTAssertEqual(model.contextLength, 32768)
|
||||
XCTAssertFalse(model.supportsImages)
|
||||
XCTAssertEqual(model.loaderKinds, [.llm, .vlm])
|
||||
XCTAssertEqual(model.sizeBytes, expectedSize)
|
||||
}
|
||||
|
||||
func testDiscoverModelsInfersVisionMetadataFromProcessorFiles() throws {
|
||||
let base = try makeTempModelsRoot()
|
||||
let repoDirectory = try makeRepoDirectory(base: base, owner: "example", repo: "vision-model")
|
||||
try writeJSON(
|
||||
[
|
||||
"text_config": ["max_position_embeddings": 262144],
|
||||
"vision_config": ["hidden_size": 768],
|
||||
],
|
||||
to: repoDirectory.appendingPathComponent("config.json")
|
||||
)
|
||||
try writeJSON(["processor_class": "Qwen3VLProcessor"], to: repoDirectory.appendingPathComponent("tokenizer_config.json"))
|
||||
try Data(repeating: 0x33, count: 12).write(to: repoDirectory.appendingPathComponent("processor_config.json"))
|
||||
try Data(repeating: 0x44, count: 8).write(to: repoDirectory.appendingPathComponent("model.safetensors.index.json"))
|
||||
|
||||
let discovered = LocalModelResolver.discoverModels(in: base)
|
||||
let model = try XCTUnwrap(discovered.first)
|
||||
|
||||
XCTAssertEqual(model.repoId, "example/vision-model")
|
||||
XCTAssertEqual(model.contextLength, 262144)
|
||||
XCTAssertTrue(model.supportsImages)
|
||||
XCTAssertEqual(model.loaderKinds, [.vlm, .llm])
|
||||
}
|
||||
|
||||
func testMergedCatalogKeepsCuratedModelsAndAddsCustomLocalModels() {
|
||||
let localModels = [
|
||||
LocalModelResolver.LocalModelInfo(
|
||||
repoId: "mlx-community/gemma-3-4b-it-4bit",
|
||||
directory: URL(fileURLWithPath: "/tmp/gemma"),
|
||||
sizeBytes: 1024,
|
||||
contextLength: 128000,
|
||||
loaderKinds: [.vlm, .llm],
|
||||
supportsImages: true
|
||||
),
|
||||
LocalModelResolver.LocalModelInfo(
|
||||
repoId: "custom-org/custom-model",
|
||||
directory: URL(fileURLWithPath: "/tmp/custom"),
|
||||
sizeBytes: 2048,
|
||||
contextLength: 65536,
|
||||
loaderKinds: [.llm, .vlm],
|
||||
supportsImages: false
|
||||
),
|
||||
]
|
||||
|
||||
let merged = ModelConfig.mergedModels(localModels: localModels)
|
||||
let gemma = merged.first(where: { $0.id == "gemma" })
|
||||
let custom = merged.first(where: { $0.repoId == "custom-org/custom-model" })
|
||||
|
||||
XCTAssertEqual(gemma?.localSizeBytes, 1024)
|
||||
XCTAssertEqual(custom?.id, "custom-org/custom-model")
|
||||
XCTAssertEqual(custom?.contextLength, 65536)
|
||||
XCTAssertFalse(custom?.isCurated ?? true)
|
||||
}
|
||||
|
||||
func testResolveUnknownRepoIdCreatesRemoteCustomConfig() throws {
|
||||
let config = try XCTUnwrap(ModelConfig.resolve("custom-owner/custom-repo"))
|
||||
|
||||
XCTAssertEqual(config.id, "custom-owner/custom-repo")
|
||||
XCTAssertEqual(config.repoId, "custom-owner/custom-repo")
|
||||
XCTAssertFalse(config.isCurated)
|
||||
}
|
||||
|
||||
func testMergedCatalogAppliesSavedMetadataOverride() {
|
||||
let repoId = "custom-org/override-model"
|
||||
Preferences.setModelMetadataOverride(
|
||||
ModelMetadataOverride(
|
||||
contextLength: 123456,
|
||||
primaryLoaderKind: .vlm,
|
||||
supportsImages: true,
|
||||
supportsTools: true
|
||||
),
|
||||
forRepoId: repoId
|
||||
)
|
||||
defer {
|
||||
Preferences.removeModelMetadataOverride(forRepoId: repoId)
|
||||
}
|
||||
|
||||
let localModels = [
|
||||
LocalModelResolver.LocalModelInfo(
|
||||
repoId: repoId,
|
||||
directory: URL(fileURLWithPath: "/tmp/custom-override"),
|
||||
sizeBytes: 2048,
|
||||
contextLength: 65536,
|
||||
loaderKinds: [.llm, .vlm],
|
||||
supportsImages: false
|
||||
),
|
||||
]
|
||||
|
||||
let merged = ModelConfig.mergedModels(localModels: localModels)
|
||||
let overridden = merged.first(where: { $0.repoId == repoId })
|
||||
|
||||
XCTAssertEqual(overridden?.contextLength, 123456)
|
||||
XCTAssertEqual(overridden?.primaryLoaderKind, .vlm)
|
||||
XCTAssertTrue(overridden?.supportsImages ?? false)
|
||||
XCTAssertTrue(overridden?.supportsTools ?? false)
|
||||
}
|
||||
|
||||
func testResolveUnknownRepoIdUsesSavedMetadataOverride() throws {
|
||||
let repoId = "custom-owner/custom-repo-with-override"
|
||||
Preferences.setModelMetadataOverride(
|
||||
ModelMetadataOverride(
|
||||
contextLength: 8192,
|
||||
primaryLoaderKind: .llm,
|
||||
supportsImages: false,
|
||||
supportsTools: true
|
||||
),
|
||||
forRepoId: repoId
|
||||
)
|
||||
defer {
|
||||
Preferences.removeModelMetadataOverride(forRepoId: repoId)
|
||||
}
|
||||
|
||||
let config = try XCTUnwrap(ModelConfig.resolve(repoId))
|
||||
|
||||
XCTAssertEqual(config.contextLength, 8192)
|
||||
XCTAssertEqual(config.primaryLoaderKind, .llm)
|
||||
XCTAssertFalse(config.supportsImages)
|
||||
XCTAssertTrue(config.supportsTools)
|
||||
}
|
||||
|
||||
private func makeTempModelsRoot() throws -> URL {
|
||||
let root = FileManager.default.temporaryDirectory
|
||||
.appendingPathComponent(UUID().uuidString, isDirectory: true)
|
||||
try FileManager.default.createDirectory(at: root, withIntermediateDirectories: true)
|
||||
addTeardownBlock {
|
||||
try? FileManager.default.removeItem(at: root)
|
||||
}
|
||||
return root
|
||||
}
|
||||
|
||||
private func makeRepoDirectory(base: URL, owner: String, repo: String) throws -> URL {
|
||||
let directory = base
|
||||
.appendingPathComponent(owner, isDirectory: true)
|
||||
.appendingPathComponent(repo, isDirectory: true)
|
||||
try FileManager.default.createDirectory(at: directory, withIntermediateDirectories: true)
|
||||
return directory
|
||||
}
|
||||
|
||||
private func writeJSON(_ object: Any, to url: URL) throws {
|
||||
let data = try JSONSerialization.data(withJSONObject: object, options: [.prettyPrinted, .sortedKeys])
|
||||
try data.write(to: url)
|
||||
}
|
||||
}
|
||||
@@ -44,7 +44,8 @@ This is intended for targeted validation while keeping the normal default as the
|
||||
|
||||
- **Chat interface** with markdown rendering and model-aware image attachments (file picker, drag & drop, clipboard paste, Finder copy-paste on vision-capable models)
|
||||
- **Scene-based chat starts** — New Chat opens a scene picker with Neutral plus saved scenes, each with an optional model override, a scene prompt layered onto the base system prompt, an auto-sent starter prompt, and optional generation-setting overrides for chat-specific behavior
|
||||
- **Model picker** in toolbar with local/download status indicators and re-download button
|
||||
- **Model picker** in toolbar with curated defaults plus any locally discovered MLX models on disk
|
||||
- **Models window** in the menu for downloading a model by HuggingFace ID, inspecting on-disk model sizes, and deleting local model folders
|
||||
- **Download progress modal** — shows file progress, percentage, and speed when downloading a new model
|
||||
- **Thinking mode** — models like Qwen3.5 can reason internally before responding; thinking content appears in a collapsible box. Toggle on/off in Settings.
|
||||
- **Streaming responses** with live token display
|
||||
|
||||
Reference in New Issue
Block a user