From 5313b7175e0edfbb25402e1e2691cee8958483b3 Mon Sep 17 00:00:00 2001 From: Chili Palmer Date: Tue, 17 Mar 2026 19:12:54 +0100 Subject: [PATCH] feat: complete rewrite to swift --- .gitignore | 10 +- CLAUDE.md | 55 +- MLXServer.xcodeproj/project.pbxproj | 488 +++++++ .../contents.xcworkspacedata | 7 + .../xcshareddata/swiftpm/Package.resolved | 159 +++ .../AccentColor.colorset/Contents.json | 11 + .../AppIcon.appiconset/Contents.json | 58 + MLXServer/Assets.xcassets/Contents.json | 6 + MLXServer/ContentView.swift | 115 ++ MLXServer/MLXServer.entitlements | 14 + MLXServer/MLXServerApp.swift | 33 + MLXServer/Models/ChatMessage.swift | 60 + MLXServer/Models/ModelConfig.swift | 56 + MLXServer/Server/APIModels.swift | 237 ++++ MLXServer/Server/APIServer.swift | 814 ++++++++++++ MLXServer/Server/ToolCallParser.swift | 190 +++ MLXServer/Server/ToolPromptBuilder.swift | 199 +++ MLXServer/Utilities/LocalModelResolver.swift | 47 + MLXServer/Utilities/Preferences.swift | 42 + MLXServer/ViewModels/ChatViewModel.swift | 158 +++ MLXServer/ViewModels/ModelManager.swift | 71 ++ MLXServer/Views/ChatInputView.swift | 128 ++ MLXServer/Views/ChatMessagesView.swift | 105 ++ MLXServer/Views/ModelPickerView.swift | 32 + MLXServer/Views/SettingsView.swift | 44 + MLXServer/Views/StatusBarView.swift | 77 ++ README.md | 105 +- build.sh | 41 + mlx_server/__init__.py | 0 mlx_server/__main__.py | 3 - mlx_server/engine.py | 1120 ----------------- mlx_server/main.py | 409 ------ mlx_server/models.py | 145 --- project.yml | 45 + pyproject.toml | 20 - run.sh | 47 - test_server.py | 296 ----- 37 files changed, 3325 insertions(+), 2122 deletions(-) create mode 100644 MLXServer.xcodeproj/project.pbxproj create mode 100644 MLXServer.xcodeproj/project.xcworkspace/contents.xcworkspacedata create mode 100644 MLXServer.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved create mode 100644 MLXServer/Assets.xcassets/AccentColor.colorset/Contents.json create mode 100644 MLXServer/Assets.xcassets/AppIcon.appiconset/Contents.json create mode 100644 MLXServer/Assets.xcassets/Contents.json create mode 100644 MLXServer/ContentView.swift create mode 100644 MLXServer/MLXServer.entitlements create mode 100644 MLXServer/MLXServerApp.swift create mode 100644 MLXServer/Models/ChatMessage.swift create mode 100644 MLXServer/Models/ModelConfig.swift create mode 100644 MLXServer/Server/APIModels.swift create mode 100644 MLXServer/Server/APIServer.swift create mode 100644 MLXServer/Server/ToolCallParser.swift create mode 100644 MLXServer/Server/ToolPromptBuilder.swift create mode 100644 MLXServer/Utilities/LocalModelResolver.swift create mode 100644 MLXServer/Utilities/Preferences.swift create mode 100644 MLXServer/ViewModels/ChatViewModel.swift create mode 100644 MLXServer/ViewModels/ModelManager.swift create mode 100644 MLXServer/Views/ChatInputView.swift create mode 100644 MLXServer/Views/ChatMessagesView.swift create mode 100644 MLXServer/Views/ModelPickerView.swift create mode 100644 MLXServer/Views/SettingsView.swift create mode 100644 MLXServer/Views/StatusBarView.swift create mode 100755 build.sh delete mode 100644 mlx_server/__init__.py delete mode 100644 mlx_server/__main__.py delete mode 100644 mlx_server/engine.py delete mode 100644 mlx_server/main.py delete mode 100644 mlx_server/models.py create mode 100644 project.yml delete mode 100644 pyproject.toml delete mode 100755 run.sh delete mode 100644 test_server.py diff --git a/.gitignore b/.gitignore index dfa92aa..b3f936a 100644 --- a/.gitignore +++ b/.gitignore @@ -1,11 +1,5 @@ -__pycache__/ -*.py[cod] -*$py.class -*.egg-info/ -dist/ build/ -.venv/ -.env -*.log .DS_Store +*.log settings.local.json +xcuserdata/ diff --git a/CLAUDE.md b/CLAUDE.md index ea625ee..3af05c9 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -1,50 +1,55 @@ # MLX Server -OpenAI-compatible API server for local LLMs on Apple Silicon via MLX. Supports Gemma 3 4B and Qwen3 VL 4B (vision + tool use). +Native macOS SwiftUI app for local LLMs on Apple Silicon via MLX. Provides a chat UI and an embedded OpenAI-compatible API server. Supports vision and tool use. ## Quick Start ```bash -# Activate virtual environment -source .venv/bin/activate +# Build (requires xcodegen: brew install xcodegen) +./build.sh -# Run with Gemma 3 (default) -./run.sh - -# Run with Qwen3 -./run.sh qwen - -# Or directly: -python -m mlx_server.main --model mlx-community/gemma-3-4b-it-4bit --port 1234 -python -m mlx_server.main --model mlx-community/Qwen3-VL-4B-Instruct-4bit --port 1234 +# Run +open "build/Debug/MLX Server.app" ``` ## Project Structure -- `mlx_server/main.py` — FastAPI server, endpoints, CLI entrypoint -- `mlx_server/engine.py` — Model loading, prompt building, generation (mlx_vlm) -- `mlx_server/models.py` — Pydantic models for OpenAI API request/response types +- `MLXServer/MLXServerApp.swift` — App entry point, GPU cache config +- `MLXServer/ContentView.swift` — Main layout, toolbar, keyboard shortcuts +- `MLXServer/Models/ModelConfig.swift` — Model definitions (alias, repoId, contextLength), resolution +- `MLXServer/Models/ChatMessage.swift` — Chat message data model +- `MLXServer/ViewModels/ModelManager.swift` — Model loading/switching via VLMModelFactory, offline-first resolution +- `MLXServer/ViewModels/ChatViewModel.swift` — Chat state, ChatSession management, API server lifecycle +- `MLXServer/Server/APIServer.swift` — NWListener HTTP server, SSE streaming, KV cache reuse, vision, tool call handling +- `MLXServer/Server/APIModels.swift` — OpenAI-compatible Codable structs +- `MLXServer/Server/ToolCallParser.swift` — Parses tool calls from model output (Gemma tool_code, Qwen XML tags) +- `MLXServer/Server/ToolPromptBuilder.swift` — Model-specific tool prompt formatting +- `MLXServer/Utilities/LocalModelResolver.swift` — Resolves HF repo IDs to ~/.cache/huggingface/hub/ snapshots +- `MLXServer/Utilities/Preferences.swift` — UserDefaults wrapper +- `project.yml` — xcodegen project spec +- `build.sh` — Build script (xcodegen + xcodebuild) ## Supported Models | Alias | HuggingFace ID | Notes | |-------|---------------|-------| | `gemma` | `mlx-community/gemma-3-4b-it-4bit` | Vision + tool use via `tool_code` blocks (128k context) | -| `gemma3n` | `mlx-community/gemma-3n-E4B-it-4bit` | Vision/audio/video + tool use via `tool_code` blocks (32k context, ~1.5x faster) | | `qwen` | `mlx-community/Qwen3-VL-4B-Instruct-4bit` | Vision + tool use via `` tags (256k context) | ## Key Design Decisions -- Uses `mlx_vlm` (not `mlx_lm`) as the inference backend — this supports both text and vision in a single model load -- Model-specific prompt formatting: Gemma converts system→user/assistant pairs and uses `tool_code` blocks; Qwen3 uses native system role and `` XML tags -- Offline-first: if the model is already cached locally (~/.cache/huggingface/hub/), the server resolves the local snapshot path directly — no network requests are made (HEAD checks, update checks, etc.) -- Thread lock on generation (single-request-at-a-time) — MLX models aren't safe for concurrent generation -- Context window size is read from each model's config at load time (Gemma 3 4B: 128k, Qwen3-VL 4B: 256k) +- Uses `mlx-swift-lm` (`MLXVLM` / `VLMModelFactory`) as the inference backend — supports both text and vision in a single model load +- Model-specific prompt formatting: Gemma uses `tool_code` blocks; Qwen uses `` XML tags +- Offline-first: if the model is already cached locally (~/.cache/huggingface/hub/), `LocalModelResolver` resolves the local snapshot path directly — no network requests +- HTTP server built on `Network.framework` (`NWListener`) — no third-party server dependencies +- KV cache reuse across API requests — reuses `ChatSession` when conversation history prefix matches +- GPU cache limit set to 20 MB; cache cleared on model unload ## Dependencies -Managed via `uv` and `pyproject.toml`. Virtual environment in `.venv/`. +Managed via Swift Package Manager (declared in `project.yml` for xcodegen). -```bash -uv pip install -e "." -``` +| Package | Products | +|---------|----------| +| `mlx-swift-lm` | `MLXLLM`, `MLXVLM`, `MLXLMCommon` | +| `swift-markdown-ui` | `MarkdownUI` | diff --git a/MLXServer.xcodeproj/project.pbxproj b/MLXServer.xcodeproj/project.pbxproj new file mode 100644 index 0000000..ceb748b --- /dev/null +++ b/MLXServer.xcodeproj/project.pbxproj @@ -0,0 +1,488 @@ +// !$*UTF8*$! +{ + archiveVersion = 1; + classes = { + }; + objectVersion = 77; + objects = { + +/* Begin PBXBuildFile section */ + 0168AEE16009097901363E16 /* ModelManager.swift in Sources */ = {isa = PBXBuildFile; fileRef = 922CBDC9206737BD04AF2874 /* ModelManager.swift */; }; + 165E8AB6ADAE1D59B1A86420 /* Preferences.swift in Sources */ = {isa = PBXBuildFile; fileRef = 145B888FBDD4F931512C5473 /* Preferences.swift */; }; + 189362AAE2CDE5D4B3428334 /* ToolCallParser.swift in Sources */ = {isa = PBXBuildFile; fileRef = E73B165A1822729C907791AE /* ToolCallParser.swift */; }; + 2CAAF7129F7CC45200FA9F6B /* ModelPickerView.swift in Sources */ = {isa = PBXBuildFile; fileRef = C3C3A76C02AF70A9D8F868FC /* ModelPickerView.swift */; }; + 4CB13DC1AC7A500DDBB443EC /* ChatInputView.swift in Sources */ = {isa = PBXBuildFile; fileRef = E5E6AD02CDF23BDAB64700A7 /* ChatInputView.swift */; }; + 50B6861FF8610B3ED4FFAD9D /* MLXServerApp.swift in Sources */ = {isa = PBXBuildFile; fileRef = C67742651DB486871CEF1612 /* MLXServerApp.swift */; }; + 50DD129CCF2843482DEC3B96 /* APIServer.swift in Sources */ = {isa = PBXBuildFile; fileRef = 3D08828E16B17EF02C14243E /* APIServer.swift */; }; + 5946258F1DE88CE904584E0B /* ContentView.swift in Sources */ = {isa = PBXBuildFile; fileRef = 944C699FBB76C734C9DF2F2E /* ContentView.swift */; }; + 5C1E8FE1C521914CEF98D3AA /* ChatMessagesView.swift in Sources */ = {isa = PBXBuildFile; fileRef = DB1A5E8B1C9F2BC4D262C53A /* ChatMessagesView.swift */; }; + 621B7E4382199AC1378F5F9C /* StatusBarView.swift in Sources */ = {isa = PBXBuildFile; fileRef = B0EAB35D7130D56B9E7484BA /* StatusBarView.swift */; }; + 6828CCA8B78AB40906F87CAB /* LocalModelResolver.swift in Sources */ = {isa = PBXBuildFile; fileRef = D733A0D1D4AC25DDDA6C8684 /* LocalModelResolver.swift */; }; + 7CD765C1E2F9F4D7504C8D09 /* Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = B629DA084A9A40E54F8EA5FA /* Assets.xcassets */; }; + 80646C5066BF79BC76E1D9D7 /* ModelConfig.swift in Sources */ = {isa = PBXBuildFile; fileRef = 38DFC212AF4359A45FBE22BA /* ModelConfig.swift */; }; + 84D32315B418B5243E017350 /* ToolPromptBuilder.swift in Sources */ = {isa = PBXBuildFile; fileRef = 16AE82A64D1D07AE3CD8D33A /* ToolPromptBuilder.swift */; }; + 945474365D0B3E961811909A /* MLXVLM in Frameworks */ = {isa = PBXBuildFile; productRef = D5E8E1C2DD8D8AABB4306193 /* MLXVLM */; }; + B5AA6E3B4BE21676226B342B /* ChatViewModel.swift in Sources */ = {isa = PBXBuildFile; fileRef = B8BD93859F0291F1A3E09DA5 /* ChatViewModel.swift */; }; + B6D3662995B885C102876B4A /* MLXLMCommon in Frameworks */ = {isa = PBXBuildFile; productRef = 9090667D4134056AE66DC2F1 /* MLXLMCommon */; }; + D666A311788375E8A061C832 /* SettingsView.swift in Sources */ = {isa = PBXBuildFile; fileRef = 4147321383E94E9F17A0154E /* SettingsView.swift */; }; + D96DDE66F76FDDA642629E17 /* APIModels.swift in Sources */ = {isa = PBXBuildFile; fileRef = F1A52E2C9964ADA9D841A89B /* APIModels.swift */; }; + F546CE5955ED253D8A793D5E /* MarkdownUI in Frameworks */ = {isa = PBXBuildFile; productRef = A98257123539E9E738213BFA /* MarkdownUI */; }; + FAF7D4714AC6D02674920208 /* ChatMessage.swift in Sources */ = {isa = PBXBuildFile; fileRef = A4B359324B5FD8D106C74338 /* ChatMessage.swift */; }; + FCD48F8C132A2B830A15EEB4 /* MLXLLM in Frameworks */ = {isa = PBXBuildFile; productRef = 3F5A4AC6DBAF7CA686ECA74E /* MLXLLM */; }; +/* End PBXBuildFile section */ + +/* Begin PBXFileReference section */ + 145B888FBDD4F931512C5473 /* Preferences.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Preferences.swift; sourceTree = ""; }; + 16AE82A64D1D07AE3CD8D33A /* ToolPromptBuilder.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ToolPromptBuilder.swift; sourceTree = ""; }; + 38DFC212AF4359A45FBE22BA /* ModelConfig.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ModelConfig.swift; sourceTree = ""; }; + 3AF462805202797F61422AEE /* MLXServer.entitlements */ = {isa = PBXFileReference; lastKnownFileType = text.plist.entitlements; path = MLXServer.entitlements; sourceTree = ""; }; + 3D08828E16B17EF02C14243E /* APIServer.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = APIServer.swift; sourceTree = ""; }; + 4147321383E94E9F17A0154E /* SettingsView.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = SettingsView.swift; sourceTree = ""; }; + 6EE59189918D06B8D2F588FC /* MLXServer.app */ = {isa = PBXFileReference; explicitFileType = wrapper.application; includeInIndex = 0; path = MLXServer.app; sourceTree = BUILT_PRODUCTS_DIR; }; + 922CBDC9206737BD04AF2874 /* ModelManager.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ModelManager.swift; sourceTree = ""; }; + 944C699FBB76C734C9DF2F2E /* ContentView.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ContentView.swift; sourceTree = ""; }; + A4B359324B5FD8D106C74338 /* ChatMessage.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ChatMessage.swift; sourceTree = ""; }; + B0EAB35D7130D56B9E7484BA /* StatusBarView.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = StatusBarView.swift; sourceTree = ""; }; + B629DA084A9A40E54F8EA5FA /* Assets.xcassets */ = {isa = PBXFileReference; lastKnownFileType = folder.assetcatalog; path = Assets.xcassets; sourceTree = ""; }; + B8BD93859F0291F1A3E09DA5 /* ChatViewModel.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ChatViewModel.swift; sourceTree = ""; }; + C3C3A76C02AF70A9D8F868FC /* ModelPickerView.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ModelPickerView.swift; sourceTree = ""; }; + C67742651DB486871CEF1612 /* MLXServerApp.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = MLXServerApp.swift; sourceTree = ""; }; + D733A0D1D4AC25DDDA6C8684 /* LocalModelResolver.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = LocalModelResolver.swift; sourceTree = ""; }; + DB1A5E8B1C9F2BC4D262C53A /* ChatMessagesView.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ChatMessagesView.swift; sourceTree = ""; }; + E5E6AD02CDF23BDAB64700A7 /* ChatInputView.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ChatInputView.swift; sourceTree = ""; }; + E73B165A1822729C907791AE /* ToolCallParser.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ToolCallParser.swift; sourceTree = ""; }; + F1A52E2C9964ADA9D841A89B /* APIModels.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = APIModels.swift; sourceTree = ""; }; +/* End PBXFileReference section */ + +/* Begin PBXFrameworksBuildPhase section */ + A328B75C1B81B56CC7597F12 /* Frameworks */ = { + isa = PBXFrameworksBuildPhase; + buildActionMask = 2147483647; + files = ( + FCD48F8C132A2B830A15EEB4 /* MLXLLM in Frameworks */, + 945474365D0B3E961811909A /* MLXVLM in Frameworks */, + B6D3662995B885C102876B4A /* MLXLMCommon in Frameworks */, + F546CE5955ED253D8A793D5E /* MarkdownUI in Frameworks */, + ); + runOnlyForDeploymentPostprocessing = 0; + }; +/* End PBXFrameworksBuildPhase section */ + +/* Begin PBXGroup section */ + 05B1BAE308E64D2FB2E73823 /* Utilities */ = { + isa = PBXGroup; + children = ( + D733A0D1D4AC25DDDA6C8684 /* LocalModelResolver.swift */, + 145B888FBDD4F931512C5473 /* Preferences.swift */, + ); + path = Utilities; + sourceTree = ""; + }; + 652987C2A419DBFC79E32CDE /* Products */ = { + isa = PBXGroup; + children = ( + 6EE59189918D06B8D2F588FC /* MLXServer.app */, + ); + name = Products; + sourceTree = ""; + }; + 6816BF8EF7C92384DD7C9177 /* MLXServer */ = { + isa = PBXGroup; + children = ( + B629DA084A9A40E54F8EA5FA /* Assets.xcassets */, + 944C699FBB76C734C9DF2F2E /* ContentView.swift */, + 3AF462805202797F61422AEE /* MLXServer.entitlements */, + C67742651DB486871CEF1612 /* MLXServerApp.swift */, + BD0E350482D91238B4B59721 /* Models */, + E13C1AAA0C49D0ED85EFD94D /* Server */, + 05B1BAE308E64D2FB2E73823 /* Utilities */, + D7A641B0969293E838F9147A /* ViewModels */, + 7B3BAACD850CBB35C7F4FB6C /* Views */, + ); + path = MLXServer; + sourceTree = ""; + }; + 7B3BAACD850CBB35C7F4FB6C /* Views */ = { + isa = PBXGroup; + children = ( + E5E6AD02CDF23BDAB64700A7 /* ChatInputView.swift */, + DB1A5E8B1C9F2BC4D262C53A /* ChatMessagesView.swift */, + C3C3A76C02AF70A9D8F868FC /* ModelPickerView.swift */, + 4147321383E94E9F17A0154E /* SettingsView.swift */, + B0EAB35D7130D56B9E7484BA /* StatusBarView.swift */, + ); + path = Views; + sourceTree = ""; + }; + BD0E350482D91238B4B59721 /* Models */ = { + isa = PBXGroup; + children = ( + A4B359324B5FD8D106C74338 /* ChatMessage.swift */, + 38DFC212AF4359A45FBE22BA /* ModelConfig.swift */, + ); + path = Models; + sourceTree = ""; + }; + D7A641B0969293E838F9147A /* ViewModels */ = { + isa = PBXGroup; + children = ( + B8BD93859F0291F1A3E09DA5 /* ChatViewModel.swift */, + 922CBDC9206737BD04AF2874 /* ModelManager.swift */, + ); + path = ViewModels; + sourceTree = ""; + }; + E13C1AAA0C49D0ED85EFD94D /* Server */ = { + isa = PBXGroup; + children = ( + F1A52E2C9964ADA9D841A89B /* APIModels.swift */, + 3D08828E16B17EF02C14243E /* APIServer.swift */, + E73B165A1822729C907791AE /* ToolCallParser.swift */, + 16AE82A64D1D07AE3CD8D33A /* ToolPromptBuilder.swift */, + ); + path = Server; + sourceTree = ""; + }; + E2540E47403820BAAFEF0560 = { + isa = PBXGroup; + children = ( + 6816BF8EF7C92384DD7C9177 /* MLXServer */, + 652987C2A419DBFC79E32CDE /* Products */, + ); + sourceTree = ""; + }; +/* End PBXGroup section */ + +/* Begin PBXNativeTarget section */ + BCD7107EE884C9B2F4C2C40E /* MLXServer */ = { + isa = PBXNativeTarget; + buildConfigurationList = 732FBF6595F174F37E5F2835 /* Build configuration list for PBXNativeTarget "MLXServer" */; + buildPhases = ( + BC03844286F51DFAEF96B823 /* Sources */, + 4668CBE7984322217309B525 /* Resources */, + A328B75C1B81B56CC7597F12 /* Frameworks */, + ); + buildRules = ( + ); + dependencies = ( + ); + name = MLXServer; + packageProductDependencies = ( + 3F5A4AC6DBAF7CA686ECA74E /* MLXLLM */, + D5E8E1C2DD8D8AABB4306193 /* MLXVLM */, + 9090667D4134056AE66DC2F1 /* MLXLMCommon */, + A98257123539E9E738213BFA /* MarkdownUI */, + ); + productName = MLXServer; + productReference = 6EE59189918D06B8D2F588FC /* MLXServer.app */; + productType = "com.apple.product-type.application"; + }; +/* End PBXNativeTarget section */ + +/* Begin PBXProject section */ + 938BC479816FCA8527B731F9 /* Project object */ = { + isa = PBXProject; + attributes = { + BuildIndependentTargetsInParallel = YES; + LastUpgradeCheck = 1640; + TargetAttributes = { + }; + }; + buildConfigurationList = 9281433C2FA3F9393F1D48E5 /* Build configuration list for PBXProject "MLXServer" */; + developmentRegion = en; + hasScannedForEncodings = 0; + knownRegions = ( + Base, + en, + ); + mainGroup = E2540E47403820BAAFEF0560; + minimizedProjectReferenceProxies = 1; + packageReferences = ( + D402301668D113A49B6DD32D /* XCRemoteSwiftPackageReference "swift-markdown-ui" */, + 1AA4C71F15847A241E418C0C /* XCRemoteSwiftPackageReference "mlx-swift-lm" */, + ); + preferredProjectObjectVersion = 77; + productRefGroup = 652987C2A419DBFC79E32CDE /* Products */; + projectDirPath = ""; + projectRoot = ""; + targets = ( + BCD7107EE884C9B2F4C2C40E /* MLXServer */, + ); + }; +/* End PBXProject section */ + +/* Begin PBXResourcesBuildPhase section */ + 4668CBE7984322217309B525 /* Resources */ = { + isa = PBXResourcesBuildPhase; + buildActionMask = 2147483647; + files = ( + 7CD765C1E2F9F4D7504C8D09 /* Assets.xcassets in Resources */, + ); + runOnlyForDeploymentPostprocessing = 0; + }; +/* End PBXResourcesBuildPhase section */ + +/* Begin PBXSourcesBuildPhase section */ + BC03844286F51DFAEF96B823 /* Sources */ = { + isa = PBXSourcesBuildPhase; + buildActionMask = 2147483647; + files = ( + D96DDE66F76FDDA642629E17 /* APIModels.swift in Sources */, + 50DD129CCF2843482DEC3B96 /* APIServer.swift in Sources */, + 4CB13DC1AC7A500DDBB443EC /* ChatInputView.swift in Sources */, + FAF7D4714AC6D02674920208 /* ChatMessage.swift in Sources */, + 5C1E8FE1C521914CEF98D3AA /* ChatMessagesView.swift in Sources */, + B5AA6E3B4BE21676226B342B /* ChatViewModel.swift in Sources */, + 5946258F1DE88CE904584E0B /* ContentView.swift in Sources */, + 6828CCA8B78AB40906F87CAB /* LocalModelResolver.swift in Sources */, + 50B6861FF8610B3ED4FFAD9D /* MLXServerApp.swift in Sources */, + 80646C5066BF79BC76E1D9D7 /* ModelConfig.swift in Sources */, + 0168AEE16009097901363E16 /* ModelManager.swift in Sources */, + 2CAAF7129F7CC45200FA9F6B /* ModelPickerView.swift in Sources */, + 165E8AB6ADAE1D59B1A86420 /* Preferences.swift in Sources */, + D666A311788375E8A061C832 /* SettingsView.swift in Sources */, + 621B7E4382199AC1378F5F9C /* StatusBarView.swift in Sources */, + 189362AAE2CDE5D4B3428334 /* ToolCallParser.swift in Sources */, + 84D32315B418B5243E017350 /* ToolPromptBuilder.swift in Sources */, + ); + runOnlyForDeploymentPostprocessing = 0; + }; +/* End PBXSourcesBuildPhase section */ + +/* Begin XCBuildConfiguration section */ + 6C0C08FC4653A138A768ECF0 /* Release */ = { + isa = XCBuildConfiguration; + buildSettings = { + ALWAYS_SEARCH_USER_PATHS = NO; + CLANG_ANALYZER_NONNULL = YES; + CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE; + CLANG_CXX_LANGUAGE_STANDARD = "gnu++14"; + CLANG_CXX_LIBRARY = "libc++"; + CLANG_ENABLE_MODULES = YES; + CLANG_ENABLE_OBJC_ARC = YES; + CLANG_ENABLE_OBJC_WEAK = YES; + CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES; + CLANG_WARN_BOOL_CONVERSION = YES; + CLANG_WARN_COMMA = YES; + CLANG_WARN_CONSTANT_CONVERSION = YES; + CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES; + CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR; + CLANG_WARN_DOCUMENTATION_COMMENTS = YES; + CLANG_WARN_EMPTY_BODY = YES; + CLANG_WARN_ENUM_CONVERSION = YES; + CLANG_WARN_INFINITE_RECURSION = YES; + CLANG_WARN_INT_CONVERSION = YES; + CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES; + CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES; + CLANG_WARN_OBJC_LITERAL_CONVERSION = YES; + CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR; + CLANG_WARN_QUOTED_INCLUDE_IN_FRAMEWORK_HEADER = YES; + CLANG_WARN_RANGE_LOOP_ANALYSIS = YES; + CLANG_WARN_STRICT_PROTOTYPES = YES; + CLANG_WARN_SUSPICIOUS_MOVE = YES; + CLANG_WARN_UNGUARDED_AVAILABILITY = YES_AGGRESSIVE; + CLANG_WARN_UNREACHABLE_CODE = YES; + CLANG_WARN__DUPLICATE_METHOD_MATCH = YES; + COPY_PHASE_STRIP = NO; + DEBUG_INFORMATION_FORMAT = "dwarf-with-dsym"; + ENABLE_NS_ASSERTIONS = NO; + ENABLE_STRICT_OBJC_MSGSEND = YES; + GCC_C_LANGUAGE_STANDARD = gnu11; + GCC_NO_COMMON_BLOCKS = YES; + GCC_WARN_64_TO_32_BIT_CONVERSION = YES; + GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR; + GCC_WARN_UNDECLARED_SELECTOR = YES; + GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE; + GCC_WARN_UNUSED_FUNCTION = YES; + GCC_WARN_UNUSED_VARIABLE = YES; + MACOSX_DEPLOYMENT_TARGET = 15.0; + MTL_ENABLE_DEBUG_INFO = NO; + MTL_FAST_MATH = YES; + PRODUCT_NAME = "$(TARGET_NAME)"; + SDKROOT = macosx; + SWIFT_COMPILATION_MODE = wholemodule; + SWIFT_OPTIMIZATION_LEVEL = "-O"; + SWIFT_VERSION = 5.0; + }; + name = Release; + }; + B906BFE6DBE690935A1B8B6E /* Debug */ = { + isa = XCBuildConfiguration; + buildSettings = { + ALWAYS_SEARCH_USER_PATHS = NO; + CLANG_ANALYZER_NONNULL = YES; + CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE; + CLANG_CXX_LANGUAGE_STANDARD = "gnu++14"; + CLANG_CXX_LIBRARY = "libc++"; + CLANG_ENABLE_MODULES = YES; + CLANG_ENABLE_OBJC_ARC = YES; + CLANG_ENABLE_OBJC_WEAK = YES; + CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES; + CLANG_WARN_BOOL_CONVERSION = YES; + CLANG_WARN_COMMA = YES; + CLANG_WARN_CONSTANT_CONVERSION = YES; + CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES; + CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR; + CLANG_WARN_DOCUMENTATION_COMMENTS = YES; + CLANG_WARN_EMPTY_BODY = YES; + CLANG_WARN_ENUM_CONVERSION = YES; + CLANG_WARN_INFINITE_RECURSION = YES; + CLANG_WARN_INT_CONVERSION = YES; + CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES; + CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES; + CLANG_WARN_OBJC_LITERAL_CONVERSION = YES; + CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR; + CLANG_WARN_QUOTED_INCLUDE_IN_FRAMEWORK_HEADER = YES; + CLANG_WARN_RANGE_LOOP_ANALYSIS = YES; + CLANG_WARN_STRICT_PROTOTYPES = YES; + CLANG_WARN_SUSPICIOUS_MOVE = YES; + CLANG_WARN_UNGUARDED_AVAILABILITY = YES_AGGRESSIVE; + CLANG_WARN_UNREACHABLE_CODE = YES; + CLANG_WARN__DUPLICATE_METHOD_MATCH = YES; + COPY_PHASE_STRIP = NO; + DEBUG_INFORMATION_FORMAT = dwarf; + ENABLE_STRICT_OBJC_MSGSEND = YES; + ENABLE_TESTABILITY = YES; + GCC_C_LANGUAGE_STANDARD = gnu11; + GCC_DYNAMIC_NO_PIC = NO; + GCC_NO_COMMON_BLOCKS = YES; + GCC_OPTIMIZATION_LEVEL = 0; + GCC_PREPROCESSOR_DEFINITIONS = ( + "$(inherited)", + "DEBUG=1", + ); + GCC_WARN_64_TO_32_BIT_CONVERSION = YES; + GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR; + GCC_WARN_UNDECLARED_SELECTOR = YES; + GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE; + GCC_WARN_UNUSED_FUNCTION = YES; + GCC_WARN_UNUSED_VARIABLE = YES; + MACOSX_DEPLOYMENT_TARGET = 15.0; + MTL_ENABLE_DEBUG_INFO = INCLUDE_SOURCE; + MTL_FAST_MATH = YES; + ONLY_ACTIVE_ARCH = YES; + PRODUCT_NAME = "$(TARGET_NAME)"; + SDKROOT = macosx; + SWIFT_ACTIVE_COMPILATION_CONDITIONS = DEBUG; + SWIFT_OPTIMIZATION_LEVEL = "-Onone"; + SWIFT_VERSION = 5.0; + }; + name = Debug; + }; + BE93D60EEB354E6E242C3EDB /* Debug */ = { + isa = XCBuildConfiguration; + buildSettings = { + ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon; + CODE_SIGN_ALLOW_ENTITLEMENTS_MODIFICATION = YES; + CODE_SIGN_ENTITLEMENTS = MLXServer/MLXServer.entitlements; + CODE_SIGN_IDENTITY = "-"; + COMBINE_HIDPI_IMAGES = YES; + CURRENT_PROJECT_VERSION = 1; + GENERATE_INFOPLIST_FILE = YES; + INFOPLIST_KEY_LSApplicationCategoryType = "public.app-category.developer-tools"; + INFOPLIST_KEY_NSHumanReadableCopyright = ""; + LD_RUNPATH_SEARCH_PATHS = ( + "$(inherited)", + "@executable_path/../Frameworks", + ); + MACOSX_DEPLOYMENT_TARGET = 15.0; + MARKETING_VERSION = 1.0.0; + PRODUCT_BUNDLE_IDENTIFIER = com.mlxserver.app; + PRODUCT_NAME = "MLX Server"; + SDKROOT = macosx; + SWIFT_VERSION = 6.0; + }; + name = Debug; + }; + D7A086319EEBE6664326B437 /* Release */ = { + isa = XCBuildConfiguration; + buildSettings = { + ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon; + CODE_SIGN_ALLOW_ENTITLEMENTS_MODIFICATION = YES; + CODE_SIGN_ENTITLEMENTS = MLXServer/MLXServer.entitlements; + CODE_SIGN_IDENTITY = "-"; + COMBINE_HIDPI_IMAGES = YES; + CURRENT_PROJECT_VERSION = 1; + GENERATE_INFOPLIST_FILE = YES; + INFOPLIST_KEY_LSApplicationCategoryType = "public.app-category.developer-tools"; + INFOPLIST_KEY_NSHumanReadableCopyright = ""; + LD_RUNPATH_SEARCH_PATHS = ( + "$(inherited)", + "@executable_path/../Frameworks", + ); + MACOSX_DEPLOYMENT_TARGET = 15.0; + MARKETING_VERSION = 1.0.0; + PRODUCT_BUNDLE_IDENTIFIER = com.mlxserver.app; + PRODUCT_NAME = "MLX Server"; + SDKROOT = macosx; + SWIFT_VERSION = 6.0; + }; + name = Release; + }; +/* End XCBuildConfiguration section */ + +/* Begin XCConfigurationList section */ + 732FBF6595F174F37E5F2835 /* Build configuration list for PBXNativeTarget "MLXServer" */ = { + isa = XCConfigurationList; + buildConfigurations = ( + BE93D60EEB354E6E242C3EDB /* Debug */, + D7A086319EEBE6664326B437 /* Release */, + ); + defaultConfigurationIsVisible = 0; + defaultConfigurationName = Debug; + }; + 9281433C2FA3F9393F1D48E5 /* Build configuration list for PBXProject "MLXServer" */ = { + isa = XCConfigurationList; + buildConfigurations = ( + B906BFE6DBE690935A1B8B6E /* Debug */, + 6C0C08FC4653A138A768ECF0 /* Release */, + ); + defaultConfigurationIsVisible = 0; + defaultConfigurationName = Debug; + }; +/* End XCConfigurationList section */ + +/* Begin XCRemoteSwiftPackageReference section */ + 1AA4C71F15847A241E418C0C /* XCRemoteSwiftPackageReference "mlx-swift-lm" */ = { + isa = XCRemoteSwiftPackageReference; + repositoryURL = "https://github.com/ml-explore/mlx-swift-lm"; + requirement = { + branch = main; + kind = branch; + }; + }; + D402301668D113A49B6DD32D /* XCRemoteSwiftPackageReference "swift-markdown-ui" */ = { + isa = XCRemoteSwiftPackageReference; + repositoryURL = "https://github.com/gonzalezreal/swift-markdown-ui"; + requirement = { + kind = upToNextMajorVersion; + minimumVersion = 2.4.0; + }; + }; +/* End XCRemoteSwiftPackageReference section */ + +/* Begin XCSwiftPackageProductDependency section */ + 3F5A4AC6DBAF7CA686ECA74E /* MLXLLM */ = { + isa = XCSwiftPackageProductDependency; + package = 1AA4C71F15847A241E418C0C /* XCRemoteSwiftPackageReference "mlx-swift-lm" */; + productName = MLXLLM; + }; + 9090667D4134056AE66DC2F1 /* MLXLMCommon */ = { + isa = XCSwiftPackageProductDependency; + package = 1AA4C71F15847A241E418C0C /* XCRemoteSwiftPackageReference "mlx-swift-lm" */; + productName = MLXLMCommon; + }; + A98257123539E9E738213BFA /* MarkdownUI */ = { + isa = XCSwiftPackageProductDependency; + package = D402301668D113A49B6DD32D /* XCRemoteSwiftPackageReference "swift-markdown-ui" */; + productName = MarkdownUI; + }; + D5E8E1C2DD8D8AABB4306193 /* MLXVLM */ = { + isa = XCSwiftPackageProductDependency; + package = 1AA4C71F15847A241E418C0C /* XCRemoteSwiftPackageReference "mlx-swift-lm" */; + productName = MLXVLM; + }; +/* End XCSwiftPackageProductDependency section */ + }; + rootObject = 938BC479816FCA8527B731F9 /* Project object */; +} diff --git a/MLXServer.xcodeproj/project.xcworkspace/contents.xcworkspacedata b/MLXServer.xcodeproj/project.xcworkspace/contents.xcworkspacedata new file mode 100644 index 0000000..919434a --- /dev/null +++ b/MLXServer.xcodeproj/project.xcworkspace/contents.xcworkspacedata @@ -0,0 +1,7 @@ + + + + + diff --git a/MLXServer.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved b/MLXServer.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved new file mode 100644 index 0000000..a814384 --- /dev/null +++ b/MLXServer.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved @@ -0,0 +1,159 @@ +{ + "originHash" : "418f7299ccb303e0e8992dfc960a3df5df98d527f18667aa162699027b29b6cd", + "pins" : [ + { + "identity" : "eventsource", + "kind" : "remoteSourceControl", + "location" : "https://github.com/mattt/EventSource.git", + "state" : { + "revision" : "a3a85a85214caf642abaa96ae664e4c772a59f6e", + "version" : "1.4.1" + } + }, + { + "identity" : "mlx-swift", + "kind" : "remoteSourceControl", + "location" : "https://github.com/ml-explore/mlx-swift", + "state" : { + "revision" : "6ba4827fb82c97d012eec9ab4b2de21f85c3b33d", + "version" : "0.30.6" + } + }, + { + "identity" : "mlx-swift-lm", + "kind" : "remoteSourceControl", + "location" : "https://github.com/ml-explore/mlx-swift-lm", + "state" : { + "branch" : "main", + "revision" : "bc3c20ef4644c86f2b347debcfe1efe4308712a6" + } + }, + { + "identity" : "networkimage", + "kind" : "remoteSourceControl", + "location" : "https://github.com/gonzalezreal/NetworkImage", + "state" : { + "revision" : "2849f5323265386e200484b0d0f896e73c3411b9", + "version" : "6.0.1" + } + }, + { + "identity" : "swift-asn1", + "kind" : "remoteSourceControl", + "location" : "https://github.com/apple/swift-asn1.git", + "state" : { + "revision" : "9f542610331815e29cc3821d3b6f488db8715517", + "version" : "1.6.0" + } + }, + { + "identity" : "swift-atomics", + "kind" : "remoteSourceControl", + "location" : "https://github.com/apple/swift-atomics.git", + "state" : { + "revision" : "b601256eab081c0f92f059e12818ac1d4f178ff7", + "version" : "1.3.0" + } + }, + { + "identity" : "swift-cmark", + "kind" : "remoteSourceControl", + "location" : "https://github.com/swiftlang/swift-cmark", + "state" : { + "revision" : "5d9bdaa4228b381639fff09403e39a04926e2dbe", + "version" : "0.7.1" + } + }, + { + "identity" : "swift-collections", + "kind" : "remoteSourceControl", + "location" : "https://github.com/apple/swift-collections.git", + "state" : { + "revision" : "8d9834a6189db730f6264db7556a7ffb751e99ee", + "version" : "1.4.0" + } + }, + { + "identity" : "swift-crypto", + "kind" : "remoteSourceControl", + "location" : "https://github.com/apple/swift-crypto.git", + "state" : { + "revision" : "fa308c07a6fa04a727212d793e761460e41049c3", + "version" : "4.3.0" + } + }, + { + "identity" : "swift-huggingface", + "kind" : "remoteSourceControl", + "location" : "https://github.com/huggingface/swift-huggingface.git", + "state" : { + "revision" : "b721959445b617d0bf03910b2b4aced345fd93bf", + "version" : "0.9.0" + } + }, + { + "identity" : "swift-jinja", + "kind" : "remoteSourceControl", + "location" : "https://github.com/huggingface/swift-jinja.git", + "state" : { + "revision" : "f731f03bf746481d4fda07f817c3774390c4d5b9", + "version" : "2.3.2" + } + }, + { + "identity" : "swift-markdown-ui", + "kind" : "remoteSourceControl", + "location" : "https://github.com/gonzalezreal/swift-markdown-ui", + "state" : { + "revision" : "5f613358148239d0292c0cef674a3c2314737f9e", + "version" : "2.4.1" + } + }, + { + "identity" : "swift-nio", + "kind" : "remoteSourceControl", + "location" : "https://github.com/apple/swift-nio.git", + "state" : { + "revision" : "b31565862a8f39866af50bc6676160d8dda7de35", + "version" : "2.96.0" + } + }, + { + "identity" : "swift-numerics", + "kind" : "remoteSourceControl", + "location" : "https://github.com/apple/swift-numerics", + "state" : { + "revision" : "0c0290ff6b24942dadb83a929ffaaa1481df04a2", + "version" : "1.1.1" + } + }, + { + "identity" : "swift-system", + "kind" : "remoteSourceControl", + "location" : "https://github.com/apple/swift-system.git", + "state" : { + "revision" : "7c6ad0fc39d0763e0b699210e4124afd5041c5df", + "version" : "1.6.4" + } + }, + { + "identity" : "swift-transformers", + "kind" : "remoteSourceControl", + "location" : "https://github.com/huggingface/swift-transformers", + "state" : { + "revision" : "eed7264ac5e4ec5dfa6165c6e5c5577364344fe4", + "version" : "1.2.0" + } + }, + { + "identity" : "yyjson", + "kind" : "remoteSourceControl", + "location" : "https://github.com/ibireme/yyjson.git", + "state" : { + "revision" : "8b4a38dc994a110abaec8a400615567bd996105f", + "version" : "0.12.0" + } + } + ], + "version" : 3 +} diff --git a/MLXServer/Assets.xcassets/AccentColor.colorset/Contents.json b/MLXServer/Assets.xcassets/AccentColor.colorset/Contents.json new file mode 100644 index 0000000..eb87897 --- /dev/null +++ b/MLXServer/Assets.xcassets/AccentColor.colorset/Contents.json @@ -0,0 +1,11 @@ +{ + "colors" : [ + { + "idiom" : "universal" + } + ], + "info" : { + "author" : "xcode", + "version" : 1 + } +} diff --git a/MLXServer/Assets.xcassets/AppIcon.appiconset/Contents.json b/MLXServer/Assets.xcassets/AppIcon.appiconset/Contents.json new file mode 100644 index 0000000..3f00db4 --- /dev/null +++ b/MLXServer/Assets.xcassets/AppIcon.appiconset/Contents.json @@ -0,0 +1,58 @@ +{ + "images" : [ + { + "idiom" : "mac", + "scale" : "1x", + "size" : "16x16" + }, + { + "idiom" : "mac", + "scale" : "2x", + "size" : "16x16" + }, + { + "idiom" : "mac", + "scale" : "1x", + "size" : "32x32" + }, + { + "idiom" : "mac", + "scale" : "2x", + "size" : "32x32" + }, + { + "idiom" : "mac", + "scale" : "1x", + "size" : "128x128" + }, + { + "idiom" : "mac", + "scale" : "2x", + "size" : "128x128" + }, + { + "idiom" : "mac", + "scale" : "1x", + "size" : "256x256" + }, + { + "idiom" : "mac", + "scale" : "2x", + "size" : "256x256" + }, + { + "idiom" : "mac", + "scale" : "1x", + "size" : "512x512" + }, + { + "idiom" : "mac", + "scale" : "2x", + "size" : "512x512" + } + ], + "info" : { + "author" : "xcode", + "version" : 1 + } +} diff --git a/MLXServer/Assets.xcassets/Contents.json b/MLXServer/Assets.xcassets/Contents.json new file mode 100644 index 0000000..73c0059 --- /dev/null +++ b/MLXServer/Assets.xcassets/Contents.json @@ -0,0 +1,6 @@ +{ + "info" : { + "author" : "xcode", + "version" : 1 + } +} diff --git a/MLXServer/ContentView.swift b/MLXServer/ContentView.swift new file mode 100644 index 0000000..62ced89 --- /dev/null +++ b/MLXServer/ContentView.swift @@ -0,0 +1,115 @@ +import SwiftUI + +struct ContentView: View { + @Environment(ModelManager.self) private var modelManager + @State private var chatVM: ChatViewModel? + @State private var showLoadError = false + + var body: some View { + Group { + if let chatVM { + ChatView(viewModel: chatVM) + } else { + ProgressView("Initializing…") + } + } + .navigationTitle(modelManager.currentModel?.displayName ?? "MLX Server") + .onAppear { + if chatVM == nil { + chatVM = ChatViewModel(modelManager: modelManager) + // Auto-start API server if configured + if Preferences.apiAutoStart { + chatVM?.startAPIServer() + } + } + } + .onChange(of: modelManager.currentModel) { + chatVM?.resetSession() + // Persist last used model + if let id = modelManager.currentModel?.id { + Preferences.lastModelId = id + } + } + .onChange(of: modelManager.errorMessage) { + showLoadError = modelManager.errorMessage != nil + } + .alert("Model Error", isPresented: $showLoadError) { + Button("Retry") { + if let config = modelManager.currentModel ?? ModelConfig.availableModels.first { + Task { await modelManager.loadModel(config) } + } + } + Button("Cancel", role: .cancel) { + modelManager.errorMessage = nil + } + } message: { + Text(modelManager.errorMessage ?? "Unknown error loading model.") + } + .toolbar { + ToolbarItem(placement: .principal) { + ModelPickerView() + } + + ToolbarItemGroup(placement: .primaryAction) { + // API server toggle + Button { + if let chatVM { + if chatVM.apiServer.isRunning { + chatVM.stopAPIServer() + } else { + chatVM.startAPIServer() + } + } + } label: { + // Running → solid globe (green tint), click to stop + // Stopped → slashed globe, click to start + Label( + chatVM?.apiServer.isRunning == true ? "Stop API" : "Start API", + systemImage: chatVM?.apiServer.isRunning == true ? "network" : "network.slash" + ) + .foregroundStyle(chatVM?.apiServer.isRunning == true ? .green : .secondary) + } + .help(chatVM?.apiServer.isRunning == true ? "API server running on port \(Preferences.apiPort) — click to stop" : "Click to start API server") + + // New conversation + Button { + chatVM?.newConversation() + } label: { + Label("New Chat", systemImage: "plus.message") + } + .keyboardShortcut("n", modifiers: .command) + } + } + // Cmd+1/2/3 model switching + .background { + modelSwitchShortcuts + } + } + + @ViewBuilder + private var modelSwitchShortcuts: some View { + ForEach(Array(ModelConfig.availableModels.enumerated()), id: \.element.id) { index, config in + if index < 9 { + Button("") { + Task { await modelManager.loadModel(config) } + } + .keyboardShortcut(KeyEquivalent(Character(String(index + 1))), modifiers: .command) + .hidden() + } + } + } +} + +/// The main chat layout: messages + input area + status bar. +struct ChatView: View { + @Bindable var viewModel: ChatViewModel + + var body: some View { + VStack(spacing: 0) { + ChatMessagesView(viewModel: viewModel) + Divider() + ChatInputView(viewModel: viewModel) + StatusBarView(viewModel: viewModel) + } + } +} diff --git a/MLXServer/MLXServer.entitlements b/MLXServer/MLXServer.entitlements new file mode 100644 index 0000000..779c582 --- /dev/null +++ b/MLXServer/MLXServer.entitlements @@ -0,0 +1,14 @@ + + + + + com.apple.security.app-sandbox + + com.apple.security.network.client + + com.apple.security.network.server + + com.apple.security.files.user-selected.read-only + + + diff --git a/MLXServer/MLXServerApp.swift b/MLXServer/MLXServerApp.swift new file mode 100644 index 0000000..993430c --- /dev/null +++ b/MLXServer/MLXServerApp.swift @@ -0,0 +1,33 @@ +import SwiftUI +import MLX + +@main +struct MLXServerApp: App { + @State private var modelManager = ModelManager() + + init() { + MLX.GPU.set(cacheLimit: 20 * 1024 * 1024) + } + + var body: some Scene { + WindowGroup { + ContentView() + .environment(modelManager) + .task { + // Auto-load last used model (or default) + let modelId = Preferences.lastModelId ?? ModelConfig.default.id + if let config = ModelConfig.availableModels.first(where: { $0.id == modelId }) { + await modelManager.loadModel(config) + } + } + } + .windowStyle(.titleBar) + .defaultSize(width: 800, height: 700) + + #if os(macOS) + Settings { + SettingsView() + } + #endif + } +} diff --git a/MLXServer/Models/ChatMessage.swift b/MLXServer/Models/ChatMessage.swift new file mode 100644 index 0000000..5f8860d --- /dev/null +++ b/MLXServer/Models/ChatMessage.swift @@ -0,0 +1,60 @@ +import AppKit +import Foundation + +/// A single message in the chat conversation. +struct ChatMessage: Identifiable { + let id = UUID() + let role: Role + var content: String + var images: [NSImage] + var isStreaming: Bool + let timestamp: Date + + enum Role: String { + case system + case user + case assistant + } + + init(role: Role, content: String, images: [NSImage] = [], isStreaming: Bool = false) { + self.role = role + self.content = content + self.images = images + self.isStreaming = isStreaming + self.timestamp = Date() + } +} + +/// Observable conversation state holding all messages. +@Observable +@MainActor +final class Conversation { + var messages: [ChatMessage] = [] + + func addUserMessage(_ text: String, images: [NSImage] = []) { + messages.append(ChatMessage(role: .user, content: text, images: images)) + } + + /// Adds an empty assistant message (to be filled via streaming) and returns its index. + func addAssistantMessage() -> Int { + let msg = ChatMessage(role: .assistant, content: "", isStreaming: true) + messages.append(msg) + return messages.count - 1 + } + + /// Appends a text chunk to the assistant message at the given index. + func appendToMessage(at index: Int, chunk: String) { + guard index < messages.count else { return } + messages[index].content += chunk + } + + /// Marks the assistant message at the given index as done streaming. + func finalizeMessage(at index: Int) { + guard index < messages.count else { return } + messages[index].isStreaming = false + } + + func clear() { + messages.removeAll() + } +} diff --git a/MLXServer/Models/ModelConfig.swift b/MLXServer/Models/ModelConfig.swift new file mode 100644 index 0000000..3e678d2 --- /dev/null +++ b/MLXServer/Models/ModelConfig.swift @@ -0,0 +1,56 @@ +import Foundation +import MLXLMCommon + +/// Defines a supported model with its metadata. +struct ModelConfig: Identifiable, Hashable { + let id: String // alias: "gemma", "gemma3n", "qwen" + let repoId: String // HuggingFace ID + let displayName: String + let contextLength: Int + + /// All models supported by the app. + static let availableModels: [ModelConfig] = [ + ModelConfig( + id: "gemma", + repoId: "mlx-community/gemma-3-4b-it-4bit", + displayName: "Gemma 3 4B", + contextLength: 128_000 + ), + ModelConfig( + id: "qwen", + repoId: "mlx-community/Qwen3-VL-4B-Instruct-4bit", + displayName: "Qwen3 VL 4B", + contextLength: 256_000 + ), + ] + + static let `default` = availableModels[0] + + /// Whether this model is cached locally (no download needed). + var isLocal: Bool { + LocalModelResolver.isAvailable(repoId: repoId) + } + + /// Build a ModelConfiguration for mlx-swift-lm from this config. + var modelConfiguration: ModelConfiguration { + ModelConfiguration(id: repoId) + } + + /// Resolve a model string (alias, full repo ID, or partial match) to a ModelConfig. + /// Mirrors the Python server's `ModelManager.resolve_model()`. + static func resolve(_ requested: String) -> ModelConfig? { + // Exact alias match + if let config = availableModels.first(where: { $0.id == requested }) { + return config + } + // Exact repo ID match + if let config = availableModels.first(where: { $0.repoId == requested }) { + return config + } + // Partial match (e.g. "gemma-3-4b-it" matches the gemma entry) + if let config = availableModels.first(where: { requested.contains($0.id) || $0.repoId.contains(requested) || requested.contains($0.repoId) }) { + return config + } + return nil + } +} diff --git a/MLXServer/Server/APIModels.swift b/MLXServer/Server/APIModels.swift new file mode 100644 index 0000000..0841412 --- /dev/null +++ b/MLXServer/Server/APIModels.swift @@ -0,0 +1,237 @@ +import Foundation + +// MARK: - Request models + +struct APIFunctionDefinition: Codable { + let name: String + let description: String? + let parameters: [String: AnyCodable]? +} + +struct APIToolDefinition: Codable { + let type: String // "function" + let function: APIFunctionDefinition +} + +struct APIFunctionCall: Codable { + let name: String + let arguments: String // JSON string +} + +struct APIToolCall: Codable { + let index: Int + let id: String + let type: String // "function" + let function: APIFunctionCall + + init(index: Int = 0, id: String, type: String = "function", function: APIFunctionCall) { + self.index = index + self.id = id + self.type = type + self.function = function + } +} + +struct APIImageURL: Codable { + let url: String + let detail: String? +} + +struct APIContentPart: Codable { + let type: String // "text" or "image_url" + let text: String? + let image_url: APIImageURL? +} + +struct APIChatMessage: Codable { + let role: String + let content: MessageContent? + let name: String? + let tool_calls: [APIToolCall]? + let tool_call_id: String? + + enum MessageContent: Codable { + case text(String) + case parts([APIContentPart]) + + init(from decoder: Decoder) throws { + let container = try decoder.singleValueContainer() + if let text = try? container.decode(String.self) { + self = .text(text) + } else if let parts = try? container.decode([APIContentPart].self) { + self = .parts(parts) + } else { + self = .text("") + } + } + + func encode(to encoder: Encoder) throws { + var container = encoder.singleValueContainer() + switch self { + case .text(let text): + try container.encode(text) + case .parts(let parts): + try container.encode(parts) + } + } + + /// Extract plain text content. + var textContent: String { + switch self { + case .text(let t): return t + case .parts(let parts): + return parts.compactMap { $0.text }.joined() + } + } + + /// Extract image URLs/base64 data URIs. + var imageURLs: [String] { + switch self { + case .text: return [] + case .parts(let parts): + return parts.compactMap { $0.image_url?.url } + } + } + } +} + +struct APIChatCompletionRequest: Codable { + let model: String? + let messages: [APIChatMessage] + let temperature: Double? + let top_p: Double? + let max_tokens: Int? + let stream: Bool? + let stop: StopSequence? + let tools: [APIToolDefinition]? + let tool_choice: AnyCodable? + let frequency_penalty: Double? + let presence_penalty: Double? + let n: Int? + + enum StopSequence: Codable { + case single(String) + case multiple([String]) + + init(from decoder: Decoder) throws { + let container = try decoder.singleValueContainer() + if let s = try? container.decode(String.self) { + self = .single(s) + } else if let arr = try? container.decode([String].self) { + self = .multiple(arr) + } else { + self = .multiple([]) + } + } + + func encode(to encoder: Encoder) throws { + var container = encoder.singleValueContainer() + switch self { + case .single(let s): try container.encode(s) + case .multiple(let arr): try container.encode(arr) + } + } + } +} + +// MARK: - Response models + +struct APIUsageInfo: Codable { + let prompt_tokens: Int + let completion_tokens: Int + let total_tokens: Int +} + +struct APIChoiceMessage: Codable { + let role: String + let content: String? + let tool_calls: [APIToolCall]? +} + +struct APIChoice: Codable { + let index: Int + let message: APIChoiceMessage + let finish_reason: String? +} + +struct APIChatCompletionResponse: Codable { + let id: String + let object: String + let created: Int + let model: String + let choices: [APIChoice] + let usage: APIUsageInfo +} + +// MARK: - Streaming response models + +struct APIDeltaMessage: Codable { + let role: String? + let content: String? + let tool_calls: [APIToolCall]? +} + +struct APIStreamChoice: Codable { + let index: Int + let delta: APIDeltaMessage + let finish_reason: String? +} + +struct APIChatCompletionChunk: Codable { + let id: String + let object: String + let created: Int + let model: String + let choices: [APIStreamChoice] + let usage: APIUsageInfo? +} + +// MARK: - Model listing + +struct APIModelInfo: Codable { + let id: String + let object: String + let created: Int + let owned_by: String + let context_window: Int? +} + +struct APIModelListResponse: Codable { + let object: String + let data: [APIModelInfo] +} + +// MARK: - Utility: type-erased Codable + +struct AnyCodable: Codable { + let value: Any + + init(_ value: Any) { + self.value = value + } + + init(from decoder: Decoder) throws { + let container = try decoder.singleValueContainer() + if let intVal = try? container.decode(Int.self) { value = intVal } + else if let doubleVal = try? container.decode(Double.self) { value = doubleVal } + else if let boolVal = try? container.decode(Bool.self) { value = boolVal } + else if let stringVal = try? container.decode(String.self) { value = stringVal } + else if let arrayVal = try? container.decode([AnyCodable].self) { value = arrayVal.map(\.value) } + else if let dictVal = try? container.decode([String: AnyCodable].self) { + value = dictVal.mapValues(\.value) + } else { value = NSNull() } + } + + func encode(to encoder: Encoder) throws { + var container = encoder.singleValueContainer() + switch value { + case let v as Int: try container.encode(v) + case let v as Double: try container.encode(v) + case let v as Bool: try container.encode(v) + case let v as String: try container.encode(v) + case let v as [Any]: try container.encode(v.map { AnyCodable($0) }) + case let v as [String: Any]: try container.encode(v.mapValues { AnyCodable($0) }) + default: try container.encodeNil() + } + } +} diff --git a/MLXServer/Server/APIServer.swift b/MLXServer/Server/APIServer.swift new file mode 100644 index 0000000..5944842 --- /dev/null +++ b/MLXServer/Server/APIServer.swift @@ -0,0 +1,814 @@ +import AppKit +import Foundation +import MLXLMCommon +import Network + +/// Lightweight HTTP server that exposes OpenAI-compatible endpoints. +/// Runs entirely in-process using NWListener (Network.framework, no third-party deps). +@Observable +@MainActor +final class APIServer { + var isRunning = false + var port: Int = 1234 + var requestCount: Int = 0 + + private var listener: NWListener? + private var modelManager: ModelManager? + + // Persistent ChatSession for KV cache reuse across requests + private var cachedSession: ChatSession? + private var cachedMessages: [Chat.Message]? + private var cachedModelId: String? + + func start(modelManager: ModelManager, port: Int = 1234) { + guard !isRunning else { return } + self.modelManager = modelManager + self.port = port + + do { + let params = NWParameters.tcp + params.allowLocalEndpointReuse = true + listener = try NWListener(using: params, on: NWEndpoint.Port(integerLiteral: UInt16(port))) + + listener?.stateUpdateHandler = { [weak self] state in + Task { @MainActor in + switch state { + case .ready: + self?.isRunning = true + print("[APIServer] Listening on port \(port)") + case .failed(let error): + self?.isRunning = false + print("[APIServer] Failed: \(error)") + case .cancelled: + self?.isRunning = false + default: + break + } + } + } + + listener?.newConnectionHandler = { [weak self] connection in + Task { @MainActor in + self?.handleConnection(connection) + } + } + + listener?.start(queue: .global(qos: .userInitiated)) + } catch { + print("[APIServer] Failed to start: \(error)") + } + } + + func stop() { + listener?.cancel() + listener = nil + isRunning = false + cachedSession = nil + cachedMessages = nil + cachedModelId = nil + } + + // MARK: - Connection handling + + private func handleConnection(_ connection: NWConnection) { + connection.start(queue: .global(qos: .userInitiated)) + receiveFullHTTPRequest(connection: connection, accumulated: Data()) + } + + /// Receive the full HTTP request, accumulating data until we have the complete body. + /// This handles large POST bodies (e.g. base64 images) that arrive in multiple chunks. + private func receiveFullHTTPRequest(connection: NWConnection, accumulated: Data) { + connection.receive(minimumIncompleteLength: 1, maximumLength: 4_194_304) { + [weak self] data, _, isComplete, error in + guard let self else { connection.cancel(); return } + + var buffer = accumulated + if let data { buffer.append(data) } + + // Try to determine if we have the full request + if let request = HTTPRequest.parse(buffer) { + // Check if we have enough body data based on Content-Length + if let clHeader = request.headers["content-length"], + let contentLength = Int(clHeader), + (request.body?.count ?? 0) < contentLength { + // Need more data + if isComplete { + // Connection closed before we got all data + Task { @MainActor in + self.sendResponse(connection: connection, status: 400, body: #"{"error":"Incomplete request body"}"#) + } + } else { + self.receiveFullHTTPRequest(connection: connection, accumulated: buffer) + } + return + } + + Task { @MainActor in + self.requestCount += 1 + await self.processHTTPRequest(request: request, connection: connection) + } + } else if isComplete { + Task { @MainActor in + self.sendResponse(connection: connection, status: 400, body: #"{"error":"Bad Request"}"#) + } + } else { + // Keep accumulating + self.receiveFullHTTPRequest(connection: connection, accumulated: buffer) + } + } + } + + private func processHTTPRequest(request: HTTPRequest, connection: NWConnection) async { + // CORS preflight + if request.method == "OPTIONS" { + sendResponse(connection: connection, status: 200, body: "", extraHeaders: corsHeaders()) + return + } + + switch (request.method, request.path) { + case ("GET", "/health"): + sendResponse(connection: connection, status: 200, body: #"{"status":"ok"}"#) + + case ("GET", "/v1/models"): + await handleListModels(connection: connection) + + case ("POST", "/v1/chat/completions"): + await handleChatCompletions(connection: connection, body: request.body) + + default: + sendResponse(connection: connection, status: 404, body: #"{"error":"Not Found"}"#) + } + } + + // MARK: - GET /v1/models + + private func handleListModels(connection: NWConnection) async { + let models = ModelConfig.availableModels.map { config in + APIModelInfo( + id: config.repoId, + object: "model", + created: Int(Date().timeIntervalSince1970), + owned_by: "local", + context_window: config.contextLength + ) + } + let response = APIModelListResponse(object: "list", data: models) + + if let json = try? JSONEncoder().encode(response) { + sendResponse(connection: connection, status: 200, body: String(data: json, encoding: .utf8) ?? "{}") + } + } + + // MARK: - POST /v1/chat/completions + + private func handleChatCompletions(connection: NWConnection, body: Data?) async { + guard let body, let request = try? JSONDecoder().decode(APIChatCompletionRequest.self, from: body) else { + sendResponse(connection: connection, status: 400, body: #"{"error":"Invalid request body"}"#) + return + } + + guard let modelManager else { + sendResponse(connection: connection, status: 503, body: #"{"error":"No model loaded"}"#) + return + } + + // Model swapping: if the request specifies a different model, swap to it + if let requestedModel = request.model, !requestedModel.isEmpty { + if let targetConfig = ModelConfig.resolve(requestedModel) { + if modelManager.currentModel?.id != targetConfig.id { + print("[APIServer] Swapping model: \(modelManager.currentModel?.repoId ?? "none") -> \(targetConfig.repoId)") + cachedSession = nil + cachedMessages = nil + cachedModelId = nil + await modelManager.loadModel(targetConfig) + } + } + // If we can't resolve the model, continue with whatever is loaded + } + + guard modelManager.isReady, let container = modelManager.modelContainer else { + sendResponse(connection: connection, status: 503, body: #"{"error":"No model loaded"}"#) + return + } + + let isStream = request.stream ?? false + let temperature = request.temperature ?? 0.7 + let topP = request.top_p ?? 1.0 + let maxTokens = request.max_tokens ?? 4096 + let requestId = "chatcmpl-\(UUID().uuidString.prefix(12).lowercased())" + let created = Int(Date().timeIntervalSince1970) + let modelName = request.model ?? modelManager.currentModel?.repoId ?? "unknown" + let contextLength = modelManager.currentModel?.contextLength ?? 0 + + // Convert API messages to Chat.Message, extracting images from content parts + var chatMessages: [Chat.Message] = [] + var images: [UserInput.Image] = [] + let currentModelRepoId = modelManager.currentModel?.repoId ?? modelName + + // Inject tool definitions into the system prompt if tools are provided + if let tools = request.tools, !tools.isEmpty { + let toolSystemPrompt = ToolPromptBuilder.buildSystemPrompt(tools: tools, modelId: currentModelRepoId) + + // Check if there's already a system message + let hasSystem = request.messages.contains { $0.role == "system" } + if hasSystem { + // Append tool prompt to existing system message (handled below during conversion) + } else { + // For Gemma: inject as user message (Gemma doesn't support system role natively) + // For Qwen: inject as system message + if currentModelRepoId.lowercased().contains("qwen") { + chatMessages.append(Chat.Message(role: .system, content: toolSystemPrompt)) + } else { + chatMessages.append(Chat.Message(role: .user, content: toolSystemPrompt)) + chatMessages.append(Chat.Message(role: .assistant, content: "Understood. I will use the provided tools when appropriate.")) + } + } + } + + let toolsForInjection = request.tools + let isQwen = currentModelRepoId.lowercased().contains("qwen") + + for msg in request.messages { + let role: Chat.Message.Role = switch msg.role { + case "system": .system + case "assistant": .assistant + case "tool": .user + default: .user + } + + var text = msg.content?.textContent ?? "" + + // If this is a system message and tools are provided, append tool definitions + if msg.role == "system", let tools = toolsForInjection, !tools.isEmpty { + let toolSystemPrompt = ToolPromptBuilder.buildSystemPrompt(tools: tools, modelId: currentModelRepoId) + text = text + "\n\n" + toolSystemPrompt + } + + // Format tool_call_id responses as tool_output for the model + if msg.role == "tool" { + if isQwen { + // Qwen expects tool results as-is in a user message + // (the role is already mapped to .user above) + } else { + // Gemma expects tool results wrapped in ```tool_output``` blocks + text = "```tool_output\n\(text)\n```" + } + } + + // Format assistant tool_calls back into model-native format + if msg.role == "assistant", let toolCalls = msg.tool_calls, !toolCalls.isEmpty { + let formattedCalls: String + if isQwen { + formattedCalls = ToolPromptBuilder.formatQwenToolCalls(toolCalls) + } else { + formattedCalls = ToolPromptBuilder.formatGemmaToolCalls(toolCalls) + } + text = (text.isEmpty ? "" : text + "\n") + formattedCalls + } + + // Extract base64 images from content parts + let imageURLs = msg.content?.imageURLs ?? [] + var messageImages: [UserInput.Image] = [] + for urlString in imageURLs { + if let image = decodeBase64Image(urlString) { + messageImages.append(image) + } + } + + // Attach images to this specific message + chatMessages.append(Chat.Message(role: role, content: text, images: messageImages)) + images.append(contentsOf: messageImages) + } + + // Context window check: estimate token count and reject if over limit + if contextLength > 0 { + let totalChars = chatMessages.reduce(0) { $0 + $1.content.count } + let estimatedTokens = totalChars * 10 / 35 // ~3.5 chars per token + let needed = estimatedTokens + maxTokens + if needed > contextLength { + let errorBody = """ + {"error":{"message":"This model's maximum context length is \(contextLength) tokens. \ + However, your messages resulted in approximately \(estimatedTokens) tokens and \ + \(maxTokens) tokens were requested for the completion (\(needed) total). \ + Please reduce the length of the messages or completion.",\ + "type":"invalid_request_error","code":"context_length_exceeded"}} + """ + sendResponse(connection: connection, status: 400, body: errorBody) + return + } + } + + let generateParams = GenerateParameters( + maxTokens: maxTokens, + temperature: Float(temperature), + topP: Float(topP) + ) + + // Feed all messages except the last as history, then send the last as the prompt + let allButLast = Array(chatMessages.dropLast()) + let lastMessage = chatMessages.last ?? Chat.Message(role: .user, content: "") + + // KV cache reuse: check if the cached session's history matches + let currentModelId = modelManager.currentModel?.id + let canReuse = cachedSession != nil + && cachedModelId == currentModelId + && cachedMessages != nil + && messagesMatch(cachedMessages!, allButLast) + + let session: ChatSession + if canReuse { + print("[APIServer] Reusing cached session (\(allButLast.count) history messages)") + session = cachedSession! + session.generateParameters = generateParams + } else { + if cachedSession != nil { + print("[APIServer] History diverged, creating fresh session") + } + if !allButLast.isEmpty { + session = ChatSession( + container, + history: allButLast, + generateParameters: generateParams + ) + } else { + session = ChatSession( + container, + generateParameters: generateParams + ) + } + } + + // Extract images from the last message only (ChatSession.streamDetails takes images separately) + let lastImages = lastMessage.images + + if isStream { + await handleStreamingResponse( + connection: connection, + session: session, + prompt: lastMessage.content, + images: lastImages, + tools: request.tools, + requestId: requestId, + created: created, + modelName: modelName + ) + } else { + await handleNonStreamingResponse( + connection: connection, + session: session, + prompt: lastMessage.content, + images: lastImages, + tools: request.tools, + requestId: requestId, + created: created, + modelName: modelName + ) + } + + // Cache the session for reuse on next request + // allButLast + lastMessage (user) + assistant response = new cached history + cachedSession = session + cachedMessages = chatMessages // full messages including the one just sent + cachedModelId = currentModelId + } + + /// Decode a base64 data URI (data:image/png;base64,...) into a UserInput.Image. + private func decodeBase64Image(_ urlString: String) -> UserInput.Image? { + // Handle data URIs: data:image/png;base64, + let base64String: String + if urlString.hasPrefix("data:") { + guard let commaIndex = urlString.firstIndex(of: ",") else { return nil } + base64String = String(urlString[urlString.index(after: commaIndex)...]) + } else { + // Could be a plain base64 string + base64String = urlString + } + + guard let data = Data(base64Encoded: base64String), + let nsImage = NSImage(data: data), + let cgImage = nsImage.cgImage(forProposedRect: nil, context: nil, hints: nil) else { + return nil + } + + return .ciImage(CIImage(cgImage: cgImage)) + } + + // MARK: - Non-streaming response + + private func handleNonStreamingResponse( + connection: NWConnection, + session: ChatSession, + prompt: String, + images: [UserInput.Image], + tools: [APIToolDefinition]?, + requestId: String, + created: Int, + modelName: String + ) async { + do { + var fullText = "" + var promptTokens = 0 + var completionTokens = 0 + var frameworkToolCalls: [MLXLMCommon.ToolCall] = [] + + let stream = session.streamDetails( + to: prompt, + images: images, + videos: [] + ) + + for try await generation in stream { + switch generation { + case .chunk(let text): + fullText += text + case .info(let info): + promptTokens = info.promptTokenCount + completionTokens = info.generationTokenCount + case .toolCall(let call): + frameworkToolCalls.append(call) + } + } + + // Parse tool calls: first check framework-detected ones, then our own text parser + var finishReason = "stop" + var responseContent: String? = fullText + var apiToolCalls: [APIToolCall]? = nil + + if !frameworkToolCalls.isEmpty { + // Framework natively detected tool calls (e.g. Qwen) + finishReason = "tool_calls" + apiToolCalls = frameworkToolCalls.enumerated().map { i, tc in + let argsJSON: String + let argsDict = tc.function.arguments.mapValues { $0.anyValue } + if let data = try? JSONSerialization.data(withJSONObject: argsDict), + let str = String(data: data, encoding: .utf8) { + argsJSON = str + } else { + argsJSON = "{}" + } + let callId = String(format: "call_%d_%08d", i, abs(tc.function.name.hashValue) % 100_000_000) + return APIToolCall( + index: i, + id: callId, + type: "function", + function: APIFunctionCall(name: tc.function.name, arguments: argsJSON) + ) + } + responseContent = fullText.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty ? nil : fullText + } else if let tools, !tools.isEmpty { + // Try our own text parser (e.g. Gemma tool_code blocks) + let (cleanText, parsedCalls) = ToolCallParser.parse(text: fullText, tools: tools) + if !parsedCalls.isEmpty { + finishReason = "tool_calls" + apiToolCalls = parsedCalls.enumerated().map { i, tc in + APIToolCall( + index: i, + id: tc.id, + type: "function", + function: APIFunctionCall(name: tc.name, arguments: tc.arguments) + ) + } + responseContent = cleanText.isEmpty ? nil : cleanText + } + } + + let response = APIChatCompletionResponse( + id: requestId, + object: "chat.completion", + created: created, + model: modelName, + choices: [ + APIChoice( + index: 0, + message: APIChoiceMessage( + role: "assistant", + content: responseContent, + tool_calls: apiToolCalls + ), + finish_reason: finishReason + ) + ], + usage: APIUsageInfo( + prompt_tokens: promptTokens, + completion_tokens: completionTokens, + total_tokens: promptTokens + completionTokens + ) + ) + + if let json = try? JSONEncoder().encode(response) { + sendResponse(connection: connection, status: 200, body: String(data: json, encoding: .utf8) ?? "{}") + } + } catch { + sendResponse(connection: connection, status: 500, body: #"{"error":"\#(error.localizedDescription)"}"#) + } + } + + // MARK: - Streaming (SSE) response + + private func handleStreamingResponse( + connection: NWConnection, + session: ChatSession, + prompt: String, + images: [UserInput.Image], + tools: [APIToolDefinition]?, + requestId: String, + created: Int, + modelName: String + ) async { + // Send SSE headers + let header = [ + "HTTP/1.1 200 OK", + "Content-Type: text/event-stream", + "Cache-Control: no-cache", + "Connection: keep-alive", + "Access-Control-Allow-Origin: *", + "", + "", + ].joined(separator: "\r\n") + + let headerSent = await withCheckedContinuation { continuation in + connection.send(content: header.data(using: .utf8), completion: .contentProcessed({ _ in + continuation.resume(returning: true) + })) + } + guard headerSent else { return } + + // Send initial role chunk + sendSSEEvent(connection: connection, chunk: APIChatCompletionChunk( + id: requestId, + object: "chat.completion.chunk", + created: created, + model: modelName, + choices: [APIStreamChoice(index: 0, delta: APIDeltaMessage(role: "assistant", content: nil, tool_calls: nil), finish_reason: nil)], + usage: nil + )) + + // When tools are available, buffer full response to parse tool calls + // (otherwise raw tool-call markup leaks into streamed text) + let bufferForTools = tools != nil && !(tools?.isEmpty ?? true) + + var promptTokens = 0 + var completionTokens = 0 + var fullText = "" + var frameworkToolCalls: [MLXLMCommon.ToolCall] = [] + + do { + let stream = session.streamDetails( + to: prompt, + images: images, + videos: [] + ) + + for try await generation in stream { + switch generation { + case .chunk(let text): + completionTokens += 1 + fullText += text + + if !bufferForTools { + sendSSEEvent(connection: connection, chunk: APIChatCompletionChunk( + id: requestId, + object: "chat.completion.chunk", + created: created, + model: modelName, + choices: [APIStreamChoice(index: 0, delta: APIDeltaMessage(role: nil, content: text, tool_calls: nil), finish_reason: nil)], + usage: nil + )) + } + + case .info(let info): + promptTokens = info.promptTokenCount + completionTokens = info.generationTokenCount + + case .toolCall(let call): + frameworkToolCalls.append(call) + } + } + } catch { + let errorEvent = "data: {\"error\":\"\(error.localizedDescription)\"}\n\n" + connection.send(content: errorEvent.data(using: .utf8), completion: .contentProcessed({ _ in })) + } + + // Post-generation: handle tool calls (framework-detected or text-parsed) + var finishReason = "stop" + + if !frameworkToolCalls.isEmpty { + // Framework natively detected tool calls (e.g. Qwen) + finishReason = "tool_calls" + + // Emit any buffered text content + if !fullText.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty { + sendSSEEvent(connection: connection, chunk: APIChatCompletionChunk( + id: requestId, + object: "chat.completion.chunk", + created: created, + model: modelName, + choices: [APIStreamChoice(index: 0, delta: APIDeltaMessage(role: nil, content: fullText, tool_calls: nil), finish_reason: nil)], + usage: nil + )) + } + + // Emit tool call chunks + for (i, tc) in frameworkToolCalls.enumerated() { + let argsDict = tc.function.arguments.mapValues { $0.anyValue } + let argsJSON: String + if let data = try? JSONSerialization.data(withJSONObject: argsDict), + let str = String(data: data, encoding: .utf8) { + argsJSON = str + } else { + argsJSON = "{}" + } + let callId = String(format: "call_%d_%08d", i, abs(tc.function.name.hashValue) % 100_000_000) + let apiToolCall = APIToolCall( + index: i, + id: callId, + type: "function", + function: APIFunctionCall(name: tc.function.name, arguments: argsJSON) + ) + sendSSEEvent(connection: connection, chunk: APIChatCompletionChunk( + id: requestId, + object: "chat.completion.chunk", + created: created, + model: modelName, + choices: [APIStreamChoice(index: 0, delta: APIDeltaMessage(role: nil, content: nil, tool_calls: [apiToolCall]), finish_reason: nil)], + usage: nil + )) + } + } else if bufferForTools { + // Text-parsed tool calls (e.g. Gemma tool_code blocks) + let (cleanText, parsed) = ToolCallParser.parse(text: fullText, tools: tools) + if !parsed.isEmpty { + finishReason = "tool_calls" + fullText = cleanText + } + + // Emit buffered content (cleaned of tool-call markup) + if !fullText.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty { + sendSSEEvent(connection: connection, chunk: APIChatCompletionChunk( + id: requestId, + object: "chat.completion.chunk", + created: created, + model: modelName, + choices: [APIStreamChoice(index: 0, delta: APIDeltaMessage(role: nil, content: fullText, tool_calls: nil), finish_reason: nil)], + usage: nil + )) + } + + // Emit tool call chunks + for (i, tc) in parsed.enumerated() { + let apiToolCall = APIToolCall( + index: i, + id: tc.id, + type: "function", + function: APIFunctionCall(name: tc.name, arguments: tc.arguments) + ) + sendSSEEvent(connection: connection, chunk: APIChatCompletionChunk( + id: requestId, + object: "chat.completion.chunk", + created: created, + model: modelName, + choices: [APIStreamChoice(index: 0, delta: APIDeltaMessage(role: nil, content: nil, tool_calls: [apiToolCall]), finish_reason: nil)], + usage: nil + )) + } + } + + // Final chunk with finish_reason and usage + sendSSEEvent(connection: connection, chunk: APIChatCompletionChunk( + id: requestId, + object: "chat.completion.chunk", + created: created, + model: modelName, + choices: [APIStreamChoice(index: 0, delta: APIDeltaMessage(role: nil, content: nil, tool_calls: nil), finish_reason: finishReason)], + usage: APIUsageInfo( + prompt_tokens: promptTokens, + completion_tokens: completionTokens, + total_tokens: promptTokens + completionTokens + ) + )) + + // Send [DONE] and close + let done = "data: [DONE]\n\n" + connection.send(content: done.data(using: .utf8), completion: .contentProcessed({ _ in + connection.cancel() + })) + } + + private func sendSSEEvent(connection: NWConnection, chunk: APIChatCompletionChunk) { + guard let json = try? JSONEncoder().encode(chunk), + let jsonString = String(data: json, encoding: .utf8) else { return } + let event = "data: \(jsonString)\n\n" + connection.send(content: event.data(using: .utf8), completion: .contentProcessed({ _ in })) + } + + // MARK: - HTTP helpers + + private func sendResponse( + connection: NWConnection, + status: Int, + body: String, + extraHeaders: [String] = [] + ) { + let statusText = switch status { + case 200: "OK" + case 400: "Bad Request" + case 404: "Not Found" + case 500: "Internal Server Error" + case 503: "Service Unavailable" + default: "Error" + } + + var headers = [ + "HTTP/1.1 \(status) \(statusText)", + "Content-Type: application/json", + "Access-Control-Allow-Origin: *", + "Access-Control-Allow-Methods: GET, POST, OPTIONS", + "Access-Control-Allow-Headers: Content-Type, Authorization", + "Content-Length: \(body.utf8.count)", + ] + headers.append(contentsOf: extraHeaders) + headers.append("") + headers.append("") + + let response = headers.joined(separator: "\r\n") + body + connection.send(content: response.data(using: .utf8), completion: .contentProcessed({ _ in + connection.cancel() + })) + } + + private func corsHeaders() -> [String] { + [ + "Access-Control-Allow-Methods: GET, POST, OPTIONS", + "Access-Control-Allow-Headers: Content-Type, Authorization", + "Access-Control-Max-Age: 86400", + ] + } + + /// Check if cached messages are a prefix of new messages (for KV cache reuse). + /// The cached messages include the full history from the previous request. + /// For cache reuse, all but the last message of the new request must match + /// all but the last message of the cached messages (the cached last was the + /// previous user prompt, which is now part of the history). + private func messagesMatch(_ cached: [Chat.Message], _ newHistory: [Chat.Message]) -> Bool { + // The cached messages are the full chatMessages from the previous request. + // For the cache to be reusable, the new history (allButLast) must match + // exactly what the session has already processed. + // After a request, the session has seen: cachedMessages' history + prompt + response. + // So on the next request, if newHistory == cachedMessages, the session already + // contains all of those turns and we can just send the new last message. + guard cached.count == newHistory.count else { return false } + for (a, b) in zip(cached, newHistory) { + if a.role != b.role || a.content != b.content { return false } + } + return true + } +} + +// MARK: - HTTP request parser + +private struct HTTPRequest { + let method: String + let path: String + let headers: [String: String] + let body: Data? + + /// Parse raw HTTP data into a structured request. + /// Uses raw Data operations to correctly handle binary body content. + static func parse(_ data: Data) -> HTTPRequest? { + // Find \r\n\r\n boundary between headers and body + let separator: [UInt8] = [0x0D, 0x0A, 0x0D, 0x0A] // \r\n\r\n + guard let separatorRange = data.firstRange(of: Data(separator)) else { + // No complete header yet — might need more data + // But if data is large enough, treat as malformed + return data.count > 65536 ? nil : nil + } + + let headerData = data[data.startIndex..= 2 else { return nil } + + let method = String(parts[0]) + let fullPath = String(parts[1]) + let path = fullPath.components(separatedBy: "?").first ?? fullPath + + var headers: [String: String] = [:] + for line in lines.dropFirst() { + let kv = line.split(separator: ":", maxSplits: 1) + if kv.count == 2 { + headers[String(kv[0]).trimmingCharacters(in: .whitespaces).lowercased()] = + String(kv[1]).trimmingCharacters(in: .whitespaces) + } + } + + // Body is everything after \r\n\r\n + let bodyStart = separatorRange.upperBound + let body: Data? = bodyStart < data.endIndex ? data[bodyStart.. XML tags. +enum ToolCallParser { + + struct ParsedToolCall { + let id: String + let name: String + let arguments: String // JSON string + } + + /// Parse tool calls from model output. Returns (cleanText, toolCalls). + static func parse(text: String, tools: [APIToolDefinition]?) -> (String, [ParsedToolCall]) { + // Try Qwen-style first ( tags) + let (qwenClean, qwenCalls) = parseQwen(text: text) + if !qwenCalls.isEmpty { + return (qwenClean, qwenCalls) + } + + // Try Gemma-style (```tool_code``` blocks) + let (gemmaClean, gemmaCalls) = parseGemma(text: text, tools: tools) + if !gemmaCalls.isEmpty { + return (gemmaClean, gemmaCalls) + } + + return (text, []) + } + + // MARK: - Gemma: ```tool_code``` blocks + + /// Parse Gemma's tool_code blocks: ```tool_code\nfunc_name(arg="value")\n``` + private static func parseGemma(text: String, tools: [APIToolDefinition]?) -> (String, [ParsedToolCall]) { + let pattern = #"```tool_code\s*(.*?)\s*```"# + guard let regex = try? NSRegularExpression(pattern: pattern, options: .dotMatchesLineSeparators) else { + return (text, []) + } + + let nsText = text as NSString + let matches = regex.matches(in: text, range: NSRange(location: 0, length: nsText.length)) + guard !matches.isEmpty else { return (text, []) } + + var toolCalls: [ParsedToolCall] = [] + + // Build tool definitions map for parameter inference + var toolDefs: [String: [String]] = [:] + if let tools { + for tool in tools { + let paramNames = tool.function.parameters?["properties"]?.value as? [String: Any] + toolDefs[tool.function.name] = paramNames.map { Array($0.keys).sorted() } ?? [] + } + } + + for (i, match) in matches.enumerated() { + let callStr = nsText.substring(with: match.range(at: 1)).trimmingCharacters(in: .whitespacesAndNewlines) + + if let (name, args) = parsePythonCall(callStr, toolDefs: toolDefs) { + let argsJSON = (try? JSONSerialization.data(withJSONObject: args)) + .flatMap { String(data: $0, encoding: .utf8) } ?? "{}" + let callId = String(format: "call_%d_%08d", i, abs(callStr.hashValue) % 100_000_000) + toolCalls.append(ParsedToolCall(id: callId, name: name, arguments: argsJSON)) + } + } + + // Remove tool_code blocks from text + let cleanText = regex.stringByReplacingMatches( + in: text, range: NSRange(location: 0, length: nsText.length), + withTemplate: "" + ).trimmingCharacters(in: .whitespacesAndNewlines) + + return (cleanText, toolCalls) + } + + /// Parse a Python-style function call: func_name(arg1="value", arg2=42) + private static func parsePythonCall(_ callStr: String, toolDefs: [String: [String]]) -> (String, [String: Any])? { + // Match: func_name(args...) + let pattern = #"^(\w+)\s*\((.*)\)\s*$"# + guard let regex = try? NSRegularExpression(pattern: pattern, options: .dotMatchesLineSeparators) else { + return nil + } + + let nsCall = callStr as NSString + let match = regex.firstMatch(in: callStr, range: NSRange(location: 0, length: nsCall.length)) + guard let match else { + // Shell-style: "func_name arg1" + let parts = callStr.split(separator: " ", maxSplits: 1) + guard let name = parts.first.map(String.init), !name.isEmpty else { return nil } + if parts.count > 1 { + let argValue = String(parts[1]).trimmingCharacters(in: .whitespacesAndNewlines) + let paramNames = toolDefs[name] ?? [] + let key = paramNames.first ?? "arg0" + return (name, [key: argValue]) + } + return (name, [:]) + } + + let name = nsCall.substring(with: match.range(at: 1)) + let argsStr = nsCall.substring(with: match.range(at: 2)).trimmingCharacters(in: .whitespacesAndNewlines) + + if argsStr.isEmpty { + return (name, [:]) + } + + // Parse keyword arguments: key="value", key2=42 + var args: [String: Any] = [:] + let kwPattern = #"(\w+)\s*=\s*("(?:[^"\\]|\\.)*"|'(?:[^'\\]|\\.)*'|[^,]+)"# + if let kwRegex = try? NSRegularExpression(pattern: kwPattern, options: []) { + let kwMatches = kwRegex.matches(in: argsStr, range: NSRange(location: 0, length: (argsStr as NSString).length)) + for kwMatch in kwMatches { + let key = (argsStr as NSString).substring(with: kwMatch.range(at: 1)) + var val = (argsStr as NSString).substring(with: kwMatch.range(at: 2)).trimmingCharacters(in: .whitespacesAndNewlines) + // Strip quotes + if (val.hasPrefix("\"") && val.hasSuffix("\"")) || (val.hasPrefix("'") && val.hasSuffix("'")) { + val = String(val.dropFirst().dropLast()) + } + // Try to parse as number/bool + if let intVal = Int(val) { args[key] = intVal } + else if let doubleVal = Double(val) { args[key] = doubleVal } + else if val == "True" || val == "true" { args[key] = true } + else if val == "False" || val == "false" { args[key] = false } + else { args[key] = val } + } + } + + // If no keyword args found, try positional + if args.isEmpty { + let paramNames = toolDefs[name] ?? [] + // Try splitting by comma for positional args + let positionals = argsStr.split(separator: ",").map { + $0.trimmingCharacters(in: .whitespacesAndNewlines) + } + for (i, pos) in positionals.enumerated() { + var val = pos + if (val.hasPrefix("\"") && val.hasSuffix("\"")) || (val.hasPrefix("'") && val.hasSuffix("'")) { + val = String(val.dropFirst().dropLast()) + } + let key = i < paramNames.count ? paramNames[i] : "arg\(i)" + args[key] = val + } + } + + return (name, args) + } + + // MARK: - Qwen: XML tags + + /// Parse Qwen's tool_call tags: {"name":"func","arguments":{...}} + private static func parseQwen(text: String) -> (String, [ParsedToolCall]) { + let pattern = #"\s*(.*?)\s*"# + guard let regex = try? NSRegularExpression(pattern: pattern, options: .dotMatchesLineSeparators) else { + return (text, []) + } + + let nsText = text as NSString + let matches = regex.matches(in: text, range: NSRange(location: 0, length: nsText.length)) + guard !matches.isEmpty else { return (text, []) } + + var toolCalls: [ParsedToolCall] = [] + + for (i, match) in matches.enumerated() { + let jsonStr = nsText.substring(with: match.range(at: 1)).trimmingCharacters(in: .whitespacesAndNewlines) + + guard let data = jsonStr.data(using: .utf8), + let obj = try? JSONSerialization.jsonObject(with: data) as? [String: Any], + let name = obj["name"] as? String else { + continue + } + + var argsJSON = "{}" + if let args = obj["arguments"] { + if let argsDict = args as? [String: Any], + let argsData = try? JSONSerialization.data(withJSONObject: argsDict) { + argsJSON = String(data: argsData, encoding: .utf8) ?? "{}" + } else if let argsStr = args as? String { + argsJSON = argsStr + } + } + + let callId = String(format: "call_%d_%08d", i, abs(jsonStr.hashValue) % 100_000_000) + toolCalls.append(ParsedToolCall(id: callId, name: name, arguments: argsJSON)) + } + + let cleanText = regex.stringByReplacingMatches( + in: text, range: NSRange(location: 0, length: nsText.length), + withTemplate: "" + ).trimmingCharacters(in: .whitespacesAndNewlines) + + return (cleanText, toolCalls) + } +} diff --git a/MLXServer/Server/ToolPromptBuilder.swift b/MLXServer/Server/ToolPromptBuilder.swift new file mode 100644 index 0000000..3da1174 --- /dev/null +++ b/MLXServer/Server/ToolPromptBuilder.swift @@ -0,0 +1,199 @@ +import Foundation + +/// Builds model-specific system prompts that inform the model about available tools. +/// Mirrors the Python server's `_build_tool_system_prompt()` and `_build_qwen_tool_system_prompt()`. +enum ToolPromptBuilder { + + /// Build a tool system prompt appropriate for the current model. + /// - Parameters: + /// - tools: OpenAI-format tool definitions + /// - modelId: The model's repo ID (to determine format) + /// - Returns: A system prompt string describing the available tools + static func buildSystemPrompt(tools: [APIToolDefinition], modelId: String) -> String { + if modelId.lowercased().contains("qwen") { + return buildQwenToolPrompt(tools: tools) + } else { + return buildGemmaToolPrompt(tools: tools) + } + } + + // MARK: - Gemma (tool_code format) + + /// Build the tool system prompt using Google's Gemma 3 tool_code convention. + private static func buildGemmaToolPrompt(tools: [APIToolDefinition]) -> String { + let funcDefs = tools.map { toolToPythonSignature($0.function) } + let functionsBlock = funcDefs.joined(separator: "\n\n") + + return """ + At each turn, if you decide to invoke any of the function(s), \ + it should be wrapped with ```tool_code```. \ + The python methods described below are imported and available, \ + you can only use defined methods. \ + The generated code should be readable and efficient. \ + The response to a method will be wrapped in ```tool_output``` \ + use it to call more tools or generate a helpful, friendly response. + + \(functionsBlock) + """ + } + + /// Convert an OpenAI function definition to a Python function signature with docstring. + private static func toolToPythonSignature(_ func: APIFunctionDefinition) -> String { + let name = `func`.name + let desc = `func`.description ?? "" + let properties = `func`.parameters?["properties"]?.value as? [String: Any] ?? [:] + let requiredArr = `func`.parameters?["required"]?.value as? [String] ?? [] + let required = Set(requiredArr) + + var paramParts: [String] = [] + var docArgs: [String] = [] + + // Sort keys for deterministic output + for pname in properties.keys.sorted() { + guard let pinfo = properties[pname] as? [String: Any] else { continue } + let ptype = jsonTypeToPython(pinfo["type"] as? String ?? "str") + let pdesc = pinfo["description"] as? String ?? "" + + if required.contains(pname) { + paramParts.append("\(pname): \(ptype)") + } else { + let defaultVal = jsonTypeDefault(pinfo["type"] as? String ?? "str") + paramParts.append("\(pname): \(ptype) = \(defaultVal)") + } + docArgs.append(pdesc.isEmpty ? " \(pname)" : " \(pname): \(pdesc)") + } + + let sig = "def \(name)(\(paramParts.joined(separator: ", "))):" + var docLines = [" \"\"\"\(desc)"] + if !docArgs.isEmpty { + docLines.append("") + docLines.append(" Args:") + docLines.append(contentsOf: docArgs) + } + docLines.append(" \"\"\"") + + return sig + "\n" + docLines.joined(separator: "\n") + } + + private static func jsonTypeToPython(_ type: String) -> String { + switch type { + case "string": return "str" + case "integer": return "int" + case "number": return "float" + case "boolean": return "bool" + case "array": return "list" + case "object": return "dict" + default: return "str" + } + } + + private static func jsonTypeDefault(_ type: String) -> String { + switch type { + case "string": return "None" + case "integer": return "0" + case "number": return "0.0" + case "boolean": return "False" + case "array": return "[]" + case "object": return "{}" + default: return "None" + } + } + + // MARK: - Qwen ( format) + + /// Build the tool system prompt for Qwen3 using its native JSON format. + private static func buildQwenToolPrompt(tools: [APIToolDefinition]) -> String { + var toolDescs: [[String: Any]] = [] + for tool in tools { + var funcDict: [String: Any] = [ + "name": tool.function.name, + "description": tool.function.description ?? "", + ] + if let params = tool.function.parameters { + funcDict["parameters"] = params.mapValues(\.value) + } + toolDescs.append([ + "type": "function", + "function": funcDict, + ]) + } + + let toolsJSON: String + if let data = try? JSONSerialization.data(withJSONObject: toolDescs, options: [.prettyPrinted, .sortedKeys]), + let str = String(data: data, encoding: .utf8) { + toolsJSON = str + } else { + toolsJSON = "[]" + } + + return """ + # Tools + + You are a helpful assistant with access to the following tools. \ + Use them when appropriate by responding with a JSON tool call. + + ## Available Tools + + \(toolsJSON) + + ## Tool Call Format + + When you need to call a tool, respond with: + + {"name": "", "arguments": {}} + + """ + } + + // MARK: - Format tool calls back into model-specific format for prompt history + + /// Format OpenAI-style tool calls back into Gemma's tool_code blocks for prompt history. + static func formatGemmaToolCalls(_ toolCalls: [APIToolCall]) -> String { + var parts: [String] = [] + for tc in toolCalls { + let name = tc.function.name + let argsStr = tc.function.arguments + if let data = argsStr.data(using: .utf8), + let args = try? JSONSerialization.jsonObject(with: data) as? [String: Any] { + let argParts = args.keys.sorted().map { key -> String in + let val = args[key]! + return "\(key)=\(pythonRepr(val))" + } + let callStr = "\(name)(\(argParts.joined(separator: ", ")))" + parts.append("```tool_code\n\(callStr)\n```") + } else { + parts.append("```tool_code\n\(name)()\n```") + } + } + return parts.joined(separator: "\n") + } + + /// Format OpenAI-style tool calls back into Qwen's tags for prompt history. + static func formatQwenToolCalls(_ toolCalls: [APIToolCall]) -> String { + var parts: [String] = [] + for tc in toolCalls { + let name = tc.function.name + let argsStr = tc.function.arguments + var callObj: [String: Any] = ["name": name] + if let data = argsStr.data(using: .utf8), + let args = try? JSONSerialization.jsonObject(with: data) { + callObj["arguments"] = args + } + if let data = try? JSONSerialization.data(withJSONObject: callObj), + let str = String(data: data, encoding: .utf8) { + parts.append("\n\(str)\n") + } + } + return parts.joined(separator: "\n") + } + + private static func pythonRepr(_ value: Any) -> String { + switch value { + case let s as String: return "\"\(s)\"" + case let i as Int: return "\(i)" + case let d as Double: return "\(d)" + case let b as Bool: return b ? "True" : "False" + default: return "\"\(value)\"" + } + } +} diff --git a/MLXServer/Utilities/LocalModelResolver.swift b/MLXServer/Utilities/LocalModelResolver.swift new file mode 100644 index 0000000..f869af5 --- /dev/null +++ b/MLXServer/Utilities/LocalModelResolver.swift @@ -0,0 +1,47 @@ +import Foundation + +/// Resolves HuggingFace model repos to local snapshot directories, +/// matching the cache layout used by Python's `huggingface_hub`. +/// +/// Cache structure: +/// ~/.cache/huggingface/hub/models--{org}--{name}/snapshots/{hash}/ +enum LocalModelResolver { + + /// The standard HuggingFace cache directory used by Python's `huggingface_hub`. + private static let cacheBase: URL = { + FileManager.default.homeDirectoryForCurrentUser + .appendingPathComponent(".cache/huggingface/hub", isDirectory: true) + }() + + /// Resolve a HuggingFace repo ID (e.g. "mlx-community/gemma-3-4b-it-4bit") + /// to its local snapshot directory, if it exists. + /// + /// Returns `nil` if the model hasn't been downloaded yet. + static func resolve(repoId: String) -> URL? { + // Convert "mlx-community/gemma-3-4b-it-4bit" → "models--mlx-community--gemma-3-4b-it-4bit" + let dirName = "models--" + repoId.replacingOccurrences(of: "/", with: "--") + let snapshotsDir = cacheBase + .appendingPathComponent(dirName, isDirectory: true) + .appendingPathComponent("snapshots", isDirectory: true) + + // Find the first (usually only) snapshot hash directory + guard let contents = try? FileManager.default.contentsOfDirectory( + at: snapshotsDir, + includingPropertiesForKeys: [.isDirectoryKey], + options: [.skipsHiddenFiles] + ) else { + return nil + } + + // Return the most recent snapshot (last alphabetically = latest hash) + return contents + .filter { (try? $0.resourceValues(forKeys: [.isDirectoryKey]).isDirectory) == true } + .sorted(by: { $0.lastPathComponent < $1.lastPathComponent }) + .last + } + + /// Check if a model is available locally. + static func isAvailable(repoId: String) -> Bool { + resolve(repoId: repoId) != nil + } +} diff --git a/MLXServer/Utilities/Preferences.swift b/MLXServer/Utilities/Preferences.swift new file mode 100644 index 0000000..710410d --- /dev/null +++ b/MLXServer/Utilities/Preferences.swift @@ -0,0 +1,42 @@ +import Foundation + +/// Persisted app preferences via UserDefaults. +enum Preferences { + nonisolated(unsafe) private static let defaults = UserDefaults.standard + + // MARK: - Last used model + + private static let lastModelKey = "lastModelId" + + static var lastModelId: String? { + get { defaults.string(forKey: lastModelKey) } + set { defaults.set(newValue, forKey: lastModelKey) } + } + + // MARK: - System prompt + + private static let systemPromptKey = "systemPrompt" + + static var systemPrompt: String { + get { defaults.string(forKey: systemPromptKey) ?? "" } + set { defaults.set(newValue, forKey: systemPromptKey) } + } + + // MARK: - API server + + private static let apiPortKey = "apiPort" + private static let apiAutoStartKey = "apiAutoStart" + + static var apiPort: Int { + get { + let val = defaults.integer(forKey: apiPortKey) + return val > 0 ? val : 1234 + } + set { defaults.set(newValue, forKey: apiPortKey) } + } + + static var apiAutoStart: Bool { + get { defaults.bool(forKey: apiAutoStartKey) } + set { defaults.set(newValue, forKey: apiAutoStartKey) } + } +} diff --git a/MLXServer/ViewModels/ChatViewModel.swift b/MLXServer/ViewModels/ChatViewModel.swift new file mode 100644 index 0000000..d874c23 --- /dev/null +++ b/MLXServer/ViewModels/ChatViewModel.swift @@ -0,0 +1,158 @@ +import AppKit +import Foundation +import MLX +import MLXLMCommon +import MLXVLM + +/// Drives the chat UI: sending messages, streaming responses, managing images. +@Observable +@MainActor +final class ChatViewModel { + var conversation = Conversation() + var inputText = "" + var attachedImages: [NSImage] = [] + var isGenerating = false + var tokensPerSecond: Double = 0 + var promptTokens: Int = 0 + var generationTokens: Int = 0 + + private var generationTask: Task? + private var chatSession: ChatSession? + + let modelManager: ModelManager + let apiServer = APIServer() + + init(modelManager: ModelManager) { + self.modelManager = modelManager + } + + /// Ensure a ChatSession exists for the current model. + private func ensureSession() { + guard let container = modelManager.modelContainer else { return } + if chatSession == nil { + let systemPrompt = Preferences.systemPrompt + chatSession = ChatSession( + container, + instructions: systemPrompt.isEmpty ? nil : systemPrompt, + generateParameters: GenerateParameters(temperature: 0.7) + ) + } + } + + func send() { + let text = inputText.trimmingCharacters(in: .whitespacesAndNewlines) + guard !text.isEmpty, modelManager.isReady else { return } + + ensureSession() + guard let session = chatSession else { return } + + let images = attachedImages + inputText = "" + attachedImages = [] + + conversation.addUserMessage(text, images: images) + let assistantIndex = conversation.addAssistantMessage() + + isGenerating = true + tokensPerSecond = 0 + promptTokens = 0 + generationTokens = 0 + + // Convert NSImages to UserInput.Image + let inputImages: [UserInput.Image] = images.compactMap { nsImage in + guard let cgImage = nsImage.cgImage(forProposedRect: nil, context: nil, hints: nil) else { + return nil + } + return .ciImage(CIImage(cgImage: cgImage)) + } + + generationTask = Task { + do { + let stream = session.streamDetails( + to: text, + images: inputImages, + videos: [] + ) + + var tokenCount = 0 + let startTime = Date() + + for try await generation in stream { + if Task.isCancelled { break } + + switch generation { + case .chunk(let text): + conversation.appendToMessage(at: assistantIndex, chunk: text) + tokenCount += 1 + let elapsed = Date().timeIntervalSince(startTime) + if elapsed > 0 { + tokensPerSecond = Double(tokenCount) / elapsed + } + generationTokens = tokenCount + + case .info(let info): + promptTokens = info.promptTokenCount + if info.tokensPerSecond > 0 { + tokensPerSecond = info.tokensPerSecond + } + + case .toolCall: + break + } + } + } catch { + if !Task.isCancelled { + conversation.appendToMessage( + at: assistantIndex, + chunk: "\n\n[Error: \(error.localizedDescription)]" + ) + } + } + + conversation.finalizeMessage(at: assistantIndex) + isGenerating = false + generationTask = nil + } + } + + func stop() { + generationTask?.cancel() + generationTask = nil + isGenerating = false + + if let last = conversation.messages.indices.last, + conversation.messages[last].isStreaming { + conversation.finalizeMessage(at: last) + } + } + + func attachImage(_ image: NSImage) { + attachedImages.append(image) + } + + func removeImage(at index: Int) { + guard attachedImages.indices.contains(index) else { return } + attachedImages.remove(at: index) + } + + func newConversation() { + stop() + conversation.clear() + resetSession() + } + + /// Reset the chat session (e.g. on model switch or new conversation). + func resetSession() { + chatSession = nil + } + + // MARK: - API Server + + func startAPIServer() { + apiServer.start(modelManager: modelManager, port: Preferences.apiPort) + } + + func stopAPIServer() { + apiServer.stop() + } +} diff --git a/MLXServer/ViewModels/ModelManager.swift b/MLXServer/ViewModels/ModelManager.swift new file mode 100644 index 0000000..129d527 --- /dev/null +++ b/MLXServer/ViewModels/ModelManager.swift @@ -0,0 +1,71 @@ +import Foundation +import MLX +import MLXLMCommon +import MLXVLM + +/// Manages model loading, switching, and generation. +@Observable +@MainActor +final class ModelManager { + var currentModel: ModelConfig? + var modelContainer: ModelContainer? + var isLoading = false + var downloadProgress: Double = 0 + var loadingModelName: String = "" + var errorMessage: String? + + /// Load a model, unloading the current one first. + /// Prefers the local snapshot from ~/.cache/huggingface/hub/ (shared with the Python server). + /// Only downloads if the model isn't cached locally. + func loadModel(_ config: ModelConfig) async { + if currentModel?.id == config.id && modelContainer != nil { + return // already loaded + } + + unloadModel() + isLoading = true + downloadProgress = 0 + loadingModelName = config.displayName + errorMessage = nil + + do { + let container: ModelContainer + let progressHandler: @Sendable (Progress) -> Void = { progress in + Task { @MainActor in + self.downloadProgress = progress.fractionCompleted + } + } + + let configuration: ModelConfiguration + if let localDir = LocalModelResolver.resolve(repoId: config.repoId) { + configuration = ModelConfiguration(directory: localDir) + } else { + configuration = config.modelConfiguration + } + + container = try await VLMModelFactory.shared.loadContainer( + configuration: configuration, + progressHandler: progressHandler + ) + + self.modelContainer = container + self.currentModel = config + } catch { + self.errorMessage = "Failed to load model: \(error.localizedDescription)" + } + + isLoading = false + } + + /// Unload the current model and free GPU memory. + func unloadModel() { + modelContainer = nil + currentModel = nil + MLX.GPU.clearCache() + } + + /// Whether a model is ready for generation. + var isReady: Bool { + modelContainer != nil && !isLoading + } +} diff --git a/MLXServer/Views/ChatInputView.swift b/MLXServer/Views/ChatInputView.swift new file mode 100644 index 0000000..51ddf4a --- /dev/null +++ b/MLXServer/Views/ChatInputView.swift @@ -0,0 +1,128 @@ +import SwiftUI +import UniformTypeIdentifiers + +struct ChatInputView: View { + @Bindable var viewModel: ChatViewModel + + var body: some View { + VStack(spacing: 8) { + // Image preview strip + if !viewModel.attachedImages.isEmpty { + ScrollView(.horizontal, showsIndicators: false) { + HStack(spacing: 8) { + ForEach(Array(viewModel.attachedImages.enumerated()), id: \.offset) { index, image in + ZStack(alignment: .topTrailing) { + Image(nsImage: image) + .resizable() + .aspectRatio(contentMode: .fill) + .frame(width: 60, height: 60) + .clipShape(RoundedRectangle(cornerRadius: 8)) + + Button { + viewModel.removeImage(at: index) + } label: { + Image(systemName: "xmark.circle.fill") + .font(.caption) + .foregroundStyle(.white) + .background(Circle().fill(.black.opacity(0.5))) + } + .buttonStyle(.plain) + .offset(x: 4, y: -4) + } + } + } + .padding(.horizontal, 12) + } + } + + // Input row + HStack(alignment: .bottom, spacing: 8) { + // Image attach button + Button { + pickImage() + } label: { + Image(systemName: "photo.badge.plus") + .font(.title3) + } + .buttonStyle(.plain) + .disabled(!viewModel.modelManager.isReady) + + // Text field + TextField("Message…", text: $viewModel.inputText, axis: .vertical) + .textFieldStyle(.plain) + .lineLimit(1...8) + .onSubmit { + if !NSEvent.modifierFlags.contains(.shift) { + viewModel.send() + } + } + + // Send or Stop button + if viewModel.isGenerating { + Button { + viewModel.stop() + } label: { + Image(systemName: "stop.circle.fill") + .font(.title2) + .foregroundStyle(.red) + } + .buttonStyle(.plain) + .keyboardShortcut(.escape, modifiers: []) + } else { + Button { + viewModel.send() + } label: { + Image(systemName: "arrow.up.circle.fill") + .font(.title2) + .foregroundStyle(Color.accentColor) + } + .buttonStyle(.plain) + .disabled(viewModel.inputText.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty || !viewModel.modelManager.isReady) + .keyboardShortcut(.return, modifiers: .command) + } + } + .padding(.horizontal, 12) + .padding(.vertical, 10) + } + .padding(.top, 4) + .onDrop(of: [.image], isTargeted: nil) { providers in + for provider in providers { + _ = provider.loadObject(ofClass: NSImage.self) { image, _ in + if let image = image as? NSImage { + Task { @MainActor in + viewModel.attachImage(image) + } + } + } + } + return true + } + // Cmd+V paste for images + .onPasteCommand(of: [.image, .png, .jpeg, .tiff]) { providers in + for provider in providers { + _ = provider.loadObject(ofClass: NSImage.self) { image, _ in + if let image = image as? NSImage { + Task { @MainActor in + viewModel.attachImage(image) + } + } + } + } + } + } + + private func pickImage() { + let panel = NSOpenPanel() + panel.allowedContentTypes = [.image] + panel.allowsMultipleSelection = true + panel.canChooseDirectories = false + + if panel.runModal() == .OK { + for url in panel.urls { + if let image = NSImage(contentsOf: url) { + viewModel.attachImage(image) + } + } + } + } +} diff --git a/MLXServer/Views/ChatMessagesView.swift b/MLXServer/Views/ChatMessagesView.swift new file mode 100644 index 0000000..8f22d0e --- /dev/null +++ b/MLXServer/Views/ChatMessagesView.swift @@ -0,0 +1,105 @@ +import MarkdownUI +import SwiftUI + +struct ChatMessagesView: View { + let viewModel: ChatViewModel + + var body: some View { + ScrollViewReader { proxy in + ScrollView { + LazyVStack(alignment: .leading, spacing: 12) { + if viewModel.conversation.messages.isEmpty { + emptyState + } else { + ForEach(viewModel.conversation.messages) { message in + MessageBubbleView(message: message) + .id(message.id) + } + } + } + .padding() + } + .onChange(of: viewModel.conversation.messages.last?.content) { + scrollToBottom(proxy: proxy) + } + .onChange(of: viewModel.conversation.messages.count) { + scrollToBottom(proxy: proxy) + } + } + } + + private var emptyState: some View { + VStack(spacing: 8) { + Spacer() + Image(systemName: "message") + .font(.system(size: 40)) + .foregroundStyle(.secondary) + Text("Start a conversation") + .font(.title3) + .foregroundStyle(.secondary) + if viewModel.modelManager.currentModel == nil { + Text("Select a model from the toolbar to begin") + .font(.caption) + .foregroundStyle(.tertiary) + } + Spacer() + } + .frame(maxWidth: .infinity, minHeight: 300) + } + + private func scrollToBottom(proxy: ScrollViewProxy) { + if let lastId = viewModel.conversation.messages.last?.id { + withAnimation(.easeOut(duration: 0.2)) { + proxy.scrollTo(lastId, anchor: .bottom) + } + } + } +} + +struct MessageBubbleView: View { + let message: ChatMessage + + var body: some View { + HStack { + if message.role == .user { Spacer(minLength: 60) } + + VStack(alignment: message.role == .user ? .trailing : .leading, spacing: 6) { + // Show attached images + if !message.images.isEmpty { + HStack(spacing: 4) { + ForEach(Array(message.images.enumerated()), id: \.offset) { _, image in + Image(nsImage: image) + .resizable() + .aspectRatio(contentMode: .fill) + .frame(width: 80, height: 80) + .clipShape(RoundedRectangle(cornerRadius: 8)) + } + } + } + + // Message content + if !message.content.isEmpty || message.isStreaming { + Group { + if message.role == .assistant { + Markdown(message.content + (message.isStreaming ? " ●" : "")) + .textSelection(.enabled) + } else { + Text(message.content) + .textSelection(.enabled) + } + } + .padding(.horizontal, 12) + .padding(.vertical, 8) + .background( + message.role == .user + ? Color.accentColor.opacity(0.15) + : Color.secondary.opacity(0.1) + ) + .clipShape(RoundedRectangle(cornerRadius: 12)) + } + } + + if message.role == .assistant { Spacer(minLength: 60) } + } + } +} diff --git a/MLXServer/Views/ModelPickerView.swift b/MLXServer/Views/ModelPickerView.swift new file mode 100644 index 0000000..0c01fe0 --- /dev/null +++ b/MLXServer/Views/ModelPickerView.swift @@ -0,0 +1,32 @@ +import SwiftUI + +struct ModelPickerView: View { + @Environment(ModelManager.self) private var modelManager + + var body: some View { + HStack(spacing: 8) { + Picker("Model", selection: selectedModelBinding) { + ForEach(ModelConfig.availableModels) { config in + Label( + config.displayName, + systemImage: config.isLocal ? "checkmark.circle.fill" : "arrow.down.circle" + ).tag(config.id) + } + } + .frame(width: 160) + .disabled(modelManager.isLoading) + } + } + + private var selectedModelBinding: Binding { + Binding( + get: { modelManager.currentModel?.id ?? ModelConfig.default.id }, + set: { newId in + guard let config = ModelConfig.availableModels.first(where: { $0.id == newId }) else { return } + Task { + await modelManager.loadModel(config) + } + } + ) + } +} diff --git a/MLXServer/Views/SettingsView.swift b/MLXServer/Views/SettingsView.swift new file mode 100644 index 0000000..79ae560 --- /dev/null +++ b/MLXServer/Views/SettingsView.swift @@ -0,0 +1,44 @@ +import SwiftUI + +struct SettingsView: View { + @State private var systemPrompt: String = Preferences.systemPrompt + @State private var apiPort: String = String(Preferences.apiPort) + @State private var apiAutoStart: Bool = Preferences.apiAutoStart + + var body: some View { + Form { + Section("System Prompt") { + TextEditor(text: $systemPrompt) + .font(.body.monospaced()) + .frame(minHeight: 80) + .onChange(of: systemPrompt) { + Preferences.systemPrompt = systemPrompt + } + + Text("Applied to new conversations. Leave empty for no system prompt.") + .font(.caption) + .foregroundStyle(.secondary) + } + + Section("API Server") { + HStack { + Text("Port") + TextField("1234", text: $apiPort) + .frame(width: 80) + .onChange(of: apiPort) { + if let port = Int(apiPort), port > 0, port < 65536 { + Preferences.apiPort = port + } + } + } + + Toggle("Start API server automatically", isOn: $apiAutoStart) + .onChange(of: apiAutoStart) { + Preferences.apiAutoStart = apiAutoStart + } + } + } + .formStyle(.grouped) + .frame(width: 450, height: 300) + } +} diff --git a/MLXServer/Views/StatusBarView.swift b/MLXServer/Views/StatusBarView.swift new file mode 100644 index 0000000..1849442 --- /dev/null +++ b/MLXServer/Views/StatusBarView.swift @@ -0,0 +1,77 @@ +import MLX +import SwiftUI + +struct StatusBarView: View { + let viewModel: ChatViewModel + @Environment(ModelManager.self) private var modelManager + + var body: some View { + HStack(spacing: 16) { + // Model info + if modelManager.isLoading { + let pct = Int(modelManager.downloadProgress * 100) + Text("Loading \(modelManager.loadingModelName)… \(pct)%") + .font(.caption.monospacedDigit()) + .foregroundStyle(.orange) + } else if let model = modelManager.currentModel { + Label(model.displayName, systemImage: "cpu") + .font(.caption) + .foregroundStyle(.secondary) + + Text("\(model.contextLength / 1000)k ctx") + .font(.caption) + .foregroundStyle(.tertiary) + } else { + Text("No model loaded") + .font(.caption) + .foregroundStyle(.secondary) + } + + Spacer() + + // GPU memory + let activeMB = Double(MLX.GPU.activeMemory) / 1_048_576 + if activeMB > 0 { + Text(String(format: "GPU: %.0f MB", activeMB)) + .font(.caption.monospacedDigit()) + .foregroundStyle(.tertiary) + } + + // Token generation speed + if viewModel.isGenerating { + Text(String(format: "%.1f tok/s", viewModel.tokensPerSecond)) + .font(.caption.monospacedDigit()) + .foregroundStyle(.secondary) + } + + // Token counts + if viewModel.promptTokens > 0 || viewModel.generationTokens > 0 { + Text("\(viewModel.promptTokens)→\(viewModel.generationTokens) tok") + .font(.caption.monospacedDigit()) + .foregroundStyle(.tertiary) + } + + // API server status + if viewModel.apiServer.isRunning { + Label("API :\(viewModel.apiServer.port)", systemImage: "network") + .font(.caption) + .foregroundStyle(.green) + } else { + Label("API off", systemImage: "network.slash") + .font(.caption) + .foregroundStyle(.tertiary) + } + + // Error + if let error = modelManager.errorMessage { + Label(error, systemImage: "exclamationmark.triangle") + .font(.caption) + .foregroundStyle(.red) + .lineLimit(1) + } + } + .padding(.horizontal, 12) + .padding(.vertical, 4) + .background(.bar) + } +} diff --git a/README.md b/README.md index 4586f81..4bcbb2d 100644 --- a/README.md +++ b/README.md @@ -1,63 +1,57 @@ # MLX Server -OpenAI-compatible API server for running local LLMs on Apple Silicon via [MLX](https://github.com/ml-explore/mlx). Supports vision and tool use with automatic model swapping — only one model is loaded in memory at a time, switched on demand based on the request's `model` field. +Native macOS app for running local LLMs on Apple Silicon via [MLX](https://github.com/ml-explore/mlx). Built with SwiftUI, it provides both a **chat UI** and an embedded **OpenAI-compatible API server**. Supports vision and tool use with automatic model swapping. ## Supported Models | Alias | Model | Context | Capabilities | |-------|-------|---------|-------------| | `gemma` | `mlx-community/gemma-3-4b-it-4bit` | 128k | Vision, tool use (`tool_code` blocks) | -| `gemma3n` | `mlx-community/gemma-3n-E4B-it-4bit` | 32k | Vision/audio/video, tool use (`tool_code` blocks), ~1.5x faster | | `qwen` | `mlx-community/Qwen3-VL-4B-Instruct-4bit` | 256k | Vision, tool use (`` tags) | ## Quick Start +Requires macOS 15+, Xcode 16.4+, and `xcodegen` (`brew install xcodegen`). + ```bash -source .venv/bin/activate - -# Start with Gemma 3 (default) -./run.sh - -# Start with Qwen3 -./run.sh qwen - -# Or directly -python -m mlx_server.main --model mlx-community/gemma-3-4b-it-4bit --port 1234 +./build.sh # Debug build +open "build/Debug/MLX Server.app" ``` -The server starts at `http://127.0.0.1:1234`. +## App Features -## API +- **Chat interface** with markdown rendering, image attachments (file picker, drag & drop, clipboard paste) +- **Model picker** in toolbar with local/download status indicators +- **Streaming responses** with live token display +- **Status bar** showing model name, context window, tokens/sec, token counts, GPU memory, API server status +- **Keyboard shortcuts**: `Cmd+N` (new chat), `Cmd+Return` (send), `Escape` (stop), `Cmd+1/2/3` (switch models) +- **Settings** (`Cmd+,`): system prompt, API port, API auto-start -Standard OpenAI-compatible endpoints: +## API Server -- `GET /v1/models` — lists all available models with `context_window` sizes +The embedded API server (toggle in toolbar) runs on port 1234 by default. Standard OpenAI-compatible endpoints: + +- `GET /v1/models` — lists available models with `context_window` sizes - `POST /v1/chat/completions` — chat completions (streaming and non-streaming) - `GET /health` — health check ### Model Swapping -Send any available model ID (or alias) in the `model` field. If it differs from the currently loaded model, the server unloads the old one and loads the new one automatically: +Send any model ID or alias in the `model` field. If it differs from the currently loaded model, the server swaps automatically: ```bash -# Uses Gemma curl http://localhost:1234/v1/chat/completions \ -H "Content-Type: application/json" \ - -d '{"model": "mlx-community/gemma-3-4b-it-4bit", "messages": [{"role": "user", "content": "Hello"}]}' - -# Swaps to Qwen -curl http://localhost:1234/v1/chat/completions \ - -H "Content-Type: application/json" \ - -d '{"model": "mlx-community/Qwen3-VL-4B-Instruct-4bit", "messages": [{"role": "user", "content": "Hello"}]}' + -d '{"model": "gemma", "messages": [{"role": "user", "content": "Hello"}]}' ``` ### Vision -Pass images as base64 data URIs or URLs in the `image_url` content part: +Pass images as base64 data URIs in the `image_url` content part: ```json { - "model": "mlx-community/gemma-3-4b-it-4bit", + "model": "gemma", "messages": [{ "role": "user", "content": [ @@ -68,37 +62,50 @@ Pass images as base64 data URIs or URLs in the `image_url` content part: } ``` -### Context Window Management - -Each model's context window is read from its HuggingFace config (`max_position_embeddings`) and reported in `/v1/models` via the `context_window` field. Clients can use this to manage conversation length proactively. - -If a request exceeds the context window, the server: - -1. Automatically summarizes older messages (keeping system messages and the last 6 messages intact) -2. Retries with the compressed conversation -3. Returns an OpenAI-compatible `context_length_exceeded` error if it still doesn't fit - ### Tool Use -Pass tools in the `tools` field (OpenAI format). The server handles model-specific formatting and parses tool calls from the output automatically. - -## Installation - -Requires Python 3.11+ and Apple Silicon. - -```bash -uv pip install -e "." -``` +Pass tools in the `tools` field (OpenAI format). The server handles model-specific formatting (Gemma `tool_code` blocks, Qwen `` XML tags) and parses tool calls from output automatically. When tools are present during streaming, output is buffered to strip tool-call markup before sending to the client. ## Project Structure ``` -mlx_server/ - main.py — FastAPI server, endpoints, CLI entrypoint - engine.py — Model loading, prompt building, generation (mlx_vlm) - models.py — Pydantic models for OpenAI API types +MLXServer/ +├── MLXServerApp.swift — App entry point, GPU cache config +├── ContentView.swift — Main layout, toolbar, keyboard shortcuts +├── Models/ +│ ├── ModelConfig.swift — Model definitions, alias/repoId resolution +│ └── ChatMessage.swift — Chat message data model +├── ViewModels/ +│ ├── ModelManager.swift — Model loading/switching via VLMModelFactory +│ └── ChatViewModel.swift — Chat state, ChatSession, API server lifecycle +├── Views/ +│ ├── ModelPickerView.swift — Toolbar model selector +│ ├── ChatMessagesView.swift — Scrollable message list with markdown +│ ├── ChatInputView.swift — Text input + image attach +│ ├── StatusBarView.swift — Model info, tok/s, GPU memory, API status +│ └── SettingsView.swift — System prompt + API settings +├── Server/ +│ ├── APIServer.swift — NWListener HTTP server, SSE streaming, KV cache reuse +│ ├── APIModels.swift — OpenAI-compatible Codable structs +│ ├── ToolCallParser.swift — Parses tool calls from model output +│ └── ToolPromptBuilder.swift — Model-specific tool prompt formatting +└── Utilities/ + ├── LocalModelResolver.swift — Offline-first HuggingFace cache resolution + └── Preferences.swift — UserDefaults wrapper + +project.yml — xcodegen project spec (dependencies, settings, deployment target) +build.sh — One-command build script (xcodegen + xcodebuild) ``` +## Key Design Decisions + +- Uses `mlx-swift-lm` (`MLXVLM` / `VLMModelFactory`) for inference — supports both text and vision in a single model load +- **Offline-first**: `LocalModelResolver` checks `~/.cache/huggingface/hub/` for locally-cached snapshots before downloading +- **KV cache reuse** across API requests — reuses `ChatSession` when conversation history prefix matches +- HTTP server built on `Network.framework` (`NWListener`) — no third-party server dependencies +- Model-specific prompt formatting: Gemma uses `tool_code` blocks, Qwen uses `` XML tags +- GPU cache limit set to 20 MB; cache cleared on model unload + ## Design Notes - Uses `mlx_vlm` (not `mlx_lm`) as the backend — supports both text and vision in a single model load diff --git a/build.sh b/build.sh new file mode 100755 index 0000000..aacc921 --- /dev/null +++ b/build.sh @@ -0,0 +1,41 @@ +#!/bin/bash +set -euo pipefail + +PROJECT_DIR="$(cd "$(dirname "$0")" && pwd)" +BUILD_DIR="$PROJECT_DIR/build" +CONFIG="${1:-Debug}" +APP_NAME="MLX Server" + +echo "==> Building $APP_NAME ($CONFIG)" + +# Regenerate Xcode project from project.yml (picks up any new/removed files) +if command -v xcodegen &>/dev/null; then + xcodegen generate --spec "$PROJECT_DIR/project.yml" --project "$PROJECT_DIR" 2>&1 | grep -v '^$' +fi + +# Build — filter to show only app source compilation, errors, and result +xcodebuild \ + -project "$PROJECT_DIR/MLXServer.xcodeproj" \ + -scheme MLXServer \ + -destination 'platform=macOS' \ + -configuration "$CONFIG" \ + SYMROOT="$BUILD_DIR" \ + build 2>&1 | \ + grep -E "(CompileSwift .* 'MLXServer'|error:|warning:.*MLXServer/|BUILD )" | \ + sed "s|.*CompileSwift normal arm64 Compiling ||" | \ + sed "s| (in target 'MLXServer' from project 'MLXServer')||" + +APP_PATH="$BUILD_DIR/$CONFIG/$APP_NAME.app" + +if [ -d "$APP_PATH" ] && [ -f "$APP_PATH/Contents/MacOS/$APP_NAME" ]; then + echo "" + echo "==> Build succeeded" + echo " $APP_PATH" + echo "" + echo " Run: open \"$APP_PATH\"" + echo " Or: \"$APP_PATH/Contents/MacOS/$APP_NAME\"" +else + echo "" + echo "==> Build failed" + exit 1 +fi diff --git a/mlx_server/__init__.py b/mlx_server/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/mlx_server/__main__.py b/mlx_server/__main__.py deleted file mode 100644 index 46ce780..0000000 --- a/mlx_server/__main__.py +++ /dev/null @@ -1,3 +0,0 @@ -from mlx_server.main import main - -main() diff --git a/mlx_server/engine.py b/mlx_server/engine.py deleted file mode 100644 index c164c51..0000000 --- a/mlx_server/engine.py +++ /dev/null @@ -1,1120 +0,0 @@ -"""Model loading and inference engine using mlx_vlm (supports both text and vision).""" - -from __future__ import annotations - -import base64 -import io -import json -import logging -import re -import tempfile -import threading -from collections.abc import Generator -from pathlib import Path - -import mlx.core as mx -import mlx_vlm -from PIL import Image - -logger = logging.getLogger(__name__) - -DEFAULT_MODEL = "mlx-community/gemma-3-4b-it-4bit" - -# Known model aliases for quick selection -MODEL_ALIASES: dict[str, str] = { - "gemma": "mlx-community/gemma-3-4b-it-4bit", - "gemma3n": "mlx-community/gemma-3n-E4B-it-4bit", - "qwen": "mlx-community/Qwen3-VL-4B-Instruct-4bit", -} - -# Fallback context lengths for models whose config doesn't expose -# max_position_embeddings (e.g. gemma3n uses a MatFormer architecture). -_CONTEXT_LENGTH_OVERRIDES: dict[str, int] = { - "gemma3n": 32768, -} - - -def _resolve_local_model_path(repo_id: str) -> Path | None: - """If a HuggingFace model is already cached locally, return its snapshot path. - - This avoids any network requests (HEAD checks) when the model files are - already present on disk — critical for offline use. - """ - # If it's already a local directory, just use it - local = Path(repo_id) - if local.is_dir(): - return local - - # Check the HF cache: ~/.cache/huggingface/hub/models--org--name/snapshots/ - cache_root = Path.home() / ".cache" / "huggingface" / "hub" - safe_name = "models--" + repo_id.replace("/", "--") - model_cache = cache_root / safe_name - - if not model_cache.is_dir(): - return None - - # Read the ref to find the snapshot hash - refs_dir = model_cache / "refs" - snapshot_dir = model_cache / "snapshots" - if refs_dir.is_dir() and snapshot_dir.is_dir(): - main_ref = refs_dir / "main" - if main_ref.is_file(): - commit_hash = main_ref.read_text().strip() - snap = snapshot_dir / commit_hash - if snap.is_dir(): - logger.info( - "Found locally cached model at %s — skipping online check", snap - ) - return snap - - # Fallback: use the first (most recent) snapshot if refs/main is missing - if snapshot_dir.is_dir(): - snapshots = sorted(snapshot_dir.iterdir(), key=lambda p: p.stat().st_mtime, reverse=True) - if snapshots: - logger.info( - "Found locally cached model at %s — skipping online check", - snapshots[0], - ) - return snapshots[0] - - return None - - -# ------------------------------------------------------------------ -# Helpers for Gemma 3 tool_code format -# ------------------------------------------------------------------ - -_JSON_TO_PYTHON_TYPE = { - "string": "str", - "integer": "int", - "number": "float", - "boolean": "bool", - "array": "list", - "object": "dict", -} - -_JSON_TYPE_DEFAULTS = { - "string": '""', - "integer": "0", - "number": "0.0", - "boolean": "False", - "array": "[]", - "object": "{}", -} - - -def _json_type_to_python(json_type: str) -> str: - return _JSON_TO_PYTHON_TYPE.get(json_type, "str") - - -def _json_type_default(json_type: str) -> str: - return _JSON_TYPE_DEFAULTS.get(json_type, "None") - - -def _python_repr(value) -> str: - """Produce a Python-repr-style string for a value.""" - if isinstance(value, str): - return repr(value) - if isinstance(value, bool): - return "True" if value else "False" - if isinstance(value, (int, float)): - return str(value) - return repr(value) - - -def _parse_python_call(call_str: str, tool_defs: dict[str, dict] | None = None) -> tuple[str, dict]: - """Parse a function call string into (name, args_dict). - - Handles multiple formats: - 1. Python-style: func_name(arg1="value1", arg2=42) - 2. Shell-style: func_name arg1 arg2 (common with small LLMs) - 3. Mixed: func_name("value") (positional args) - - tool_defs maps function names to their parameter schemas, used to - infer which parameter a positional/shell-style argument maps to. - """ - import ast - - call_str = call_str.strip() - - # Try Python-style: function_name(...) - m = re.match(r"(\w+)\s*\((.*)\)\s*$", call_str, re.DOTALL) - if m: - name = m.group(1) - args_str = m.group(2).strip() - - if not args_str: - return name, {} - - # Try parsing as a Python function call via dict() - try: - tree = ast.parse(f"dict({args_str})", mode="eval") - call_node = tree.body - args = {} - # Handle keyword arguments: func(key="val") - for kw in call_node.keywords: - args[kw.arg] = ast.literal_eval(kw.value) - # Handle positional arguments: func("val1", "val2") - if call_node.args and not args: - param_names = _get_param_names(name, tool_defs) - for i, arg_node in enumerate(call_node.args): - val = ast.literal_eval(arg_node) - if i < len(param_names): - args[param_names[i]] = val - else: - args[f"arg{i}"] = val - if args: - return name, args - except Exception: - pass - - # Fallback: regex-based key=value parsing - args = {} - for pair_match in re.finditer(r"(\w+)\s*=\s*(.+?)(?:,\s*(?=\w+\s*=)|$)", args_str, re.DOTALL): - key = pair_match.group(1) - val_str = pair_match.group(2).strip() - try: - args[key] = ast.literal_eval(val_str) - except Exception: - args[key] = val_str - return name, args - - # Shell-style: "func_name arg1 arg2" or "func_name some/path" - # Also handles: "func_name -flag arg" (common with shell tools) - parts = call_str.split(None, 1) - if parts and re.match(r"^\w+$", parts[0]): - name = parts[0] - if len(parts) == 1: - return name, {} - - rest = parts[1].strip() - param_names = _get_param_names(name, tool_defs) - first_param = param_names[0] if param_names else "input" - return name, {first_param: rest} - - # Last resort: treat the entire block as a command for the first - # known tool that looks like a shell/command tool, or just fail - raise ValueError(f"Cannot parse as function call: {call_str!r}") - - -def _get_param_names(func_name: str, tool_defs: dict[str, dict] | None) -> list[str]: - """Get ordered parameter names for a function from tool definitions.""" - if not tool_defs or func_name not in tool_defs: - return [] - params = tool_defs[func_name].get("parameters", {}) - properties = params.get("properties", {}) - required = params.get("required", []) - # Required params first, then optional - optional = [k for k in properties if k not in required] - return list(required) + optional - - -class PromptCache: - """Manages KV cache reuse across requests with shared prompt prefixes. - - Gemma 3 uses a mix of KVCache (full attention every 6th layer) and - RotatingKVCache (sliding window, 1024 tokens). Since RotatingKVCache - cannot be safely trimmed mid-sequence, we only reuse the cache when - the ENTIRE cached token sequence is a prefix of the new prompt. - - In multi-turn chat this is the common case: each new request extends - the previous prompt with the assistant response + new user message. - """ - - def __init__(self): - self._cache: list | None = None - self._cached_token_ids: list[int] | None = None - - def get_reusable_length(self, new_token_ids: list[int]) -> int: - """Return cached length if the entire cache is a valid prefix, else 0.""" - if self._cached_token_ids is None or self._cache is None: - return 0 - cached_len = len(self._cached_token_ids) - if cached_len > len(new_token_ids): - return 0 - for i in range(cached_len): - if self._cached_token_ids[i] != new_token_ids[i]: - return 0 - return cached_len - - def update(self, cache: list, token_ids: list[int]) -> None: - """Store cache and the token IDs it was built from.""" - self._cache = cache - self._cached_token_ids = list(token_ids) - - def clear(self) -> None: - self._cache = None - self._cached_token_ids = None - - @property - def cache(self): - return self._cache - - -class InferenceEngine: - """Manages model loading and text/vision generation.""" - - def __init__(self, model_path: str = DEFAULT_MODEL): - self.model_path = model_path - self.model = None - self.processor = None - self.config = None - self._model_type: str = "" # e.g. "gemma3", "qwen3" - self._lock = threading.Lock() - self._prompt_cache = PromptCache() - - def load(self) -> None: - logger.info("Loading model %s ...", self.model_path) - - # Prefer the local cache to avoid any network requests - local_path = _resolve_local_model_path(self.model_path) - load_path = str(local_path) if local_path else self.model_path - - self.model, self.processor = mlx_vlm.load(load_path) - - # Load model config for chat template - from transformers import AutoConfig - - self.config = AutoConfig.from_pretrained(load_path, trust_remote_code=True) - - # Detect model family for prompt-format branching - self._model_type = getattr(self.config, "model_type", "").lower() - logger.info("Model loaded successfully (type=%s).", self._model_type) - - def unload(self) -> None: - """Release model weights and caches to free memory.""" - logger.info("Unloading model %s ...", self.model_path) - self._prompt_cache.clear() - self.model = None - self.processor = None - self.config = None - self._model_type = "" - # Force garbage collection + clear MLX cache to reclaim memory - import gc - gc.collect() - mx.metal.clear_cache() - - @property - def is_qwen(self) -> bool: - return "qwen" in self._model_type - - @property - def is_gemma(self) -> bool: - return "gemma" in self._model_type - - @property - def context_length(self) -> int: - """Max context length from the model config.""" - if self.config is None: - return 0 - # Some architectures don't expose max_position_embeddings in config - if self._model_type in _CONTEXT_LENGTH_OVERRIDES: - return _CONTEXT_LENGTH_OVERRIDES[self._model_type] - # VLMs nest the LLM config under text_config - text_cfg = getattr(self.config, "text_config", self.config) - return getattr(text_cfg, "max_position_embeddings", 0) - - def count_tokens(self, text: str) -> int: - """Count tokens in a text string. Thread-safe, no lock needed.""" - tokenizer = self._get_tokenizer() - return len(tokenizer.encode(text)) - - def summarize_messages(self, messages: list[dict]) -> str: - """Summarize a list of conversation messages into a concise text. - - Calls generate() internally (acquires and releases the lock). - """ - # Build a readable transcript from the messages - transcript_lines = [] - for msg in messages: - role = msg.get("role", "unknown") - content = self._get_text_content(msg.get("content")) - # Include tool call info if present - if msg.get("tool_calls"): - tool_names = [ - tc.get("function", tc).get("name", "?") - for tc in msg["tool_calls"] - ] - content += f" [called tools: {', '.join(tool_names)}]" - if content.strip(): - transcript_lines.append(f"{role}: {content.strip()}") - - transcript = "\n".join(transcript_lines) - - summary_instruction = [{ - "role": "user", - "content": ( - "Summarize the following conversation concisely. " - "Preserve key facts, decisions, tool results, and context " - "needed to continue the conversation naturally. " - "Be brief but complete.\n\n" - f"\n{transcript}\n" - ), - }] - - prompt, _ = self.build_prompt(summary_instruction, tools=None) - summary_text, _, _ = self.generate( - prompt=prompt, - images=None, - max_tokens=1024, - temperature=0.2, - ) - return summary_text.strip() - - # ------------------------------------------------------------------ - # Image helpers - # ------------------------------------------------------------------ - - @staticmethod - def _decode_image_url(url: str) -> str: - """Convert a data URI or URL to a file path that mlx_vlm can consume.""" - if url.startswith("data:"): - # data:image/png;base64,iVBOR... - header, b64data = url.split(",", 1) - img_bytes = base64.b64decode(b64data) - img = Image.open(io.BytesIO(img_bytes)) - tmp = tempfile.NamedTemporaryFile(suffix=".png", delete=False) - img.save(tmp, format="PNG") - tmp.close() - return tmp.name - # Assume it's a URL or local path – mlx_vlm handles URLs natively - return url - - # ------------------------------------------------------------------ - # Prompt building - # ------------------------------------------------------------------ - - def build_prompt( - self, - messages: list[dict], - tools: list[dict] | None = None, - ) -> tuple[str, list[str]]: - """Build a prompt string and collect image paths from messages. - - Returns (prompt_str, image_paths). - """ - if self.is_qwen: - return self._build_prompt_qwen(messages, tools) - return self._build_prompt_gemma(messages, tools) - - def _build_prompt_gemma( - self, - messages: list[dict], - tools: list[dict] | None = None, - ) -> tuple[str, list[str]]: - """Gemma 3 prompt builder (tool_code format, no system role).""" - image_paths: list[str] = [] - formatted_messages: list[dict] = [] - - for msg in messages: - role = msg["role"] - content = msg.get("content") - tool_calls = msg.get("tool_calls") - tool_call_id = msg.get("tool_call_id") - - if role == "system": - text = self._get_text_content(content) - # Inject tool definitions into system prompt - if tools: - text = self._inject_tools_into_system(text, tools) - formatted_messages.append({"role": "user", "content": text}) - # Gemma 3 doesn't have a system role; we use the user role - # and add a model acknowledgment - formatted_messages.append({ - "role": "assistant", - "content": "Understood. I will follow these instructions.", - }) - elif role == "user": - text, imgs = self._extract_content_parts(content) - image_paths.extend(imgs) - formatted_messages.append({"role": "user", "content": text}) - elif role == "assistant": - text = self._get_text_content(content) or "" - if tool_calls: - # Format tool calls in the way Gemma 3 expects - tc_text = self._format_tool_calls_for_prompt(tool_calls) - text = (text + "\n" + tc_text).strip() - formatted_messages.append({"role": "assistant", "content": text}) - elif role == "tool": - # Tool results use Gemma 3's tool_output format - tool_text = self._get_text_content(content) or "" - result_msg = f"```tool_output\n{tool_text}\n```" - formatted_messages.append({"role": "user", "content": result_msg}) - - # If the first system prompt had no tools but we have tools, inject at start - if tools and not any(m.get("role") == "system" for m in messages): - tool_system = self._build_tool_system_prompt(tools) - formatted_messages.insert(0, {"role": "user", "content": tool_system}) - formatted_messages.insert(1, { - "role": "assistant", - "content": "Understood. I will follow these instructions and use tools when appropriate.", - }) - - # Gemma 3 requires strictly alternating user/assistant turns. - # Merge consecutive same-role messages and ensure it starts with user. - formatted_messages = self._merge_consecutive_roles(formatted_messages) - - # Apply chat template via mlx_vlm - prompt = mlx_vlm.apply_chat_template( - self.processor, - self.config, - formatted_messages, - add_generation_prompt=True, - num_images=len(image_paths), - ) - - return prompt, image_paths - - def _build_prompt_qwen( - self, - messages: list[dict], - tools: list[dict] | None = None, - ) -> tuple[str, list[str]]: - """Qwen3 prompt builder (native system role, JSON tool calls).""" - image_paths: list[str] = [] - formatted_messages: list[dict] = [] - - # Qwen3 supports system role natively — inject tools there - has_system = any(m.get("role") == "system" for m in messages) - if tools and not has_system: - formatted_messages.append({ - "role": "system", - "content": self._build_qwen_tool_system_prompt(tools), - }) - - for msg in messages: - role = msg["role"] - content = msg.get("content") - tool_calls = msg.get("tool_calls") - - if role == "system": - text = self._get_text_content(content) - if tools: - text = text + "\n\n" + self._build_qwen_tool_system_prompt(tools) - formatted_messages.append({"role": "system", "content": text}) - elif role == "user": - text, imgs = self._extract_content_parts(content) - image_paths.extend(imgs) - formatted_messages.append({"role": "user", "content": text}) - elif role == "assistant": - text = self._get_text_content(content) or "" - if tool_calls: - tc_text = self._format_qwen_tool_calls_for_prompt(tool_calls) - text = (text + "\n" + tc_text).strip() - formatted_messages.append({"role": "assistant", "content": text}) - elif role == "tool": - tool_text = self._get_text_content(content) or "" - formatted_messages.append({"role": "user", "content": tool_text}) - - # Apply chat template via mlx_vlm - prompt = mlx_vlm.apply_chat_template( - self.processor, - self.config, - formatted_messages, - add_generation_prompt=True, - num_images=len(image_paths), - ) - - return prompt, image_paths - - @staticmethod - def _merge_consecutive_roles(messages: list[dict]) -> list[dict]: - """Merge consecutive messages with the same role into one. - - Gemma 3's chat template enforces strict user/assistant alternation. - """ - if not messages: - return messages - - merged = [messages[0].copy()] - for msg in messages[1:]: - if msg["role"] == merged[-1]["role"]: - # Merge content with the previous message - merged[-1]["content"] = ( - merged[-1].get("content", "") + "\n\n" + msg.get("content", "") - ) - else: - merged.append(msg.copy()) - - # Ensure conversation starts with user - if merged and merged[0]["role"] != "user": - merged.insert(0, {"role": "user", "content": ""}) - - return merged - - def _get_text_content(self, content) -> str: - if content is None: - return "" - if isinstance(content, str): - return content - # list of content parts - parts = [] - for part in content: - if isinstance(part, dict) and part.get("type") == "text": - parts.append(part["text"]) - return "\n".join(parts) - - def _extract_content_parts(self, content) -> tuple[str, list[str]]: - """Extract text and image paths from content parts.""" - if isinstance(content, str): - return content, [] - if content is None: - return "", [] - - texts = [] - images = [] - for part in content: - if isinstance(part, dict): - if part.get("type") == "text": - texts.append(part["text"]) - elif part.get("type") == "image_url": - url = part["image_url"]["url"] - images.append(self._decode_image_url(url)) - return "\n".join(texts), images - - def _inject_tools_into_system(self, system_text: str, tools: list[dict]) -> str: - tool_block = self._build_tool_system_prompt(tools) - return f"{system_text}\n\n{tool_block}" - - def _build_tool_system_prompt(self, tools: list[dict]) -> str: - """Build the tool system prompt using Google's official Gemma 3 format. - - Uses the tool_code/tool_output convention recommended by Google: - - Tools defined as Python function signatures with docstrings - - Model outputs calls in ```tool_code``` fenced blocks - - Results returned in ```tool_output``` fenced blocks - """ - func_defs = [] - for tool in tools: - func = tool.get("function", tool) - func_defs.append(self._tool_to_python_signature(func)) - - functions_block = "\n\n".join(func_defs) - - return ( - "At each turn, if you decide to invoke any of the function(s), " - "it should be wrapped with ```tool_code```. " - "The python methods described below are imported and available, " - "you can only use defined methods. " - "The generated code should be readable and efficient. " - "The response to a method will be wrapped in ```tool_output``` " - "use it to call more tools or generate a helpful, friendly response.\n" - "\n" - f"{functions_block}" - ) - - @staticmethod - def _tool_to_python_signature(func: dict) -> str: - """Convert an OpenAI function definition to a Python function signature with docstring.""" - name = func["name"] - desc = func.get("description", "") - params = func.get("parameters", {}) - properties = params.get("properties", {}) - required = set(params.get("required", [])) - - # Build parameter list - param_parts = [] - doc_args = [] - for pname, pinfo in properties.items(): - ptype = _json_type_to_python(pinfo.get("type", "str")) - pdesc = pinfo.get("description", "") - if pname in required: - param_parts.append(f"{pname}: {ptype}") - else: - default = _json_type_default(pinfo.get("type", "str")) - param_parts.append(f"{pname}: {ptype} = {default}") - doc_args.append(f" {pname}: {pdesc}" if pdesc else f" {pname}") - - sig = f"def {name}({', '.join(param_parts)}):" - doc_lines = [f' """{desc}'] - if doc_args: - doc_lines.append("") - doc_lines.append(" Args:") - doc_lines.extend(doc_args) - doc_lines.append(' """') - - return sig + "\n" + "\n".join(doc_lines) - - def _format_tool_calls_for_prompt(self, tool_calls: list[dict]) -> str: - """Format OpenAI-style tool calls back into Gemma 3 tool_code blocks.""" - parts = [] - for tc in tool_calls: - func = tc.get("function", tc) - name = func["name"] - args = func.get("arguments", "{}") - if isinstance(args, str): - args = json.loads(args) - # Format as Python function call - arg_parts = [f"{k}={_python_repr(v)}" for k, v in args.items()] - call_str = f"{name}({', '.join(arg_parts)})" - parts.append(f"```tool_code\n{call_str}\n```") - return "\n".join(parts) - - # ------------------------------------------------------------------ - # Qwen3 tool helpers - # ------------------------------------------------------------------ - - @staticmethod - def _build_qwen_tool_system_prompt(tools: list[dict]) -> str: - """Build the tool system prompt for Qwen3 using its native JSON format.""" - tool_descs = [] - for tool in tools: - func = tool.get("function", tool) - tool_descs.append({ - "type": "function", - "function": { - "name": func["name"], - "description": func.get("description", ""), - "parameters": func.get("parameters", {}), - }, - }) - tools_json = json.dumps(tool_descs, indent=2) - return ( - "# Tools\n\n" - "You are a helpful assistant with access to the following tools. " - "Use them when appropriate by responding with a JSON tool call.\n\n" - "## Available Tools\n\n" - f"{tools_json}\n\n" - "## Tool Call Format\n\n" - "When you need to call a tool, respond with:\n" - '\n{"name": "", "arguments": {}}\n' - ) - - @staticmethod - def _format_qwen_tool_calls_for_prompt(tool_calls: list[dict]) -> str: - """Format OpenAI-style tool calls back into Qwen3's XML tag format.""" - parts = [] - for tc in tool_calls: - func = tc.get("function", tc) - name = func["name"] - args = func.get("arguments", "{}") - if isinstance(args, str): - args = json.loads(args) - call_obj = {"name": name, "arguments": args} - parts.append(f"\n{json.dumps(call_obj)}\n") - return "\n".join(parts) - - # ------------------------------------------------------------------ - # Prefix cache & generation - # ------------------------------------------------------------------ - - # Common kwargs for mlx_vlm generate calls - # Note: KV cache quantization is not supported with Gemma 3's RotatingKVCache - _GENERATE_KWARGS: dict = {} - - # Keys in the prep dict that are internal bookkeeping, not kwargs for - # mlx_vlm.stream_generate. - _PREP_INTERNAL_KEYS = frozenset({ - "input_ids", "pixel_values", "mask", "prompt_cache", - "_full_token_ids", "_prompt_token_count", - }) - - def _extra_generate_kwargs( - self, images: list[str] | None, prep: dict | None = None, - ) -> dict: - """Build per-request kwargs for mlx_vlm.stream_generate. - - Includes model-specific keys from prepare_inputs (e.g. image_grid_thw - for Qwen3-VL) and works around a chunked-prefill bug where - visual_pos_masks is None for text-only requests. - """ - extra: dict = dict(self._GENERATE_KWARGS) - if self.is_qwen and not images: - extra["prefill_step_size"] = None - # Forward any model-specific keys that prepare_inputs returned - if prep: - for k, v in prep.items(): - if k not in self._PREP_INTERNAL_KEYS: - extra[k] = v - return extra - - def _get_tokenizer(self): - """Get the underlying tokenizer from the processor.""" - proc = self.processor - return proc.tokenizer if hasattr(proc, "tokenizer") else proc - - def _prepare_generation( - self, prompt: str, images: list[str] | None = None - ) -> dict: - """Tokenize prompt, check prefix cache, return generation kwargs. - - Returns a dict with keys: - input_ids, pixel_values, mask, prompt_cache, - _full_token_ids, _prompt_token_count - """ - from mlx_vlm.models import cache as cache_module - from mlx_vlm.utils import prepare_inputs - - model_type = getattr(self.config, "model_type", "") - add_special_tokens = ( - not hasattr(self.processor, "chat_template") - if model_type in ("gemma3", "gemma3n") - else True - ) - image_token_index = getattr(self.model.config, "image_token_index", None) - - # Tokenize the full prompt (+ process pixel values if images present) - inputs = prepare_inputs( - self.processor, - images=images if images else None, - prompts=prompt, - image_token_index=image_token_index, - add_special_tokens=add_special_tokens, - ) - full_input_ids = inputs["input_ids"] - pixel_values = inputs.get("pixel_values") - mask = inputs.get("attention_mask") - - # Collect any model-specific extra keys from prepare_inputs - # (e.g. image_grid_thw for Qwen3-VL) so they reach the model. - _KNOWN_KEYS = {"input_ids", "pixel_values", "attention_mask"} - extra_inputs = {k: v for k, v in inputs.items() if k not in _KNOWN_KEYS} - - full_token_list = full_input_ids.flatten().tolist() - prefix_len = self._prompt_cache.get_reusable_length(full_token_list) - - if prefix_len > 0: - suffix_token_list = full_token_list[prefix_len:] - - # If the suffix contains image placeholder tokens, we can't skip - # the vision encoder — fall back to full processing. - if ( - image_token_index is not None - and image_token_index in suffix_token_list - ): - logger.info( - "New images in suffix — prefix cache invalidated" - ) - prefix_len = 0 - - if prefix_len > 0: - suffix_ids = mx.array([suffix_token_list]) - logger.info( - "Prefix cache hit: reusing %d/%d tokens (%.1f%%), " - "processing %d new tokens", - prefix_len, - len(full_token_list), - 100 * prefix_len / len(full_token_list), - len(suffix_token_list), - ) - return { - "input_ids": suffix_ids, - "pixel_values": None, # images already in cached KV - "mask": None, - "prompt_cache": self._prompt_cache.cache, - "_full_token_ids": full_token_list, - "_prompt_token_count": len(full_token_list), - } - - # Cache miss — create a fresh KV cache - # VLM models expose .language_model; text-only models are the LM directly - lm = getattr(self.model, "language_model", self.model) - cache = cache_module.make_prompt_cache(lm) - logger.info( - "Prefix cache miss: processing %d tokens from scratch", - len(full_token_list), - ) - return { - "input_ids": full_input_ids, - "pixel_values": pixel_values, - "mask": mask, - "prompt_cache": cache, - "_full_token_ids": full_token_list, - "_prompt_token_count": len(full_token_list), - **extra_inputs, - } - - def _save_cache(self, prep: dict, generated_tokens: list[int]) -> None: - """Persist the KV cache and token IDs after generation.""" - full_sequence = prep["_full_token_ids"] + generated_tokens - self._prompt_cache.update(prep["prompt_cache"], full_sequence) - - def generate( - self, - prompt: str, - images: list[str] | None = None, - max_tokens: int = 4096, - temperature: float = 0.7, - top_p: float = 0.9, - stop: list[str] | None = None, - repetition_penalty: float = 1.1, - ) -> tuple[str, int, int]: - """Generate a complete response. Returns (text, prompt_tokens, completion_tokens).""" - with self._lock: - prep = self._prepare_generation(prompt, images) - prompt_token_count = prep["_prompt_token_count"] - - # Ensure stopping criteria is initialised (Gemma-specific; optional for others) - tokenizer = self._get_tokenizer() - if hasattr(tokenizer, "stopping_criteria"): - tokenizer.stopping_criteria.reset(self.model.config.eos_token_id) - - text = "" - generated_tokens: list[int] = [] - gen_tokens = 0 - - for result in mlx_vlm.stream_generate( - self.model, - self.processor, - prompt, - input_ids=prep["input_ids"], - pixel_values=prep.get("pixel_values"), - mask=prep.get("mask"), - prompt_cache=prep["prompt_cache"], - max_tokens=max_tokens, - temperature=temperature, - top_p=top_p, - repetition_penalty=repetition_penalty, - **self._extra_generate_kwargs(images, prep), - ): - text += result.text - if result.token is not None: - generated_tokens.append(result.token) - gen_tokens = result.generation_tokens - - self._save_cache(prep, generated_tokens) - - if stop: - text = self._apply_stop(text, stop) - return text, prompt_token_count, gen_tokens - - def stream_generate( - self, - prompt: str, - images: list[str] | None = None, - max_tokens: int = 4096, - temperature: float = 0.7, - top_p: float = 0.9, - stop: list[str] | None = None, - repetition_penalty: float = 1.1, - ) -> Generator[tuple[str, bool, int, int], None, None]: - """Stream tokens. Yields (token_text, is_final, prompt_tokens, gen_tokens).""" - with self._lock: - prep = self._prepare_generation(prompt, images) - prompt_token_count = prep["_prompt_token_count"] - - # Ensure stopping criteria is initialised (Gemma-specific; optional for others) - tokenizer = self._get_tokenizer() - if hasattr(tokenizer, "stopping_criteria"): - tokenizer.stopping_criteria.reset(self.model.config.eos_token_id) - - accumulated = "" - generated_tokens: list[int] = [] - gen_tokens = 0 - - try: - for result in mlx_vlm.stream_generate( - self.model, - self.processor, - prompt, - input_ids=prep["input_ids"], - pixel_values=prep.get("pixel_values"), - mask=prep.get("mask"), - prompt_cache=prep["prompt_cache"], - max_tokens=max_tokens, - temperature=temperature, - top_p=top_p, - repetition_penalty=repetition_penalty, - **self._extra_generate_kwargs(images, prep), - ): - token_text = result.text - accumulated += token_text - if result.token is not None: - generated_tokens.append(result.token) - gen_tokens = result.generation_tokens - - if stop and self._check_stop(accumulated, stop): - trimmed = self._apply_stop(accumulated, stop) - safe_delta = trimmed[ - len(accumulated) - len(token_text) : - ] - yield safe_delta, True, prompt_token_count, gen_tokens - return - - yield token_text, False, prompt_token_count, gen_tokens - - # Final yield to signal completion - yield "", True, prompt_token_count, gen_tokens - finally: - self._save_cache(prep, generated_tokens) - - @staticmethod - def _apply_stop(text: str, stop: list[str]) -> str: - for s in stop: - idx = text.find(s) - if idx != -1: - text = text[:idx] - return text - - @staticmethod - def _check_stop(text: str, stop: list[str]) -> bool: - return any(s in text for s in stop) - - # ------------------------------------------------------------------ - # Tool call parsing from model output - # ------------------------------------------------------------------ - - def parse_tool_calls( - self, text: str, tools: list[dict] | None = None - ) -> tuple[str, list[dict]]: - """Parse tool calls from model output. - - Supports both Gemma 3's ```tool_code``` blocks and Qwen3's - XML tags. - - Returns (clean_text, tool_calls) where tool_calls is a list of - {"id": str, "type": "function", "function": {"name": str, "arguments": str}}. - """ - if self.is_qwen: - return self._parse_tool_calls_qwen(text) - return self._parse_tool_calls_gemma(text, tools) - - @staticmethod - def _parse_tool_calls_gemma( - text: str, tools: list[dict] | None = None - ) -> tuple[str, list[dict]]: - """Parse Gemma 3 tool_code blocks.""" - tool_defs: dict[str, dict] = {} - if tools: - for tool in tools: - func = tool.get("function", tool) - tool_defs[func["name"]] = func - - tool_calls = [] - pattern = r"```tool_code\s*(.*?)\s*```" - matches = re.findall(pattern, text, re.DOTALL) - - clean_text = re.sub(r"```tool_code\s*.*?\s*```", "", text, flags=re.DOTALL).strip() - - for i, match in enumerate(matches): - call_str = match.strip() - try: - name, args = _parse_python_call(call_str, tool_defs) - tool_calls.append({ - "id": f"call_{i}_{hash(call_str) % 10**8:08d}", - "type": "function", - "function": { - "name": name, - "arguments": json.dumps(args), - }, - }) - except Exception as e: - logger.warning("Failed to parse tool_code call %r: %s", call_str, e) - - return clean_text, tool_calls - - @staticmethod - def _parse_tool_calls_qwen(text: str) -> tuple[str, list[dict]]: - """Parse Qwen3 XML tags.""" - tool_calls = [] - pattern = r"\s*(.*?)\s*" - matches = re.findall(pattern, text, re.DOTALL) - - clean_text = re.sub(r"\s*.*?\s*", "", text, flags=re.DOTALL).strip() - - for i, match in enumerate(matches): - try: - call_obj = json.loads(match.strip()) - name = call_obj.get("name", "") - args = call_obj.get("arguments", {}) - if isinstance(args, str): - args = json.loads(args) - tool_calls.append({ - "id": f"call_{i}_{hash(match) % 10**8:08d}", - "type": "function", - "function": { - "name": name, - "arguments": json.dumps(args), - }, - }) - except Exception as e: - logger.warning("Failed to parse tool_call tag %r: %s", match, e) - - return clean_text, tool_calls - - -class ModelManager: - """Registry of available models with on-demand loading and swapping. - - Only one model is loaded in memory at a time. When a request targets a - different model, the current one is unloaded first. - """ - - def __init__(self, default_model: str = DEFAULT_MODEL): - self._lock = threading.Lock() - self._engine: InferenceEngine | None = None - self._current_model: str | None = None - self._default_model = default_model - - @property - def available_models(self) -> list[str]: - """All model IDs that clients can request.""" - return list(MODEL_ALIASES.values()) - - @property - def available_aliases(self) -> dict[str, str]: - """Short alias -> full HuggingFace model path.""" - return dict(MODEL_ALIASES) - - def resolve_model(self, requested: str) -> str: - """Resolve a model string to a full HuggingFace model path. - - Accepts aliases ('gemma', 'qwen') or full paths. - """ - if requested in MODEL_ALIASES: - return MODEL_ALIASES[requested] - if requested in MODEL_ALIASES.values(): - return requested - # Accept partial matches (e.g. 'gemma-3-4b-it' matches the gemma entry) - for alias, full_path in MODEL_ALIASES.items(): - if requested in full_path or requested in alias: - return full_path - # Unknown model — return as-is and let loading fail if invalid - return requested - - def get_engine(self, requested_model: str | None = None) -> InferenceEngine: - """Return an engine for the requested model, swapping if necessary.""" - target = self.resolve_model(requested_model) if requested_model else self._default_model - - with self._lock: - if self._engine is not None and self._current_model == target: - return self._engine - - # Need to swap - if self._engine is not None: - logger.info( - "Swapping model: %s -> %s", self._current_model, target - ) - self._engine.unload() - self._engine = None - self._current_model = None - - engine = InferenceEngine(model_path=target) - engine.load() - self._engine = engine - self._current_model = target - return self._engine - - def get_context_length(self, model_id: str) -> int | None: - """Get context length for a model from its cached config, without loading it.""" - local_path = _resolve_local_model_path(model_id) - if local_path is None: - return None - config_file = local_path / "config.json" - if not config_file.is_file(): - return None - try: - config = json.loads(config_file.read_text()) - model_type = config.get("model_type", "") - # Check override table for models that don't expose it in config - if model_type in _CONTEXT_LENGTH_OVERRIDES: - return _CONTEXT_LENGTH_OVERRIDES[model_type] - # VLMs nest under text_config - text_cfg = config.get("text_config", config) - return text_cfg.get("max_position_embeddings") - except Exception: - return None - - def preload(self, model: str | None = None) -> None: - """Eagerly load a model at startup.""" - self.get_engine(model) diff --git a/mlx_server/main.py b/mlx_server/main.py deleted file mode 100644 index da5a8f1..0000000 --- a/mlx_server/main.py +++ /dev/null @@ -1,409 +0,0 @@ -"""OpenAI-compatible API server for local LLMs (Gemma 3, Qwen3, …) via MLX.""" - -from __future__ import annotations - -import argparse -import json -import logging -import time -import uuid - -import uvicorn -from fastapi import FastAPI, HTTPException -from fastapi.middleware.cors import CORSMiddleware -from sse_starlette.sse import EventSourceResponse - -from .engine import DEFAULT_MODEL, InferenceEngine, ModelManager -from .models import ( - ChatCompletionChunk, - ChatCompletionRequest, - ChatCompletionResponse, - Choice, - ChoiceMessage, - DeltaMessage, - ModelInfo, - ModelListResponse, - StreamChoice, - ToolCall, - FunctionCall, - UsageInfo, -) - -logger = logging.getLogger(__name__) - -app = FastAPI(title="MLX Server", description="OpenAI-compatible API for local LLMs on Apple Silicon") - -app.add_middleware( - CORSMiddleware, - allow_origins=["*"], - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], -) - -manager: ModelManager | None = None - -# Number of recent messages to always preserve when summarizing -_KEEP_RECENT = 6 - - -def get_engine(requested_model: str | None = None): - if manager is None: - raise HTTPException(status_code=503, detail="Server not initialized") - return manager.get_engine(requested_model) - - -def _make_id() -> str: - return f"chatcmpl-{uuid.uuid4().hex[:12]}" - - -# ------------------------------------------------------------------ -# Context window management -# ------------------------------------------------------------------ - - -def _manage_context( - e: InferenceEngine, - messages: list[dict], - tools: list[dict] | None, - max_tokens: int, -) -> list[dict]: - """Check if messages fit in the context window; summarize if needed. - - Returns the (possibly summarized) message list. Raises HTTPException - with an OpenAI-compatible error if the conversation cannot fit. - """ - context_length = e.context_length - if context_length <= 0: - return messages # unknown context size, skip check - - prompt, _ = e.build_prompt(messages, tools) - prompt_tokens = e.count_tokens(prompt) - available = context_length - max_tokens - - if prompt_tokens <= available: - return messages - - # --- Need to summarize --- - logger.info( - "Context window pressure: %d prompt tokens + %d max_tokens = %d " - "(limit %d). Attempting summarization.", - prompt_tokens, max_tokens, prompt_tokens + max_tokens, context_length, - ) - - # Split messages: system | middle (summarizable) | recent (kept) - system_msgs = [m for m in messages if m.get("role") == "system"] - non_system = [m for m in messages if m.get("role") != "system"] - - if len(non_system) <= _KEEP_RECENT: - _raise_context_exceeded(prompt_tokens, max_tokens, context_length) - - recent = non_system[-_KEEP_RECENT:] - middle = non_system[:-_KEEP_RECENT] - - # Generate summary of the middle messages - summary_text = e.summarize_messages(middle) - - summary_msg = { - "role": "user", - "content": f"[Summary of earlier conversation]\n{summary_text}", - } - ack_msg = { - "role": "assistant", - "content": "Understood, I have the context from our earlier conversation.", - } - new_messages = system_msgs + [summary_msg, ack_msg] + recent - - # Re-check fit - new_prompt, _ = e.build_prompt(new_messages, tools) - new_prompt_tokens = e.count_tokens(new_prompt) - - if new_prompt_tokens + max_tokens > context_length: - logger.warning( - "Still over context limit after summarization: %d + %d = %d (limit %d)", - new_prompt_tokens, max_tokens, new_prompt_tokens + max_tokens, context_length, - ) - _raise_context_exceeded(new_prompt_tokens, max_tokens, context_length) - - logger.info( - "Summarization reduced prompt from %d to %d tokens (saved %d).", - prompt_tokens, new_prompt_tokens, prompt_tokens - new_prompt_tokens, - ) - return new_messages - - -def _raise_context_exceeded(prompt_tokens: int, max_tokens: int, context_length: int): - """Raise an OpenAI-compatible context_length_exceeded error.""" - raise HTTPException( - status_code=400, - detail={ - "error": { - "message": ( - f"This model's maximum context length is {context_length} tokens. " - f"However, your messages resulted in {prompt_tokens} tokens and " - f"{max_tokens} tokens were requested for the completion " - f"({prompt_tokens + max_tokens} total). " - f"Please reduce the length of the messages or completion." - ), - "type": "invalid_request_error", - "code": "context_length_exceeded", - } - }, - ) - - -# ------------------------------------------------------------------ -# Endpoints -# ------------------------------------------------------------------ - - -@app.get("/v1/models") -async def list_models() -> ModelListResponse: - if manager is None: - raise HTTPException(status_code=503, detail="Server not initialized") - return ModelListResponse( - data=[ - ModelInfo( - id=model_id, - context_window=manager.get_context_length(model_id), - ) - for model_id in manager.available_models - ] - ) - - -@app.post("/v1/chat/completions") -async def chat_completions(request: ChatCompletionRequest): - e = get_engine(request.model) - - # Convert pydantic messages to dicts - messages = [m.model_dump(exclude_none=True) for m in request.messages] - tools = None - if request.tools: - tools = [t.model_dump(exclude_none=True) for t in request.tools] - - stop = request.stop - if isinstance(stop, str): - stop = [stop] - - temperature = request.temperature if request.temperature is not None else 0.7 - top_p = request.top_p if request.top_p is not None else 0.9 - max_tokens = request.max_tokens if request.max_tokens is not None else 4096 - - # Context window management: summarize if needed, error if impossible - messages = _manage_context(e, messages, tools, max_tokens) - - prompt, images = e.build_prompt(messages, tools) - - if request.stream: - return EventSourceResponse( - _stream_response(e, prompt, images, max_tokens, temperature, top_p, stop, tools, request.model), - media_type="text/event-stream", - ) - - # Non-streaming - text, prompt_tokens, completion_tokens = e.generate( - prompt=prompt, - images=images or None, - max_tokens=max_tokens, - temperature=temperature, - top_p=top_p, - stop=stop, - ) - - # Check for tool calls in the response - finish_reason = "stop" - tool_calls_parsed = None - if tools: - clean_text, parsed = e.parse_tool_calls(text, tools) - if parsed: - tool_calls_parsed = [ - ToolCall( - index=i, - id=tc["id"], - type="function", - function=FunctionCall( - name=tc["function"]["name"], - arguments=tc["function"]["arguments"], - ), - ) - for i, tc in enumerate(parsed) - ] - text = clean_text if clean_text else None - finish_reason = "tool_calls" - - return ChatCompletionResponse( - id=_make_id(), - model=request.model, - choices=[ - Choice( - message=ChoiceMessage( - role="assistant", - content=text if not tool_calls_parsed else (text or None), - tool_calls=tool_calls_parsed, - ), - finish_reason=finish_reason, - ) - ], - usage=UsageInfo( - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=prompt_tokens + completion_tokens, - ), - ) - - -async def _stream_response( - e, - prompt: str, - images: list[str] | None, - max_tokens: int, - temperature: float, - top_p: float, - stop: list[str] | None, - tools: list[dict] | None, - model_name: str, -): - request_id = _make_id() - created = int(time.time()) - - # Send initial chunk with role - initial_chunk = ChatCompletionChunk( - id=request_id, - created=created, - model=model_name, - choices=[StreamChoice(delta=DeltaMessage(role="assistant"))], - ) - yield {"data": initial_chunk.model_dump_json()} - - full_text = "" - prompt_tokens = 0 - gen_tokens = 0 - - # When tools are available we must buffer the full response before - # emitting content — otherwise raw tool-call markup (```tool_code``` - # or ) leaks into the streamed text. - buffer_for_tools = bool(tools) - - for token_text, is_final, pt, gt in e.stream_generate( - prompt=prompt, - images=images or None, - max_tokens=max_tokens, - temperature=temperature, - top_p=top_p, - stop=stop, - ): - prompt_tokens = pt - gen_tokens = gt - full_text += token_text - - if not buffer_for_tools and not is_final and token_text: - chunk = ChatCompletionChunk( - id=request_id, - created=created, - model=model_name, - choices=[StreamChoice(delta=DeltaMessage(content=token_text))], - ) - yield {"data": chunk.model_dump_json()} - - # --- Post-generation: parse tool calls and emit clean content ------ - finish_reason = "stop" - tool_calls_parsed = [] - - if tools: - clean_text, parsed = e.parse_tool_calls(full_text, tools) - if parsed: - finish_reason = "tool_calls" - tool_calls_parsed = parsed - full_text = clean_text or "" - - # Emit buffered content (when tools were present, this is the cleaned - # text with tool-call markup stripped out) - if buffer_for_tools and full_text.strip(): - content_chunk = ChatCompletionChunk( - id=request_id, - created=created, - model=model_name, - choices=[StreamChoice(delta=DeltaMessage(content=full_text))], - ) - yield {"data": content_chunk.model_dump_json()} - - # Emit tool call chunks - for i, tc in enumerate(tool_calls_parsed): - tc_chunk = ChatCompletionChunk( - id=request_id, - created=created, - model=model_name, - choices=[ - StreamChoice( - delta=DeltaMessage( - tool_calls=[ - ToolCall( - index=i, - id=tc["id"], - type="function", - function=FunctionCall( - name=tc["function"]["name"], - arguments=tc["function"]["arguments"], - ), - ) - ] - ) - ) - ], - ) - yield {"data": tc_chunk.model_dump_json()} - - # Final chunk with finish reason and usage - final_chunk = ChatCompletionChunk( - id=request_id, - created=created, - model=model_name, - choices=[StreamChoice(delta=DeltaMessage(), finish_reason=finish_reason)], - usage=UsageInfo( - prompt_tokens=prompt_tokens, - completion_tokens=gen_tokens, - total_tokens=prompt_tokens + gen_tokens, - ), - ) - yield {"data": final_chunk.model_dump_json()} - yield {"data": "[DONE]"} - - -# ------------------------------------------------------------------ -# Health / utility -# ------------------------------------------------------------------ - - -@app.get("/health") -async def health(): - return {"status": "ok"} - - -# ------------------------------------------------------------------ -# Entrypoint -# ------------------------------------------------------------------ - - -def main(): - parser = argparse.ArgumentParser(description="MLX Server – OpenAI-compatible API") - parser.add_argument("--model", type=str, default=DEFAULT_MODEL, help="HuggingFace model path") - parser.add_argument("--host", type=str, default="127.0.0.1") - parser.add_argument("--port", type=int, default=1234) - parser.add_argument("--log-level", type=str, default="info") - args = parser.parse_args() - - logging.basicConfig( - level=getattr(logging, args.log_level.upper()), - format="%(asctime)s %(levelname)s %(name)s: %(message)s", - ) - - global manager - manager = ModelManager(default_model=args.model) - manager.preload(args.model) - - uvicorn.run(app, host=args.host, port=args.port, log_level=args.log_level) - - -if __name__ == "__main__": - main() diff --git a/mlx_server/models.py b/mlx_server/models.py deleted file mode 100644 index 781b31b..0000000 --- a/mlx_server/models.py +++ /dev/null @@ -1,145 +0,0 @@ -"""OpenAI API compatible request/response models.""" - -from __future__ import annotations - -import time -from typing import Any, Literal - -from pydantic import BaseModel, Field - - -# --- Request models --- - - -class FunctionDefinition(BaseModel): - name: str - description: str | None = None - parameters: dict[str, Any] | None = None - - -class ToolDefinition(BaseModel): - type: Literal["function"] = "function" - function: FunctionDefinition - - -class FunctionCall(BaseModel): - name: str - arguments: str # JSON string - - -class ToolCall(BaseModel): - index: int = 0 - id: str - type: Literal["function"] = "function" - function: FunctionCall - - -class ContentPartText(BaseModel): - type: Literal["text"] = "text" - text: str - - -class ImageURL(BaseModel): - url: str # Can be a URL or base64 data URI - detail: str | None = None - - -class ContentPartImage(BaseModel): - type: Literal["image_url"] = "image_url" - image_url: ImageURL - - -ContentPart = ContentPartText | ContentPartImage - - -class ChatMessage(BaseModel): - role: Literal["system", "user", "assistant", "tool"] - content: str | list[ContentPart] | None = None - name: str | None = None - tool_calls: list[ToolCall] | None = None - tool_call_id: str | None = None - - -class ChatCompletionRequest(BaseModel): - model: str = "gemma-3-4b-it" - messages: list[ChatMessage] - temperature: float | None = 0.7 - top_p: float | None = 0.9 - max_tokens: int | None = 4096 - stream: bool = False - stop: str | list[str] | None = None - tools: list[ToolDefinition] | None = None - tool_choice: str | dict | None = None - frequency_penalty: float | None = None - presence_penalty: float | None = None - n: int | None = 1 - - -# --- Response models --- - - -class UsageInfo(BaseModel): - prompt_tokens: int = 0 - completion_tokens: int = 0 - total_tokens: int = 0 - - -class ChoiceMessage(BaseModel): - role: str = "assistant" - content: str | None = None - tool_calls: list[ToolCall] | None = None - - -class Choice(BaseModel): - index: int = 0 - message: ChoiceMessage - finish_reason: str | None = "stop" - - -class ChatCompletionResponse(BaseModel): - id: str - object: str = "chat.completion" - created: int = Field(default_factory=lambda: int(time.time())) - model: str - choices: list[Choice] - usage: UsageInfo - - -# --- Streaming response models --- - - -class DeltaMessage(BaseModel): - role: str | None = None - content: str | None = None - tool_calls: list[ToolCall] | None = None - - -class StreamChoice(BaseModel): - index: int = 0 - delta: DeltaMessage - finish_reason: str | None = None - - -class ChatCompletionChunk(BaseModel): - id: str - object: str = "chat.completion.chunk" - created: int = Field(default_factory=lambda: int(time.time())) - model: str - choices: list[StreamChoice] - usage: UsageInfo | None = None - - -# --- Model listing --- - - -class ModelInfo(BaseModel): - id: str - object: str = "model" - created: int = Field(default_factory=lambda: int(time.time())) - owned_by: str = "local" - context_window: int | None = None - - -class ModelListResponse(BaseModel): - object: str = "list" - data: list[ModelInfo] diff --git a/project.yml b/project.yml new file mode 100644 index 0000000..43b11d0 --- /dev/null +++ b/project.yml @@ -0,0 +1,45 @@ +name: MLXServer +options: + bundleIdPrefix: com.mlxserver + deploymentTarget: + macOS: "15.0" + xcodeVersion: "16.4" + minimumXcodeGenVersion: "2.40" + +packages: + mlx-swift-lm: + url: https://github.com/ml-explore/mlx-swift-lm + branch: main + MarkdownUI: + url: https://github.com/gonzalezreal/swift-markdown-ui + from: "2.4.0" + +targets: + MLXServer: + type: application + platform: macOS + sources: + - MLXServer + settings: + base: + PRODUCT_BUNDLE_IDENTIFIER: com.mlxserver.app + PRODUCT_NAME: MLX Server + MARKETING_VERSION: "1.0.0" + CURRENT_PROJECT_VERSION: "1" + SWIFT_VERSION: "6.0" + MACOSX_DEPLOYMENT_TARGET: "15.0" + GENERATE_INFOPLIST_FILE: "YES" + INFOPLIST_KEY_LSApplicationCategoryType: "public.app-category.developer-tools" + INFOPLIST_KEY_NSHumanReadableCopyright: "" + CODE_SIGN_ENTITLEMENTS: MLXServer/MLXServer.entitlements + CODE_SIGN_IDENTITY: "-" + CODE_SIGN_ALLOW_ENTITLEMENTS_MODIFICATION: "YES" + dependencies: + - package: mlx-swift-lm + product: MLXLLM + - package: mlx-swift-lm + product: MLXVLM + - package: mlx-swift-lm + product: MLXLMCommon + - package: MarkdownUI + product: MarkdownUI diff --git a/pyproject.toml b/pyproject.toml deleted file mode 100644 index d6b0e67..0000000 --- a/pyproject.toml +++ /dev/null @@ -1,20 +0,0 @@ -[project] -name = "mlx-server" -version = "0.1.0" -description = "OpenAI-compatible API server for Gemma 3 4B via MLX" -requires-python = ">=3.11" -dependencies = [ - "fastapi>=0.115.0", - "uvicorn[standard]>=0.30.0", - "mlx>=0.22.0", - "mlx-lm>=0.22.0", - "mlx-vlm>=0.1.18", - "pydantic>=2.0.0", - "sse-starlette>=2.0.0", - "pillow>=10.0.0", - "httpx>=0.27.0", - "torchvision>=0.20.0", -] - -[project.scripts] -mlx-server = "mlx_server.main:main" diff --git a/run.sh b/run.sh deleted file mode 100755 index 9567dfe..0000000 --- a/run.sh +++ /dev/null @@ -1,47 +0,0 @@ -#!/usr/bin/env bash -set -euo pipefail - -SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" -cd "$SCRIPT_DIR" - -# Activate virtual environment -source .venv/bin/activate - -# --- Model selection --- -# Usage: ./run.sh [gemma|gemma3n|qwen] -# Or set MODEL env var directly for a custom model. - -MODEL_CHOICE="${1:-gemma}" - -if [[ -z "${MODEL:-}" ]]; then - case "$MODEL_CHOICE" in - gemma) - MODEL="mlx-community/gemma-3-4b-it-4bit" - ;; - gemma3n) - MODEL="mlx-community/gemma-3n-E4B-it-4bit" - ;; - qwen) - MODEL="mlx-community/Qwen3-VL-4B-Instruct-4bit" - ;; - *) - echo "Unknown model choice: $MODEL_CHOICE" - echo "Usage: $0 [gemma|gemma3n|qwen]" - exit 1 - ;; - esac -fi - -HOST="${HOST:-127.0.0.1}" -PORT="${PORT:-1234}" - -echo "Starting MLX Server..." -echo " Model: $MODEL" -echo " Endpoint: http://$HOST:$PORT" -echo "" - -exec python -m mlx_server.main \ - --model "$MODEL" \ - --host "$HOST" \ - --port "$PORT" \ - "${@:2}" diff --git a/test_server.py b/test_server.py deleted file mode 100644 index 88e1ae6..0000000 --- a/test_server.py +++ /dev/null @@ -1,296 +0,0 @@ -"""Test script for MLX Server – exercises chat, streaming, vision, and tool use.""" - -import base64 -import io -import json -import sys - -import httpx -from PIL import Image, ImageDraw - -BASE_URL = "http://127.0.0.1:1234/v1" -MODEL = "mlx-community/gemma-3-4b-it-4bit" - - -def test_models(): - """Test GET /v1/models.""" - print("=" * 60) - print("TEST: List models") - print("=" * 60) - r = httpx.get(f"{BASE_URL}/models") - r.raise_for_status() - data = r.json() - print(f"Models: {[m['id'] for m in data['data']]}") - print("PASS\n") - - -def test_chat_basic(): - """Test basic non-streaming chat.""" - print("=" * 60) - print("TEST: Basic chat (non-streaming)") - print("=" * 60) - r = httpx.post( - f"{BASE_URL}/chat/completions", - json={ - "model": MODEL, - "messages": [{"role": "user", "content": "Say exactly: 'The sky is blue.' Nothing else."}], - "max_tokens": 50, - "temperature": 0.1, - }, - timeout=120, - ) - r.raise_for_status() - data = r.json() - msg = data["choices"][0]["message"]["content"] - usage = data["usage"] - print(f"Response: {msg}") - print(f"Usage: {usage}") - print(f"Finish reason: {data['choices'][0]['finish_reason']}") - print("PASS\n") - - -def test_chat_streaming(): - """Test streaming chat.""" - print("=" * 60) - print("TEST: Streaming chat") - print("=" * 60) - collected = "" - with httpx.stream( - "POST", - f"{BASE_URL}/chat/completions", - json={ - "model": MODEL, - "messages": [{"role": "user", "content": "Count from 1 to 5, one number per line."}], - "max_tokens": 100, - "temperature": 0.1, - "stream": True, - }, - timeout=120, - ) as response: - response.raise_for_status() - for line in response.iter_lines(): - if not line.startswith("data: "): - continue - payload = line[len("data: "):] - if payload == "[DONE]": - break - chunk = json.loads(payload) - delta = chunk["choices"][0]["delta"] - if delta.get("content"): - collected += delta["content"] - print(delta["content"], end="", flush=True) - if chunk["choices"][0].get("finish_reason"): - print(f"\n[finish_reason: {chunk['choices'][0]['finish_reason']}]") - if chunk.get("usage") and chunk["usage"].get("total_tokens", 0) > 0: - print(f"[usage: {chunk['usage']}]") - print(f"Full collected: {collected!r}") - print("PASS\n") - - -def _make_test_image() -> str: - """Create a simple test image and return it as a base64 data URI.""" - img = Image.new("RGB", (200, 200), color=(135, 206, 235)) - draw = ImageDraw.Draw(img) - # Draw a red circle - draw.ellipse([50, 50, 150, 150], fill=(255, 0, 0), outline=(0, 0, 0), width=2) - # Draw a green triangle - draw.polygon([(100, 20), (60, 80), (140, 80)], fill=(0, 180, 0), outline=(0, 0, 0)) - # Draw yellow text area - draw.rectangle([10, 160, 190, 190], fill=(255, 255, 0)) - - buf = io.BytesIO() - img.save(buf, format="PNG") - b64 = base64.b64encode(buf.getvalue()).decode() - return f"data:image/png;base64,{b64}" - - -def test_vision(): - """Test vision with an image.""" - print("=" * 60) - print("TEST: Vision (image description)") - print("=" * 60) - image_uri = _make_test_image() - print(f"Image: 200x200 PNG with red circle, green triangle, yellow bar") - - r = httpx.post( - f"{BASE_URL}/chat/completions", - json={ - "model": MODEL, - "messages": [ - { - "role": "user", - "content": [ - {"type": "text", "text": "Describe what shapes and colors you see in this image. Be brief."}, - {"type": "image_url", "image_url": {"url": image_uri}}, - ], - } - ], - "max_tokens": 200, - "temperature": 0.1, - }, - timeout=120, - ) - r.raise_for_status() - data = r.json() - msg = data["choices"][0]["message"]["content"] - print(f"Response: {msg}") - print("PASS\n") - - -def test_tool_use(): - """Test tool calling.""" - print("=" * 60) - print("TEST: Tool use") - print("=" * 60) - - tools = [ - { - "type": "function", - "function": { - "name": "get_weather", - "description": "Get the current weather for a given city", - "parameters": { - "type": "object", - "properties": { - "city": { - "type": "string", - "description": "The city name, e.g. 'London'", - }, - "units": { - "type": "string", - "description": "Temperature units: 'celsius' or 'fahrenheit'", - }, - }, - "required": ["city"], - }, - }, - } - ] - - # Step 1: Ask the model to use the tool - print("Step 1: Asking model to get weather for Paris...") - r = httpx.post( - f"{BASE_URL}/chat/completions", - json={ - "model": MODEL, - "messages": [ - {"role": "user", "content": "What is the weather in Paris right now? Use the get_weather tool."}, - ], - "tools": tools, - "max_tokens": 300, - "temperature": 0.1, - }, - timeout=120, - ) - r.raise_for_status() - data = r.json() - choice = data["choices"][0] - print(f"Finish reason: {choice['finish_reason']}") - print(f"Content: {choice['message'].get('content')}") - print(f"Tool calls: {choice['message'].get('tool_calls')}") - - if choice["message"].get("tool_calls"): - tc = choice["message"]["tool_calls"][0] - print(f"\nTool call detected:") - print(f" ID: {tc['id']}") - print(f" Function: {tc['function']['name']}") - print(f" Arguments: {tc['function']['arguments']}") - - # Step 2: Send the tool result back - print("\nStep 2: Sending mock tool result back...") - r2 = httpx.post( - f"{BASE_URL}/chat/completions", - json={ - "model": MODEL, - "messages": [ - {"role": "user", "content": "What is the weather in Paris right now? Use the get_weather tool."}, - { - "role": "assistant", - "content": choice["message"].get("content"), - "tool_calls": choice["message"]["tool_calls"], - }, - { - "role": "tool", - "tool_call_id": tc["id"], - "content": json.dumps({"temperature": 18, "condition": "Partly cloudy", "humidity": 65}), - }, - ], - "tools": tools, - "max_tokens": 300, - "temperature": 0.1, - }, - timeout=120, - ) - r2.raise_for_status() - data2 = r2.json() - msg2 = data2["choices"][0]["message"]["content"] - print(f"Final response: {msg2}") - else: - print("WARNING: Model did not produce a tool call. Raw response above.") - - print("PASS\n") - - -def test_multi_turn(): - """Test multi-turn conversation.""" - print("=" * 60) - print("TEST: Multi-turn conversation") - print("=" * 60) - messages = [ - {"role": "user", "content": "My name is Alice."}, - ] - r = httpx.post( - f"{BASE_URL}/chat/completions", - json={"model": MODEL, "messages": messages, "max_tokens": 100, "temperature": 0.1}, - timeout=120, - ) - r.raise_for_status() - reply1 = r.json()["choices"][0]["message"]["content"] - print(f"Turn 1 reply: {reply1}") - - messages.append({"role": "assistant", "content": reply1}) - messages.append({"role": "user", "content": "What is my name?"}) - - r2 = httpx.post( - f"{BASE_URL}/chat/completions", - json={"model": MODEL, "messages": messages, "max_tokens": 100, "temperature": 0.1}, - timeout=120, - ) - r2.raise_for_status() - reply2 = r2.json()["choices"][0]["message"]["content"] - print(f"Turn 2 reply: {reply2}") - assert "alice" in reply2.lower(), f"Expected 'Alice' in response, got: {reply2}" - print("PASS\n") - - -if __name__ == "__main__": - tests = [ - test_models, - test_chat_basic, - test_chat_streaming, - test_vision, - test_tool_use, - test_multi_turn, - ] - - # Allow running a single test by name - if len(sys.argv) > 1: - name = sys.argv[1] - tests = [t for t in tests if name in t.__name__] - if not tests: - print(f"No test matching '{name}'. Available: models, chat_basic, chat_streaming, vision, tool_use, multi_turn") - sys.exit(1) - - passed = 0 - failed = 0 - for test in tests: - try: - test() - passed += 1 - except Exception as e: - print(f"FAIL: {e}\n") - failed += 1 - - print("=" * 60) - print(f"Results: {passed} passed, {failed} failed") - print("=" * 60)