diff --git a/MLXServer.xcodeproj/project.pbxproj b/MLXServer.xcodeproj/project.pbxproj index da7e92a..ec2b104 100644 --- a/MLXServer.xcodeproj/project.pbxproj +++ b/MLXServer.xcodeproj/project.pbxproj @@ -9,12 +9,14 @@ /* Begin PBXBuildFile section */ 0168AEE16009097901363E16 /* ModelManager.swift in Sources */ = {isa = PBXBuildFile; fileRef = 922CBDC9206737BD04AF2874 /* ModelManager.swift */; }; 07119250A7F9D6ECE7F6B8FD /* SceneCommands.swift in Sources */ = {isa = PBXBuildFile; fileRef = 0F03A123A8908714A89315FE /* SceneCommands.swift */; }; + 0BC7203552A161BC852975EA /* GenerationSettingsEditor.swift in Sources */ = {isa = PBXBuildFile; fileRef = 7AE2A32FBB744696DEA77435 /* GenerationSettingsEditor.swift */; }; 165E8AB6ADAE1D59B1A86420 /* Preferences.swift in Sources */ = {isa = PBXBuildFile; fileRef = 145B888FBDD4F931512C5473 /* Preferences.swift */; }; 189362AAE2CDE5D4B3428334 /* ToolCallParser.swift in Sources */ = {isa = PBXBuildFile; fileRef = E73B165A1822729C907791AE /* ToolCallParser.swift */; }; 1A8833E3CCD3289C95E282A2 /* ChatDocumentManifest.swift in Sources */ = {isa = PBXBuildFile; fileRef = 1607BDDE53C575627DCC6896 /* ChatDocumentManifest.swift */; }; 1FE8C624898960ECCE39C0D4 /* PromptBuilderTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = 5F9426FA5A4AC55F8D9C080E /* PromptBuilderTests.swift */; }; 20FFB5DBF75AA6C359AAE31C /* SceneManagementView.swift in Sources */ = {isa = PBXBuildFile; fileRef = 37FEB592E5E717F817B03151 /* SceneManagementView.swift */; }; 221DEC86374902FCFD661A01 /* TokenPrefixCacheTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = 64B2EDD5D1881AC9E1E60913 /* TokenPrefixCacheTests.swift */; }; + 2640EDCA9033D85C0B785557 /* GenerationSettings.swift in Sources */ = {isa = PBXBuildFile; fileRef = 6FAF7455BD387CD2061E0CBF /* GenerationSettings.swift */; }; 29879D696584B96CC56560DF /* ChatExporter.swift in Sources */ = {isa = PBXBuildFile; fileRef = D7C9BAD674E29688ACE53B0B /* ChatExporter.swift */; }; 2CAAF7129F7CC45200FA9F6B /* ModelPickerView.swift in Sources */ = {isa = PBXBuildFile; fileRef = C3C3A76C02AF70A9D8F868FC /* ModelPickerView.swift */; }; 2D08769282BD71C170DB0943 /* InferenceStats.swift in Sources */ = {isa = PBXBuildFile; fileRef = E35452B166893B25E765FF70 /* InferenceStats.swift */; }; @@ -37,6 +39,7 @@ 7CD765C1E2F9F4D7504C8D09 /* Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = B629DA084A9A40E54F8EA5FA /* Assets.xcassets */; }; 80646C5066BF79BC76E1D9D7 /* ModelConfig.swift in Sources */ = {isa = PBXBuildFile; fileRef = 38DFC212AF4359A45FBE22BA /* ModelConfig.swift */; }; 834B49AA3E30A1FED549D057 /* ToolCallParserTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = B89226C9ED585A5296C54441 /* ToolCallParserTests.swift */; }; + 847B445654860396AF5A8280 /* GenerationSettingsTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = 57AC0815F72BDD32FC54C88A /* GenerationSettingsTests.swift */; }; 84D32315B418B5243E017350 /* ToolPromptBuilder.swift in Sources */ = {isa = PBXBuildFile; fileRef = 16AE82A64D1D07AE3CD8D33A /* ToolPromptBuilder.swift */; }; 85FB1EB49D76A9F21E181346 /* ChatScene.swift in Sources */ = {isa = PBXBuildFile; fileRef = C04EE8E6418EC6E9B66999B0 /* ChatScene.swift */; }; 8E665E21CCCD87A907CEA78D /* ModelBackedInferenceValidationTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = D388BE00B42C06ED9D9905BF /* ModelBackedInferenceValidationTests.swift */; }; @@ -98,11 +101,14 @@ 4147321383E94E9F17A0154E /* SettingsView.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = SettingsView.swift; sourceTree = ""; }; 4239CFF94B819C35A8D4D617 /* MonitorView.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = MonitorView.swift; sourceTree = ""; }; 49C383DD5224F3420EB98DB2 /* StreamingSSEEncoderTests.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = StreamingSSEEncoderTests.swift; sourceTree = ""; }; + 57AC0815F72BDD32FC54C88A /* GenerationSettingsTests.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = GenerationSettingsTests.swift; sourceTree = ""; }; 5F9426FA5A4AC55F8D9C080E /* PromptBuilderTests.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = PromptBuilderTests.swift; sourceTree = ""; }; 615F8A7C9ABCADEB215D31BD /* StreamingSSEEncoder.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = StreamingSSEEncoder.swift; sourceTree = ""; }; 64B2EDD5D1881AC9E1E60913 /* TokenPrefixCacheTests.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = TokenPrefixCacheTests.swift; sourceTree = ""; }; 6B3AA91D2C7842D7366F9A41 /* ChatDocumentPackage.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ChatDocumentPackage.swift; sourceTree = ""; }; 6EE59189918D06B8D2F588FC /* MLXServer.app */ = {isa = PBXFileReference; explicitFileType = wrapper.application; includeInIndex = 0; path = MLXServer.app; sourceTree = BUILT_PRODUCTS_DIR; }; + 6FAF7455BD387CD2061E0CBF /* GenerationSettings.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = GenerationSettings.swift; sourceTree = ""; }; + 7AE2A32FBB744696DEA77435 /* GenerationSettingsEditor.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = GenerationSettingsEditor.swift; sourceTree = ""; }; 7C1A89C076E717F87A60397D /* ImageDecoder.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ImageDecoder.swift; sourceTree = ""; }; 7E7DF9F68C10C718844B7B01 /* LiveCountersTests.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = LiveCountersTests.swift; sourceTree = ""; }; 922CBDC9206737BD04AF2874 /* ModelManager.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ModelManager.swift; sourceTree = ""; }; @@ -190,6 +196,7 @@ E43535D68448F1752D91C3A9 /* APIServerRewriteTests.swift */, FEFF6168B2283FEC87B4BB8C /* CancellationTokenTests.swift */, B758F596F4F3E68793B045BB /* ChatViewModelTests.swift */, + 57AC0815F72BDD32FC54C88A /* GenerationSettingsTests.swift */, E4573DC9314915F4C7963B4E /* ImageDecoderTests.swift */, 7E7DF9F68C10C718844B7B01 /* LiveCountersTests.swift */, D388BE00B42C06ED9D9905BF /* ModelBackedInferenceValidationTests.swift */, @@ -238,6 +245,7 @@ E5E6AD02CDF23BDAB64700A7 /* ChatInputView.swift */, DB1A5E8B1C9F2BC4D262C53A /* ChatMessagesView.swift */, 2DC8C86D397B1FCA08E07CBD /* DownloadModalView.swift */, + 7AE2A32FBB744696DEA77435 /* GenerationSettingsEditor.swift */, C3C3A76C02AF70A9D8F868FC /* ModelPickerView.swift */, 4239CFF94B819C35A8D4D617 /* MonitorView.swift */, 37FEB592E5E717F817B03151 /* SceneManagementView.swift */, @@ -263,6 +271,7 @@ children = ( A4B359324B5FD8D106C74338 /* ChatMessage.swift */, C04EE8E6418EC6E9B66999B0 /* ChatScene.swift */, + 6FAF7455BD387CD2061E0CBF /* GenerationSettings.swift */, E35452B166893B25E765FF70 /* InferenceStats.swift */, 38DFC212AF4359A45FBE22BA /* ModelConfig.swift */, ); @@ -404,6 +413,7 @@ CBC9DB0799C4ADF2DC9319DA /* APIServerRewriteTests.swift in Sources */, 962083CCCC4AC848E0BBBC99 /* CancellationTokenTests.swift in Sources */, 95A612524552AF5CC3B1AE62 /* ChatViewModelTests.swift in Sources */, + 847B445654860396AF5A8280 /* GenerationSettingsTests.swift in Sources */, E92B6656C251EDA246B8F582 /* ImageDecoderTests.swift in Sources */, 67B815DC3304BF4B2E9974A8 /* LiveCountersTests.swift in Sources */, 8E665E21CCCD87A907CEA78D /* ModelBackedInferenceValidationTests.swift in Sources */, @@ -437,6 +447,8 @@ 5946258F1DE88CE904584E0B /* ContentView.swift in Sources */, C07A377244DCD67F4FE709FE /* DownloadModalView.swift in Sources */, 4DC033E45880B2948B47DEB1 /* FocusedValues.swift in Sources */, + 2640EDCA9033D85C0B785557 /* GenerationSettings.swift in Sources */, + 0BC7203552A161BC852975EA /* GenerationSettingsEditor.swift in Sources */, A146BBA70CFBEC505BDCDF0D /* ImageDecoder.swift in Sources */, EC4FC68608DDFA6A3DF133CC /* InferenceEngine.swift in Sources */, 2D08769282BD71C170DB0943 /* InferenceStats.swift in Sources */, diff --git a/MLXServer/Documents/ChatDocumentManifest.swift b/MLXServer/Documents/ChatDocumentManifest.swift index 7a61e6d..74f513a 100644 --- a/MLXServer/Documents/ChatDocumentManifest.swift +++ b/MLXServer/Documents/ChatDocumentManifest.swift @@ -11,7 +11,7 @@ struct ChatDocumentManifest: Codable { var messages: [StoredChatMessage] var uiState: StoredChatUIState - static let currentSchemaVersion = 1 + static let currentSchemaVersion = 2 struct StoredModelInfo: Codable, Hashable { var id: String @@ -23,6 +23,69 @@ struct ChatDocumentManifest: Codable { var systemPrompt: String var thinkingEnabled: Bool var temperature: Double + var topP: Double + var topK: Int + var minP: Double + var maxTokens: Int + var repetitionPenalty: Double? + var presencePenalty: Double? + var frequencyPenalty: Double? + + init(systemPrompt: String, generationSettings: GenerationSettings) { + self.systemPrompt = systemPrompt + self.thinkingEnabled = generationSettings.thinkingEnabled + self.temperature = generationSettings.temperature + self.topP = generationSettings.topP + self.topK = generationSettings.topK + self.minP = generationSettings.minP + self.maxTokens = generationSettings.maxTokens + self.repetitionPenalty = generationSettings.repetitionPenalty + self.presencePenalty = generationSettings.presencePenalty + self.frequencyPenalty = generationSettings.frequencyPenalty + } + + var generationSettings: GenerationSettings { + GenerationSettings( + temperature: temperature, + topP: topP, + topK: topK, + minP: minP, + maxTokens: maxTokens, + repetitionPenalty: repetitionPenalty, + presencePenalty: presencePenalty, + frequencyPenalty: frequencyPenalty, + thinkingEnabled: thinkingEnabled + ).normalized() + } + + private enum CodingKeys: String, CodingKey { + case systemPrompt + case thinkingEnabled + case temperature + case topP + case topK + case minP + case maxTokens + case repetitionPenalty + case presencePenalty + case frequencyPenalty + } + + init(from decoder: Decoder) throws { + let container = try decoder.container(keyedBy: CodingKeys.self) + let fallback = GenerationSettings() + + systemPrompt = try container.decodeIfPresent(String.self, forKey: .systemPrompt) ?? "" + thinkingEnabled = try container.decodeIfPresent(Bool.self, forKey: .thinkingEnabled) ?? fallback.thinkingEnabled + temperature = try container.decodeIfPresent(Double.self, forKey: .temperature) ?? fallback.temperature + topP = try container.decodeIfPresent(Double.self, forKey: .topP) ?? fallback.topP + topK = try container.decodeIfPresent(Int.self, forKey: .topK) ?? fallback.topK + minP = try container.decodeIfPresent(Double.self, forKey: .minP) ?? fallback.minP + maxTokens = try container.decodeIfPresent(Int.self, forKey: .maxTokens) ?? fallback.maxTokens + repetitionPenalty = try container.decodeIfPresent(Double.self, forKey: .repetitionPenalty) + presencePenalty = try container.decodeIfPresent(Double.self, forKey: .presencePenalty) + frequencyPenalty = try container.decodeIfPresent(Double.self, forKey: .frequencyPenalty) + } } struct StoredChatUIState: Codable, Hashable { diff --git a/MLXServer/Documents/ChatDocumentMigration.swift b/MLXServer/Documents/ChatDocumentMigration.swift index d7b0782..a6a0091 100644 --- a/MLXServer/Documents/ChatDocumentMigration.swift +++ b/MLXServer/Documents/ChatDocumentMigration.swift @@ -12,6 +12,8 @@ enum ChatDocumentMigration { switch envelope.schemaVersion { case 1: return try decoder.decode(ChatDocumentManifest.self, from: data) + case 2: + return try decoder.decode(ChatDocumentManifest.self, from: data) default: throw ChatDocumentError.unsupportedSchemaVersion(envelope.schemaVersion) } diff --git a/MLXServer/Models/ChatScene.swift b/MLXServer/Models/ChatScene.swift index d936ba7..59bde57 100644 --- a/MLXServer/Models/ChatScene.swift +++ b/MLXServer/Models/ChatScene.swift @@ -6,19 +6,41 @@ struct ChatScene: Codable, Identifiable, Hashable { var modelId: String? var systemPrompt: String var starterPrompt: String + var generationOverrides: GenerationSettingsOverride init( id: UUID = UUID(), name: String, modelId: String? = nil, systemPrompt: String = "", - starterPrompt: String = "" + starterPrompt: String = "", + generationOverrides: GenerationSettingsOverride = .none ) { self.id = id self.name = name self.modelId = modelId self.systemPrompt = systemPrompt self.starterPrompt = starterPrompt + self.generationOverrides = generationOverrides + } + + private enum CodingKeys: String, CodingKey { + case id + case name + case modelId + case systemPrompt + case starterPrompt + case generationOverrides + } + + init(from decoder: Decoder) throws { + let container = try decoder.container(keyedBy: CodingKeys.self) + id = try container.decode(UUID.self, forKey: .id) + name = try container.decode(String.self, forKey: .name) + modelId = try container.decodeIfPresent(String.self, forKey: .modelId) + systemPrompt = try container.decodeIfPresent(String.self, forKey: .systemPrompt) ?? "" + starterPrompt = try container.decodeIfPresent(String.self, forKey: .starterPrompt) ?? "" + generationOverrides = try container.decodeIfPresent(GenerationSettingsOverride.self, forKey: .generationOverrides) ?? .none } var trimmedName: String { diff --git a/MLXServer/Models/GenerationSettings.swift b/MLXServer/Models/GenerationSettings.swift new file mode 100644 index 0000000..7a4e5ba --- /dev/null +++ b/MLXServer/Models/GenerationSettings.swift @@ -0,0 +1,157 @@ +import Foundation + +struct GenerationSettings: Codable, Hashable, Sendable { + var temperature: Double + var topP: Double + var topK: Int + var minP: Double + var maxTokens: Int + var repetitionPenalty: Double? + var presencePenalty: Double? + var frequencyPenalty: Double? + var thinkingEnabled: Bool + + init( + temperature: Double = 0.7, + topP: Double = 1.0, + topK: Int = 0, + minP: Double = 0.0, + maxTokens: Int = 4096, + repetitionPenalty: Double? = nil, + presencePenalty: Double? = nil, + frequencyPenalty: Double? = nil, + thinkingEnabled: Bool = true + ) { + self.temperature = temperature + self.topP = topP + self.topK = topK + self.minP = minP + self.maxTokens = maxTokens + self.repetitionPenalty = repetitionPenalty + self.presencePenalty = presencePenalty + self.frequencyPenalty = frequencyPenalty + self.thinkingEnabled = thinkingEnabled + } + + func normalized() -> GenerationSettings { + GenerationSettings( + temperature: max(0, temperature), + topP: min(max(topP, 0), 1), + topK: max(0, topK), + minP: min(max(minP, 0), 1), + maxTokens: max(1, maxTokens), + repetitionPenalty: Self.normalizePositive(repetitionPenalty), + presencePenalty: Self.normalizeSignedPenalty(presencePenalty), + frequencyPenalty: Self.normalizeSignedPenalty(frequencyPenalty), + thinkingEnabled: thinkingEnabled + ) + } + + func applying(_ overrides: GenerationSettingsOverride) -> GenerationSettings { + GenerationSettings( + temperature: overrides.temperature ?? temperature, + topP: overrides.topP ?? topP, + topK: overrides.topK ?? topK, + minP: overrides.minP ?? minP, + maxTokens: overrides.maxTokens ?? maxTokens, + repetitionPenalty: overrides.repetitionPenalty ?? repetitionPenalty, + presencePenalty: overrides.presencePenalty ?? presencePenalty, + frequencyPenalty: overrides.frequencyPenalty ?? frequencyPenalty, + thinkingEnabled: overrides.thinkingEnabled ?? thinkingEnabled + ) + .normalized() + } + + static func modelDefault(for modelId: String, legacyThinkingEnabled: Bool = true) -> GenerationSettings { + let fallback = ModelConfig.resolve(modelId)?.defaultGenerationSettings ?? .generalDefault + var resolved = fallback + if !legacyThinkingEnabled { + resolved.thinkingEnabled = false + } + return resolved.normalized() + } + + static let generalDefault = GenerationSettings() + + static let technicalDefault = GenerationSettings( + temperature: 0.35, + topP: 0.9, + topK: 40, + minP: 0.0, + maxTokens: 4096, + repetitionPenalty: 1.05, + presencePenalty: nil, + frequencyPenalty: nil, + thinkingEnabled: true + ) + + static let roleplayDefault = GenerationSettings( + temperature: 0.85, + topP: 0.95, + topK: 60, + minP: 0.0, + maxTokens: 4096, + repetitionPenalty: 1.02, + presencePenalty: nil, + frequencyPenalty: nil, + thinkingEnabled: false + ) + + private static func normalizePositive(_ value: Double?) -> Double? { + guard let value else { return nil } + return value > 0 ? value : nil + } + + private static func normalizeSignedPenalty(_ value: Double?) -> Double? { + guard let value else { return nil } + return min(max(value, -2), 2) + } +} + +struct GenerationSettingsOverride: Codable, Hashable, Sendable { + var temperature: Double? + var topP: Double? + var topK: Int? + var minP: Double? + var maxTokens: Int? + var repetitionPenalty: Double? + var presencePenalty: Double? + var frequencyPenalty: Double? + var thinkingEnabled: Bool? + + init( + temperature: Double? = nil, + topP: Double? = nil, + topK: Int? = nil, + minP: Double? = nil, + maxTokens: Int? = nil, + repetitionPenalty: Double? = nil, + presencePenalty: Double? = nil, + frequencyPenalty: Double? = nil, + thinkingEnabled: Bool? = nil + ) { + self.temperature = temperature + self.topP = topP + self.topK = topK + self.minP = minP + self.maxTokens = maxTokens + self.repetitionPenalty = repetitionPenalty + self.presencePenalty = presencePenalty + self.frequencyPenalty = frequencyPenalty + self.thinkingEnabled = thinkingEnabled + } + + static let none = GenerationSettingsOverride() + + var hasOverrides: Bool { + temperature != nil + || topP != nil + || topK != nil + || minP != nil + || maxTokens != nil + || repetitionPenalty != nil + || presencePenalty != nil + || frequencyPenalty != nil + || thinkingEnabled != nil + } +} \ No newline at end of file diff --git a/MLXServer/Models/ModelConfig.swift b/MLXServer/Models/ModelConfig.swift index dba35aa..841c1e9 100644 --- a/MLXServer/Models/ModelConfig.swift +++ b/MLXServer/Models/ModelConfig.swift @@ -15,6 +15,7 @@ struct ModelConfig: Identifiable, Hashable { let loaderKind: LoaderKind let supportsImages: Bool let supportsTools: Bool + let defaultGenerationSettings: GenerationSettings /// All models supported by the app. static let availableModels: [ModelConfig] = [ @@ -25,7 +26,8 @@ struct ModelConfig: Identifiable, Hashable { contextLength: 128_000, loaderKind: .vlm, supportsImages: true, - supportsTools: true + supportsTools: true, + defaultGenerationSettings: .technicalDefault ), ModelConfig( id: "qwen", @@ -34,7 +36,8 @@ struct ModelConfig: Identifiable, Hashable { contextLength: 256_000, loaderKind: .vlm, supportsImages: true, - supportsTools: true + supportsTools: true, + defaultGenerationSettings: .technicalDefault ), ModelConfig( id: "qwen3.5-0.8b", @@ -43,7 +46,8 @@ struct ModelConfig: Identifiable, Hashable { contextLength: 256_000, loaderKind: .vlm, supportsImages: true, - supportsTools: true + supportsTools: true, + defaultGenerationSettings: .technicalDefault ), ModelConfig( id: "qwen3.5-9b", @@ -52,7 +56,8 @@ struct ModelConfig: Identifiable, Hashable { contextLength: 256_000, loaderKind: .vlm, supportsImages: true, - supportsTools: true + supportsTools: true, + defaultGenerationSettings: .technicalDefault ), ModelConfig( id: "stheno", @@ -61,7 +66,8 @@ struct ModelConfig: Identifiable, Hashable { contextLength: 8_192, loaderKind: .llm, supportsImages: false, - supportsTools: false + supportsTools: false, + defaultGenerationSettings: .roleplayDefault ), ] diff --git a/MLXServer/Server/APIModels.swift b/MLXServer/Server/APIModels.swift index 61563ee..52faf3f 100644 --- a/MLXServer/Server/APIModels.swift +++ b/MLXServer/Server/APIModels.swift @@ -152,15 +152,52 @@ struct APIChatCompletionRequest: Codable { let messages: [APIChatMessage] let temperature: Double? let top_p: Double? + let top_k: Int? + let min_p: Double? let max_tokens: Int? let stream: Bool? let stop: StopSequence? let tools: [APIToolDefinition]? let tool_choice: AnyCodable? + let repetition_penalty: Double? let frequency_penalty: Double? let presence_penalty: Double? let n: Int? + init( + model: String?, + messages: [APIChatMessage], + temperature: Double? = nil, + top_p: Double? = nil, + max_tokens: Int? = nil, + stream: Bool? = nil, + stop: StopSequence? = nil, + tools: [APIToolDefinition]? = nil, + tool_choice: AnyCodable? = nil, + frequency_penalty: Double? = nil, + presence_penalty: Double? = nil, + n: Int? = nil, + top_k: Int? = nil, + min_p: Double? = nil, + repetition_penalty: Double? = nil + ) { + self.model = model + self.messages = messages + self.temperature = temperature + self.top_p = top_p + self.top_k = top_k + self.min_p = min_p + self.max_tokens = max_tokens + self.stream = stream + self.stop = stop + self.tools = tools + self.tool_choice = tool_choice + self.repetition_penalty = repetition_penalty + self.frequency_penalty = frequency_penalty + self.presence_penalty = presence_penalty + self.n = n + } + enum StopSequence: Codable { case single(String) case multiple([String]) diff --git a/MLXServer/Server/APIServer.swift b/MLXServer/Server/APIServer.swift index ccc6ed6..28fba32 100644 --- a/MLXServer/Server/APIServer.swift +++ b/MLXServer/Server/APIServer.swift @@ -15,12 +15,19 @@ final class APIServer { let matchedTokenCount: Int } + struct DebugGenerationSettingsEvent: Sendable { + let requestId: String + let modelId: String + let settings: GenerationSettings + } + private struct ActiveRequest { let connection: NWConnection let cancellation: CancellationToken } nonisolated(unsafe) static var debugLookupEventHandler: (@Sendable (DebugLookupEvent) -> Void)? + nonisolated(unsafe) static var debugGenerationSettingsEventHandler: (@Sendable (DebugGenerationSettingsEvent) -> Void)? var isRunning = false var port: Int = 1234 @@ -256,15 +263,26 @@ final class APIServer { modelManager.touchActivity() - let isStream = request.stream ?? false - let temperature = request.temperature ?? 0.7 - let topP = request.top_p ?? 1.0 - let maxTokens = request.max_tokens ?? 4096 let requestId = "chatcmpl-\(UUID().uuidString.prefix(12).lowercased())" let created = Int(Date().timeIntervalSince1970) let modelName = request.model ?? modelManager.currentModel?.repoId ?? "unknown" let currentModel = modelManager.currentModel let contextLength = modelManager.currentModel?.contextLength ?? 0 + let baseSettings = Preferences.generationSettings(forModelId: currentModel?.id ?? ModelConfig.default.id) + let generationSettings = baseSettings.applying( + GenerationSettingsOverride( + temperature: request.temperature, + topP: request.top_p, + topK: request.top_k, + minP: request.min_p, + maxTokens: request.max_tokens, + repetitionPenalty: request.repetition_penalty, + presencePenalty: request.presence_penalty, + frequencyPenalty: request.frequency_penalty + ) + ) + let isStream = request.stream ?? false + let maxTokens = generationSettings.maxTokens if let tools = request.tools, !tools.isEmpty, currentModel?.supportsTools != true { sendResponse( @@ -281,10 +299,14 @@ final class APIServer { let preparedPrompt = PromptBuilder.build( from: request, modelId: currentModelRepoId, - thinkingEnabled: Preferences.enableThinking + thinkingEnabled: generationSettings.thinkingEnabled ) let isQwen = currentModelRepoId.lowercased().contains("qwen") + Self.debugGenerationSettingsEventHandler?( + DebugGenerationSettingsEvent(requestId: requestId, modelId: currentModelRepoId, settings: generationSettings) + ) + if preparedPrompt.containsImages, currentModel?.supportsImages != true { LiveCounters.shared.requestCompleted(requestId: requestId, generationTokens: 0) sendResponse( @@ -315,8 +337,16 @@ final class APIServer { let generateParams = GenerateParameters( maxTokens: maxTokens, - temperature: Float(temperature), - topP: Float(topP) + temperature: Float(generationSettings.temperature), + topP: Float(generationSettings.topP), + topK: generationSettings.topK, + minP: Float(generationSettings.minP), + repetitionPenalty: generationSettings.repetitionPenalty.map(Float.init), + repetitionContextSize: 128, + presencePenalty: generationSettings.presencePenalty.map(Float.init), + presenceContextSize: 128, + frequencyPenalty: generationSettings.frequencyPenalty.map(Float.init), + frequencyContextSize: 128 ) let currentModelId = modelManager.currentModel?.id ?? modelName let engine = InferenceEngine(container: container) diff --git a/MLXServer/Utilities/Preferences.swift b/MLXServer/Utilities/Preferences.swift index 11ff9a2..e894788 100644 --- a/MLXServer/Utilities/Preferences.swift +++ b/MLXServer/Utilities/Preferences.swift @@ -6,6 +6,7 @@ enum Preferences { private static let jsonEncoder = JSONEncoder() private static let jsonDecoder = JSONDecoder() + private static let legacyThinkingDefault = true // MARK: - Last used model @@ -79,12 +80,53 @@ enum Preferences { // MARK: - Thinking mode private static let enableThinkingKey = "enableThinking" + private static let modelGenerationSettingsKey = "modelGenerationSettings" /// Whether to enable thinking/reasoning mode for models that support it (e.g. Qwen3.5). /// When disabled, the model skips internal reasoning and responds directly. static var enableThinking: Bool { - get { defaults.object(forKey: enableThinkingKey) == nil ? true : defaults.bool(forKey: enableThinkingKey) } - set { defaults.set(newValue, forKey: enableThinkingKey) } + get { + let modelId = defaultModelId ?? lastModelId ?? ModelConfig.default.id + if modelGenerationSettingsMap[modelId] != nil { + return generationSettings(forModelId: modelId).thinkingEnabled + } + return defaults.object(forKey: enableThinkingKey) == nil ? Self.legacyThinkingDefault : defaults.bool(forKey: enableThinkingKey) + } + set { + let modelId = defaultModelId ?? lastModelId ?? ModelConfig.default.id + var settings = generationSettings(forModelId: modelId) + settings.thinkingEnabled = newValue + setGenerationSettings(settings, forModelId: modelId) + defaults.set(newValue, forKey: enableThinkingKey) + } + } + + static func generationSettings(forModelId modelId: String) -> GenerationSettings { + let legacyThinking = defaults.object(forKey: enableThinkingKey) == nil ? Self.legacyThinkingDefault : defaults.bool(forKey: enableThinkingKey) + return (modelGenerationSettingsMap[modelId] ?? GenerationSettings.modelDefault(for: modelId, legacyThinkingEnabled: legacyThinking)).normalized() + } + + static func setGenerationSettings(_ settings: GenerationSettings, forModelId modelId: String) { + var map = modelGenerationSettingsMap + let normalized = settings.normalized() + map[modelId] = normalized + modelGenerationSettingsMap = map + defaults.set(normalized.thinkingEnabled, forKey: enableThinkingKey) + } + + static func hasGenerationSettings(forModelId modelId: String) -> Bool { + modelGenerationSettingsMap[modelId] != nil + } + + private static var modelGenerationSettingsMap: [String: GenerationSettings] { + get { + guard let data = defaults.data(forKey: modelGenerationSettingsKey) else { return [:] } + return (try? jsonDecoder.decode([String: GenerationSettings].self, from: data)) ?? [:] + } + set { + guard let data = try? jsonEncoder.encode(newValue) else { return } + defaults.set(data, forKey: modelGenerationSettingsKey) + } } // MARK: - Idle unload diff --git a/MLXServer/ViewModels/ChatViewModel.swift b/MLXServer/ViewModels/ChatViewModel.swift index bfeefed..bc84931 100644 --- a/MLXServer/ViewModels/ChatViewModel.swift +++ b/MLXServer/ViewModels/ChatViewModel.swift @@ -28,8 +28,7 @@ final class ChatViewModel { private var documentId = UUID() private var documentCreatedAt = Date() private var documentSystemPromptOverride: String? - private var documentThinkingOverride: Bool? - private var documentTemperature = 0.7 + private var documentGenerationSettingsOverride: GenerationSettings? let modelManager: ModelManager let apiServer = APIServer() @@ -55,12 +54,25 @@ final class ChatViewModel { guard let container = modelManager.modelContainer else { return } if chatSession == nil { let systemPrompt = effectiveSystemPrompt + let generationSettings = effectiveGenerationSettings // Pass enable_thinking to the Jinja chat template context. // Qwen3.5 and similar models use this to control reasoning mode. - let thinkingContext: [String: any Sendable]? = effectiveThinkingEnabled + let thinkingContext: [String: any Sendable]? = generationSettings.thinkingEnabled ? nil : ["enable_thinking": false] - let generateParameters = GenerateParameters(temperature: Float(documentTemperature)) + let generateParameters = GenerateParameters( + maxTokens: generationSettings.maxTokens, + temperature: Float(generationSettings.temperature), + topP: Float(generationSettings.topP), + topK: generationSettings.topK, + minP: Float(generationSettings.minP), + repetitionPenalty: generationSettings.repetitionPenalty.map(Float.init), + repetitionContextSize: 128, + presencePenalty: generationSettings.presencePenalty.map(Float.init), + presenceContextSize: 128, + frequencyPenalty: generationSettings.frequencyPenalty.map(Float.init), + frequencyContextSize: 128 + ) let history = conversation.messages.compactMap(historyMessage(from:)) if history.isEmpty { chatSession = ChatSession( @@ -96,8 +108,17 @@ final class ChatViewModel { return parts.joined(separator: "\n\n") } - private var effectiveThinkingEnabled: Bool { - documentThinkingOverride ?? Preferences.enableThinking + private var effectiveGenerationSettings: GenerationSettings { + if let documentGenerationSettingsOverride { + return documentGenerationSettingsOverride + } + + let modelId = activeScene?.resolvedModel?.id + ?? modelManager.currentModel?.id + ?? Preferences.defaultModelId + ?? ModelConfig.default.id + return Preferences.generationSettings(forModelId: modelId) + .applying(activeScene?.generationOverrides ?? .none) } func send() { @@ -269,8 +290,7 @@ final class ChatViewModel { documentId = package.manifest.documentId documentCreatedAt = package.manifest.createdAt documentSystemPromptOverride = package.manifest.settings.systemPrompt - documentThinkingOverride = package.manifest.settings.thinkingEnabled - documentTemperature = package.manifest.settings.temperature + documentGenerationSettingsOverride = package.manifest.settings.generationSettings resetSession() lastSavedSnapshotHash = try snapshotHash() hasUnsavedChanges = false @@ -316,8 +336,7 @@ final class ChatViewModel { documentId = UUID() documentCreatedAt = Date() documentSystemPromptOverride = nil - documentThinkingOverride = nil - documentTemperature = 0.7 + documentGenerationSettingsOverride = nil } private func restoreMessage( @@ -398,11 +417,7 @@ final class ChatViewModel { updatedAt: updatedAt, appVersion: Bundle.main.object(forInfoDictionaryKey: "CFBundleShortVersionString") as? String ?? "1.0.0", model: currentStoredModelInfo, - settings: .init( - systemPrompt: effectiveSystemPrompt, - thinkingEnabled: effectiveThinkingEnabled, - temperature: documentTemperature - ), + settings: .init(systemPrompt: effectiveSystemPrompt, generationSettings: effectiveGenerationSettings), messages: messages, uiState: .init( draftInput: inputText, @@ -443,11 +458,7 @@ final class ChatViewModel { documentId: documentId, createdAt: documentCreatedAt, model: currentStoredModelInfo, - settings: .init( - systemPrompt: effectiveSystemPrompt, - thinkingEnabled: effectiveThinkingEnabled, - temperature: documentTemperature - ), + settings: .init(systemPrompt: effectiveSystemPrompt, generationSettings: effectiveGenerationSettings), messages: makeManifest(updatedAt: documentCreatedAt).messages, uiState: .init(draftInput: inputText, scrollAnchorMessageId: conversation.messages.last?.id) ) diff --git a/MLXServer/ViewModels/SceneStore.swift b/MLXServer/ViewModels/SceneStore.swift index 3f3cf34..112d905 100644 --- a/MLXServer/ViewModels/SceneStore.swift +++ b/MLXServer/ViewModels/SceneStore.swift @@ -16,7 +16,8 @@ final class SceneStore { name: scene.displayName, modelId: scene.modelId, systemPrompt: scene.systemPrompt, - starterPrompt: scene.starterPrompt + starterPrompt: scene.starterPrompt, + generationOverrides: scene.generationOverrides ) } else { nextScene = .empty diff --git a/MLXServer/Views/GenerationSettingsEditor.swift b/MLXServer/Views/GenerationSettingsEditor.swift new file mode 100644 index 0000000..a4022d1 --- /dev/null +++ b/MLXServer/Views/GenerationSettingsEditor.swift @@ -0,0 +1,144 @@ +import SwiftUI + +private let generationDoubleFormat = FloatingPointFormatStyle.number.precision(.fractionLength(0...2)) +private let generationIntegerFormat = IntegerFormatStyle.number.grouping(.never) + +struct GenerationDefaultsEditor: View { + @Binding var settings: GenerationSettings + + var body: some View { + Toggle("Enable thinking mode", isOn: $settings.thinkingEnabled) + doubleRow("Temperature", value: $settings.temperature) + doubleRow("Top P", value: $settings.topP) + intRow("Top K", value: $settings.topK) + doubleRow("Min P", value: $settings.minP) + intRow("Max tokens", value: $settings.maxTokens) + optionalDoubleRow("Repetition penalty", value: $settings.repetitionPenalty) + optionalDoubleRow("Presence penalty", value: $settings.presencePenalty) + optionalDoubleRow("Frequency penalty", value: $settings.frequencyPenalty) + } + + private func doubleRow(_ title: String, value: Binding) -> some View { + HStack { + Text(title) + Spacer() + TextField(title, value: value, format: generationDoubleFormat) + .multilineTextAlignment(.trailing) + .frame(width: 90) + } + } + + private func intRow(_ title: String, value: Binding) -> some View { + HStack { + Text(title) + Spacer() + TextField(title, value: value, format: generationIntegerFormat) + .multilineTextAlignment(.trailing) + .frame(width: 90) + } + } + + private func optionalDoubleRow(_ title: String, value: Binding) -> some View { + HStack { + Text(title) + Spacer() + TextField(title, value: binding(for: value), format: generationDoubleFormat) + .multilineTextAlignment(.trailing) + .frame(width: 90) + Button(value.wrappedValue == nil ? "Set" : "Clear") { + if value.wrappedValue == nil { + value.wrappedValue = 1.0 + } else { + value.wrappedValue = nil + } + } + .buttonStyle(.link) + } + } + + private func binding(for value: Binding) -> Binding { + Binding( + get: { value.wrappedValue ?? 1.0 }, + set: { value.wrappedValue = $0 } + ) + } +} + +struct GenerationOverridesEditor: View { + @Binding var overrides: GenerationSettingsOverride + let inheritedSettings: GenerationSettings + let inheritedSource: String + + var body: some View { + Picker("Thinking mode", selection: $overrides.thinkingEnabled) { + Text("Inherited (\(inheritedSettings.thinkingEnabled ? "Enabled" : "Disabled"))").tag(Optional.none) + Text("Enabled").tag(Optional(true)) + Text("Disabled").tag(Optional(false)) + } + + optionalDoubleRow("Temperature", value: $overrides.temperature, inheritedValue: inheritedSettings.temperature) + optionalDoubleRow("Top P", value: $overrides.topP, inheritedValue: inheritedSettings.topP) + optionalIntRow("Top K", value: $overrides.topK, inheritedValue: inheritedSettings.topK) + optionalDoubleRow("Min P", value: $overrides.minP, inheritedValue: inheritedSettings.minP) + optionalIntRow("Max tokens", value: $overrides.maxTokens, inheritedValue: inheritedSettings.maxTokens) + optionalDoubleRow("Repetition penalty", value: $overrides.repetitionPenalty, inheritedValue: inheritedSettings.repetitionPenalty ?? 0) + optionalDoubleRow("Presence penalty", value: $overrides.presencePenalty, inheritedValue: inheritedSettings.presencePenalty ?? 0) + optionalDoubleRow("Frequency penalty", value: $overrides.frequencyPenalty, inheritedValue: inheritedSettings.frequencyPenalty ?? 0) + + Text("Unset fields inherit from \(inheritedSource). The values shown are the effective starting values for this scene.") + .font(.caption) + .foregroundStyle(.secondary) + } + + private func optionalDoubleRow(_ title: String, value: Binding, inheritedValue: Double) -> some View { + HStack { + Text(title) + Spacer() + TextField(title, value: Binding( + get: { value.wrappedValue ?? inheritedValue }, + set: { value.wrappedValue = $0 } + ), format: generationDoubleFormat) + .multilineTextAlignment(.trailing) + .frame(width: 90) + if value.wrappedValue == nil { + Text("Inherited") + .font(.caption) + .foregroundStyle(.secondary) + } + Button(value.wrappedValue == nil ? "Override" : "Clear") { + if value.wrappedValue == nil { + value.wrappedValue = inheritedValue + } else { + value.wrappedValue = nil + } + } + .buttonStyle(.link) + } + } + + private func optionalIntRow(_ title: String, value: Binding, inheritedValue: Int) -> some View { + HStack { + Text(title) + Spacer() + TextField(title, value: Binding( + get: { value.wrappedValue ?? inheritedValue }, + set: { value.wrappedValue = $0 } + ), format: generationIntegerFormat) + .multilineTextAlignment(.trailing) + .frame(width: 90) + if value.wrappedValue == nil { + Text("Inherited") + .font(.caption) + .foregroundStyle(.secondary) + } + Button(value.wrappedValue == nil ? "Override" : "Clear") { + if value.wrappedValue == nil { + value.wrappedValue = inheritedValue + } else { + value.wrappedValue = nil + } + } + .buttonStyle(.link) + } + } +} \ No newline at end of file diff --git a/MLXServer/Views/SceneManagementView.swift b/MLXServer/Views/SceneManagementView.swift index 20ff1a1..1ea7747 100644 --- a/MLXServer/Views/SceneManagementView.swift +++ b/MLXServer/Views/SceneManagementView.swift @@ -246,6 +246,14 @@ private struct SceneEditorView: View { .font(.caption) .foregroundStyle(.secondary) } + + Section("Generation Overrides") { + GenerationOverridesEditor( + overrides: generationOverridesBinding, + inheritedSettings: inheritedGenerationSettings, + inheritedSource: inheritedGenerationSource + ) + } } .formStyle(.grouped) .navigationTitle(scene.displayName) @@ -272,4 +280,35 @@ private struct SceneEditorView: View { } ) } + + private var generationOverridesBinding: Binding { + Binding( + get: { sceneStore.scene(id: scene.id)?.generationOverrides ?? scene.generationOverrides }, + set: { newValue in + sceneStore.updateScene(id: scene.id) { + $0.generationOverrides = newValue + } + } + ) + } + + private var effectiveModelId: String { + sceneStore.scene(id: scene.id)?.modelId + ?? scene.modelId + ?? Preferences.defaultModelId + ?? Preferences.lastModelId + ?? ModelConfig.default.id + } + + private var inheritedGenerationSettings: GenerationSettings { + Preferences.generationSettings(forModelId: effectiveModelId) + } + + private var inheritedGenerationSource: String { + let modelName = ModelConfig.resolve(effectiveModelId)?.displayName ?? effectiveModelId + if Preferences.hasGenerationSettings(forModelId: effectiveModelId) { + return "saved \(modelName) defaults" + } + return "built-in \(modelName) defaults" + } } \ No newline at end of file diff --git a/MLXServer/Views/SettingsView.swift b/MLXServer/Views/SettingsView.swift index d4fddb5..f79609a 100644 --- a/MLXServer/Views/SettingsView.swift +++ b/MLXServer/Views/SettingsView.swift @@ -8,7 +8,7 @@ struct SettingsView: View { @State private var apiAutoStart: Bool = Preferences.apiAutoStart @State private var idleUnloadMinutes: String = String(Preferences.idleUnloadMinutes) @State private var defaultModelId: String = Preferences.defaultModelId ?? ModelConfig.default.id - @State private var enableThinking: Bool = Preferences.enableThinking + @State private var generationDefaultsModelId: String = Preferences.defaultModelId ?? ModelConfig.default.id @State private var kvQuantizationEnabled: Bool = Preferences.kvQuantizationEnabled @State private var kvQuantizationBits: Int = Preferences.kvQuantizationBits @@ -42,13 +42,16 @@ struct SettingsView: View { .foregroundStyle(.secondary) } - Section("Generation") { - Toggle("Enable thinking mode", isOn: $enableThinking) - .onChange(of: enableThinking) { - Preferences.enableThinking = enableThinking + Section("Generation Defaults") { + Picker("Defaults for model", selection: $generationDefaultsModelId) { + ForEach(ModelConfig.availableModels) { model in + Text(model.displayName).tag(model.id) } + } - Text("When enabled, models like Qwen3.5 reason internally before responding. Produces better answers but slower. Takes effect on the next conversation.") + GenerationDefaultsEditor(settings: generationDefaultsBinding) + + Text("These are the per-model defaults used by chat sessions and by the API server whenever a request omits a generation parameter. Lower temperature and stronger repetition penalties are usually better for technical work; higher temperature is usually better for improvisation and roleplay.") .font(.caption) .foregroundStyle(.secondary) } @@ -162,4 +165,11 @@ struct SettingsView: View { .formStyle(.grouped) .frame(width: 450, height: 650) } + + private var generationDefaultsBinding: Binding { + Binding( + get: { Preferences.generationSettings(forModelId: generationDefaultsModelId) }, + set: { Preferences.setGenerationSettings($0, forModelId: generationDefaultsModelId) } + ) + } } diff --git a/MLXServerTests/Server/APIServerRewriteTests.swift b/MLXServerTests/Server/APIServerRewriteTests.swift index 6899d77..7435649 100644 --- a/MLXServerTests/Server/APIServerRewriteTests.swift +++ b/MLXServerTests/Server/APIServerRewriteTests.swift @@ -1174,6 +1174,102 @@ final class APIServerRewriteTests: XCTestCase { XCTAssertGreaterThan(finalLiveSnapshot.totalCacheReusePromptTokens, afterDisconnectLiveSnapshot.totalCacheReusePromptTokens) } + func testAPIServerUsesModelDefaultsAndRequestOverridesTakePrecedence() async throws { + let modelId = self.genericModelId + let originalSettings = Preferences.generationSettings(forModelId: modelId) + let collector = GenerationSettingsEventCollector() + + Preferences.setGenerationSettings( + GenerationSettings( + temperature: 0.11, + topP: 0.77, + topK: 9, + minP: 0.04, + maxTokens: 3, + repetitionPenalty: 1.18, + presencePenalty: 0.25, + frequencyPenalty: 0.4, + thinkingEnabled: false + ), + forModelId: modelId + ) + APIServer.debugGenerationSettingsEventHandler = { event in + Task { + await collector.record(event) + } + } + + defer { + Preferences.setGenerationSettings(originalSettings, forModelId: modelId) + APIServer.debugGenerationSettingsEventHandler = nil + } + + let harness = try await makeHarness(initialModelId: modelId) + defer { harness.stop() } + + _ = try await sendChatCompletion( + APIChatCompletionRequest( + model: modelId, + messages: [ + APIChatMessage(role: "user", content: .text("Reply with one short word."), name: nil, tool_calls: nil, tool_call_id: nil) + ], + stream: false + ), + port: harness.port + ) + + try await waitUntil(timeoutSeconds: 5) { + await collector.events().count == 1 + } + + let firstEvents = await collector.events() + let firstEvent = try XCTUnwrap(firstEvents.first) + XCTAssertEqual(firstEvent.settings.temperature, 0.11) + XCTAssertEqual(firstEvent.settings.topP, 0.77) + XCTAssertEqual(firstEvent.settings.topK, 9) + XCTAssertEqual(firstEvent.settings.minP, 0.04) + XCTAssertEqual(firstEvent.settings.maxTokens, 3) + XCTAssertEqual(firstEvent.settings.repetitionPenalty, 1.18) + XCTAssertEqual(firstEvent.settings.presencePenalty, 0.25) + XCTAssertEqual(firstEvent.settings.frequencyPenalty, 0.4) + XCTAssertFalse(firstEvent.settings.thinkingEnabled) + + _ = try await sendChatCompletion( + APIChatCompletionRequest( + model: modelId, + messages: [ + APIChatMessage(role: "user", content: .text("Reply with one short word."), name: nil, tool_calls: nil, tool_call_id: nil) + ], + temperature: 0.62, + top_p: 0.55, + max_tokens: 5, + stream: false, + frequency_penalty: 0.1, + presence_penalty: 0.2, + top_k: 4, + min_p: 0.02, + repetition_penalty: 1.05 + ), + port: harness.port + ) + + try await waitUntil(timeoutSeconds: 5) { + await collector.events().count == 2 + } + + let secondEvents = await collector.events() + let secondEvent = try XCTUnwrap(secondEvents.last) + XCTAssertEqual(secondEvent.settings.temperature, 0.62) + XCTAssertEqual(secondEvent.settings.topP, 0.55) + XCTAssertEqual(secondEvent.settings.topK, 4) + XCTAssertEqual(secondEvent.settings.minP, 0.02) + XCTAssertEqual(secondEvent.settings.maxTokens, 5) + XCTAssertEqual(secondEvent.settings.repetitionPenalty, 1.05) + XCTAssertEqual(secondEvent.settings.presencePenalty, 0.2) + XCTAssertEqual(secondEvent.settings.frequencyPenalty, 0.1) + XCTAssertFalse(secondEvent.settings.thinkingEnabled) + } + func testStreamingDisconnectStopsServerWorkWithinTwoHundredMilliseconds() async throws { let harness = try await makeHarness() defer { harness.stop() } @@ -1683,6 +1779,18 @@ private actor LookupEventCollector { } } +private actor GenerationSettingsEventCollector { + private var recorded: [APIServer.DebugGenerationSettingsEvent] = [] + + func record(_ event: APIServer.DebugGenerationSettingsEvent) { + recorded.append(event) + } + + func events() -> [APIServer.DebugGenerationSettingsEvent] { + recorded + } +} + private struct DetailedStreamingResult { let events: [StreamingEvent] let sawDone: Bool diff --git a/MLXServerTests/Server/GenerationSettingsTests.swift b/MLXServerTests/Server/GenerationSettingsTests.swift new file mode 100644 index 0000000..475c78b --- /dev/null +++ b/MLXServerTests/Server/GenerationSettingsTests.swift @@ -0,0 +1,80 @@ +import XCTest +@testable import MLX_Server + +final class GenerationSettingsTests: XCTestCase { + func testSceneOverridesApplyWithoutDiscardingModelDefaults() { + let base = GenerationSettings( + temperature: 0.2, + topP: 0.9, + topK: 12, + minP: 0.05, + maxTokens: 2048, + repetitionPenalty: 1.08, + presencePenalty: 0.3, + frequencyPenalty: 0.1, + thinkingEnabled: true + ) + + let overrides = GenerationSettingsOverride( + temperature: 0.8, + repetitionPenalty: 1.2, + thinkingEnabled: false + ) + + let resolved = base.applying(overrides) + + XCTAssertEqual(resolved.temperature, 0.8) + XCTAssertEqual(resolved.repetitionPenalty, 1.2) + XCTAssertEqual(resolved.topP, 0.9) + XCTAssertEqual(resolved.topK, 12) + XCTAssertEqual(resolved.maxTokens, 2048) + XCTAssertEqual(resolved.presencePenalty, 0.3) + XCTAssertFalse(resolved.thinkingEnabled) + } + + func testPreferencesStoreGenerationDefaultsPerModel() { + let gemmaId = "gemma" + let qwenId = "qwen3.5-0.8b" + let originalGemma = Preferences.generationSettings(forModelId: gemmaId) + let originalQwen = Preferences.generationSettings(forModelId: qwenId) + + defer { + Preferences.setGenerationSettings(originalGemma, forModelId: gemmaId) + Preferences.setGenerationSettings(originalQwen, forModelId: qwenId) + } + + Preferences.setGenerationSettings( + GenerationSettings(temperature: 0.15, topP: 0.85, maxTokens: 1024, repetitionPenalty: 1.1, thinkingEnabled: false), + forModelId: gemmaId + ) + Preferences.setGenerationSettings( + GenerationSettings(temperature: 0.95, topP: 1.0, maxTokens: 8192, repetitionPenalty: nil, thinkingEnabled: true), + forModelId: qwenId + ) + + let gemma = Preferences.generationSettings(forModelId: gemmaId) + let qwen = Preferences.generationSettings(forModelId: qwenId) + + XCTAssertEqual(gemma.temperature, 0.15) + XCTAssertEqual(gemma.topP, 0.85) + XCTAssertEqual(gemma.maxTokens, 1024) + XCTAssertEqual(gemma.repetitionPenalty, 1.1) + XCTAssertFalse(gemma.thinkingEnabled) + + XCTAssertEqual(qwen.temperature, 0.95) + XCTAssertEqual(qwen.maxTokens, 8192) + XCTAssertNil(qwen.repetitionPenalty) + XCTAssertTrue(qwen.thinkingEnabled) + } + + func testModelFallbackDefaultsComeFromModelDefinitions() { + let gemma = GenerationSettings.modelDefault(for: "gemma") + let qwen = GenerationSettings.modelDefault(for: "qwen") + let stheno = GenerationSettings.modelDefault(for: "stheno") + + XCTAssertEqual(gemma, .technicalDefault) + XCTAssertEqual(qwen, .technicalDefault) + XCTAssertEqual(stheno, .roleplayDefault) + XCTAssertNotEqual(gemma, stheno) + } +} \ No newline at end of file diff --git a/README.md b/README.md index 068f601..6186ab6 100644 --- a/README.md +++ b/README.md @@ -42,7 +42,7 @@ This is intended for targeted validation while keeping the normal default as the ## App Features - **Chat interface** with markdown rendering and model-aware image attachments (file picker, drag & drop, clipboard paste, Finder copy-paste on vision-capable models) -- **Scene-based chat starts** — New Chat opens a scene picker with Neutral plus saved scenes, each with an optional model override, a scene prompt layered onto the base system prompt, and an auto-sent starter prompt +- **Scene-based chat starts** — New Chat opens a scene picker with Neutral plus saved scenes, each with an optional model override, a scene prompt layered onto the base system prompt, an auto-sent starter prompt, and optional generation-setting overrides for chat-specific behavior - **Model picker** in toolbar with local/download status indicators and re-download button - **Download progress modal** — shows file progress, percentage, and speed when downloading a new model - **Thinking mode** — models like Qwen3.5 can reason internally before responding; thinking content appears in a collapsible box. Toggle on/off in Settings. @@ -52,7 +52,7 @@ This is intended for targeted validation while keeping the normal default as the - **Status bar** showing model name, context window, tokens/sec, token counts, GPU memory, API server status - **Keyboard shortcuts**: `Cmd+N` (new chat), `Cmd+O` (open chat document), `Cmd+S` (save chat document), `Cmd+Shift+S` (save chat document as), `Cmd+Shift+E` (export), `Cmd+Return` (send), `Escape` (stop), `Cmd+1/2/3/4/5` (switch models) - **Scene management** — create and edit reusable roleplay/task presets from the New Chat flow or Settings -- **Settings** (`Cmd+,`): default model, thinking mode toggle, base system prompt, scene management, API port, API auto-start, idle unload timeout +- **Settings** (`Cmd+,`): default model, per-model generation defaults (temperature, top-p/top-k, min-p, repetition/presence/frequency penalties, max tokens, thinking mode), base system prompt, scene management, API port, API auto-start, idle unload timeout - **Idle auto-unload** — model is unloaded after configurable idle time (resets on both user input and model output), reloaded on next request ## API Server @@ -65,6 +65,8 @@ The embedded API server (toggle in toolbar) runs on port 1234 by default. Standa Capability checks are enforced server-side. If a request sends images to a text-only model or tools to a model without tool support, the server returns a `400 invalid_request_error`. +When a chat-completions request omits generation parameters, the API server falls back to the saved per-model defaults from Settings. Request-supplied values still take precedence on a per-call basis. + ### Model Swapping Send any model ID or alias in the `model` field. If it differs from the currently loaded model, the server swaps automatically: