feat: complete rewrite to swift
This commit is contained in:
10
.gitignore
vendored
10
.gitignore
vendored
@@ -1,11 +1,5 @@
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
*.egg-info/
|
||||
dist/
|
||||
build/
|
||||
.venv/
|
||||
.env
|
||||
*.log
|
||||
.DS_Store
|
||||
*.log
|
||||
settings.local.json
|
||||
xcuserdata/
|
||||
|
||||
55
CLAUDE.md
55
CLAUDE.md
@@ -1,50 +1,55 @@
|
||||
# MLX Server
|
||||
|
||||
OpenAI-compatible API server for local LLMs on Apple Silicon via MLX. Supports Gemma 3 4B and Qwen3 VL 4B (vision + tool use).
|
||||
Native macOS SwiftUI app for local LLMs on Apple Silicon via MLX. Provides a chat UI and an embedded OpenAI-compatible API server. Supports vision and tool use.
|
||||
|
||||
## Quick Start
|
||||
|
||||
```bash
|
||||
# Activate virtual environment
|
||||
source .venv/bin/activate
|
||||
# Build (requires xcodegen: brew install xcodegen)
|
||||
./build.sh
|
||||
|
||||
# Run with Gemma 3 (default)
|
||||
./run.sh
|
||||
|
||||
# Run with Qwen3
|
||||
./run.sh qwen
|
||||
|
||||
# Or directly:
|
||||
python -m mlx_server.main --model mlx-community/gemma-3-4b-it-4bit --port 1234
|
||||
python -m mlx_server.main --model mlx-community/Qwen3-VL-4B-Instruct-4bit --port 1234
|
||||
# Run
|
||||
open "build/Debug/MLX Server.app"
|
||||
```
|
||||
|
||||
## Project Structure
|
||||
|
||||
- `mlx_server/main.py` — FastAPI server, endpoints, CLI entrypoint
|
||||
- `mlx_server/engine.py` — Model loading, prompt building, generation (mlx_vlm)
|
||||
- `mlx_server/models.py` — Pydantic models for OpenAI API request/response types
|
||||
- `MLXServer/MLXServerApp.swift` — App entry point, GPU cache config
|
||||
- `MLXServer/ContentView.swift` — Main layout, toolbar, keyboard shortcuts
|
||||
- `MLXServer/Models/ModelConfig.swift` — Model definitions (alias, repoId, contextLength), resolution
|
||||
- `MLXServer/Models/ChatMessage.swift` — Chat message data model
|
||||
- `MLXServer/ViewModels/ModelManager.swift` — Model loading/switching via VLMModelFactory, offline-first resolution
|
||||
- `MLXServer/ViewModels/ChatViewModel.swift` — Chat state, ChatSession management, API server lifecycle
|
||||
- `MLXServer/Server/APIServer.swift` — NWListener HTTP server, SSE streaming, KV cache reuse, vision, tool call handling
|
||||
- `MLXServer/Server/APIModels.swift` — OpenAI-compatible Codable structs
|
||||
- `MLXServer/Server/ToolCallParser.swift` — Parses tool calls from model output (Gemma tool_code, Qwen XML tags)
|
||||
- `MLXServer/Server/ToolPromptBuilder.swift` — Model-specific tool prompt formatting
|
||||
- `MLXServer/Utilities/LocalModelResolver.swift` — Resolves HF repo IDs to ~/.cache/huggingface/hub/ snapshots
|
||||
- `MLXServer/Utilities/Preferences.swift` — UserDefaults wrapper
|
||||
- `project.yml` — xcodegen project spec
|
||||
- `build.sh` — Build script (xcodegen + xcodebuild)
|
||||
|
||||
## Supported Models
|
||||
|
||||
| Alias | HuggingFace ID | Notes |
|
||||
|-------|---------------|-------|
|
||||
| `gemma` | `mlx-community/gemma-3-4b-it-4bit` | Vision + tool use via `tool_code` blocks (128k context) |
|
||||
| `gemma3n` | `mlx-community/gemma-3n-E4B-it-4bit` | Vision/audio/video + tool use via `tool_code` blocks (32k context, ~1.5x faster) |
|
||||
| `qwen` | `mlx-community/Qwen3-VL-4B-Instruct-4bit` | Vision + tool use via `<tool_call>` tags (256k context) |
|
||||
|
||||
## Key Design Decisions
|
||||
|
||||
- Uses `mlx_vlm` (not `mlx_lm`) as the inference backend — this supports both text and vision in a single model load
|
||||
- Model-specific prompt formatting: Gemma converts system→user/assistant pairs and uses `tool_code` blocks; Qwen3 uses native system role and `<tool_call>` XML tags
|
||||
- Offline-first: if the model is already cached locally (~/.cache/huggingface/hub/), the server resolves the local snapshot path directly — no network requests are made (HEAD checks, update checks, etc.)
|
||||
- Thread lock on generation (single-request-at-a-time) — MLX models aren't safe for concurrent generation
|
||||
- Context window size is read from each model's config at load time (Gemma 3 4B: 128k, Qwen3-VL 4B: 256k)
|
||||
- Uses `mlx-swift-lm` (`MLXVLM` / `VLMModelFactory`) as the inference backend — supports both text and vision in a single model load
|
||||
- Model-specific prompt formatting: Gemma uses `tool_code` blocks; Qwen uses `<tool_call>` XML tags
|
||||
- Offline-first: if the model is already cached locally (~/.cache/huggingface/hub/), `LocalModelResolver` resolves the local snapshot path directly — no network requests
|
||||
- HTTP server built on `Network.framework` (`NWListener`) — no third-party server dependencies
|
||||
- KV cache reuse across API requests — reuses `ChatSession` when conversation history prefix matches
|
||||
- GPU cache limit set to 20 MB; cache cleared on model unload
|
||||
|
||||
## Dependencies
|
||||
|
||||
Managed via `uv` and `pyproject.toml`. Virtual environment in `.venv/`.
|
||||
Managed via Swift Package Manager (declared in `project.yml` for xcodegen).
|
||||
|
||||
```bash
|
||||
uv pip install -e "."
|
||||
```
|
||||
| Package | Products |
|
||||
|---------|----------|
|
||||
| `mlx-swift-lm` | `MLXLLM`, `MLXVLM`, `MLXLMCommon` |
|
||||
| `swift-markdown-ui` | `MarkdownUI` |
|
||||
|
||||
488
MLXServer.xcodeproj/project.pbxproj
Normal file
488
MLXServer.xcodeproj/project.pbxproj
Normal file
@@ -0,0 +1,488 @@
|
||||
// !$*UTF8*$!
|
||||
{
|
||||
archiveVersion = 1;
|
||||
classes = {
|
||||
};
|
||||
objectVersion = 77;
|
||||
objects = {
|
||||
|
||||
/* Begin PBXBuildFile section */
|
||||
0168AEE16009097901363E16 /* ModelManager.swift in Sources */ = {isa = PBXBuildFile; fileRef = 922CBDC9206737BD04AF2874 /* ModelManager.swift */; };
|
||||
165E8AB6ADAE1D59B1A86420 /* Preferences.swift in Sources */ = {isa = PBXBuildFile; fileRef = 145B888FBDD4F931512C5473 /* Preferences.swift */; };
|
||||
189362AAE2CDE5D4B3428334 /* ToolCallParser.swift in Sources */ = {isa = PBXBuildFile; fileRef = E73B165A1822729C907791AE /* ToolCallParser.swift */; };
|
||||
2CAAF7129F7CC45200FA9F6B /* ModelPickerView.swift in Sources */ = {isa = PBXBuildFile; fileRef = C3C3A76C02AF70A9D8F868FC /* ModelPickerView.swift */; };
|
||||
4CB13DC1AC7A500DDBB443EC /* ChatInputView.swift in Sources */ = {isa = PBXBuildFile; fileRef = E5E6AD02CDF23BDAB64700A7 /* ChatInputView.swift */; };
|
||||
50B6861FF8610B3ED4FFAD9D /* MLXServerApp.swift in Sources */ = {isa = PBXBuildFile; fileRef = C67742651DB486871CEF1612 /* MLXServerApp.swift */; };
|
||||
50DD129CCF2843482DEC3B96 /* APIServer.swift in Sources */ = {isa = PBXBuildFile; fileRef = 3D08828E16B17EF02C14243E /* APIServer.swift */; };
|
||||
5946258F1DE88CE904584E0B /* ContentView.swift in Sources */ = {isa = PBXBuildFile; fileRef = 944C699FBB76C734C9DF2F2E /* ContentView.swift */; };
|
||||
5C1E8FE1C521914CEF98D3AA /* ChatMessagesView.swift in Sources */ = {isa = PBXBuildFile; fileRef = DB1A5E8B1C9F2BC4D262C53A /* ChatMessagesView.swift */; };
|
||||
621B7E4382199AC1378F5F9C /* StatusBarView.swift in Sources */ = {isa = PBXBuildFile; fileRef = B0EAB35D7130D56B9E7484BA /* StatusBarView.swift */; };
|
||||
6828CCA8B78AB40906F87CAB /* LocalModelResolver.swift in Sources */ = {isa = PBXBuildFile; fileRef = D733A0D1D4AC25DDDA6C8684 /* LocalModelResolver.swift */; };
|
||||
7CD765C1E2F9F4D7504C8D09 /* Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = B629DA084A9A40E54F8EA5FA /* Assets.xcassets */; };
|
||||
80646C5066BF79BC76E1D9D7 /* ModelConfig.swift in Sources */ = {isa = PBXBuildFile; fileRef = 38DFC212AF4359A45FBE22BA /* ModelConfig.swift */; };
|
||||
84D32315B418B5243E017350 /* ToolPromptBuilder.swift in Sources */ = {isa = PBXBuildFile; fileRef = 16AE82A64D1D07AE3CD8D33A /* ToolPromptBuilder.swift */; };
|
||||
945474365D0B3E961811909A /* MLXVLM in Frameworks */ = {isa = PBXBuildFile; productRef = D5E8E1C2DD8D8AABB4306193 /* MLXVLM */; };
|
||||
B5AA6E3B4BE21676226B342B /* ChatViewModel.swift in Sources */ = {isa = PBXBuildFile; fileRef = B8BD93859F0291F1A3E09DA5 /* ChatViewModel.swift */; };
|
||||
B6D3662995B885C102876B4A /* MLXLMCommon in Frameworks */ = {isa = PBXBuildFile; productRef = 9090667D4134056AE66DC2F1 /* MLXLMCommon */; };
|
||||
D666A311788375E8A061C832 /* SettingsView.swift in Sources */ = {isa = PBXBuildFile; fileRef = 4147321383E94E9F17A0154E /* SettingsView.swift */; };
|
||||
D96DDE66F76FDDA642629E17 /* APIModels.swift in Sources */ = {isa = PBXBuildFile; fileRef = F1A52E2C9964ADA9D841A89B /* APIModels.swift */; };
|
||||
F546CE5955ED253D8A793D5E /* MarkdownUI in Frameworks */ = {isa = PBXBuildFile; productRef = A98257123539E9E738213BFA /* MarkdownUI */; };
|
||||
FAF7D4714AC6D02674920208 /* ChatMessage.swift in Sources */ = {isa = PBXBuildFile; fileRef = A4B359324B5FD8D106C74338 /* ChatMessage.swift */; };
|
||||
FCD48F8C132A2B830A15EEB4 /* MLXLLM in Frameworks */ = {isa = PBXBuildFile; productRef = 3F5A4AC6DBAF7CA686ECA74E /* MLXLLM */; };
|
||||
/* End PBXBuildFile section */
|
||||
|
||||
/* Begin PBXFileReference section */
|
||||
145B888FBDD4F931512C5473 /* Preferences.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Preferences.swift; sourceTree = "<group>"; };
|
||||
16AE82A64D1D07AE3CD8D33A /* ToolPromptBuilder.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ToolPromptBuilder.swift; sourceTree = "<group>"; };
|
||||
38DFC212AF4359A45FBE22BA /* ModelConfig.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ModelConfig.swift; sourceTree = "<group>"; };
|
||||
3AF462805202797F61422AEE /* MLXServer.entitlements */ = {isa = PBXFileReference; lastKnownFileType = text.plist.entitlements; path = MLXServer.entitlements; sourceTree = "<group>"; };
|
||||
3D08828E16B17EF02C14243E /* APIServer.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = APIServer.swift; sourceTree = "<group>"; };
|
||||
4147321383E94E9F17A0154E /* SettingsView.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = SettingsView.swift; sourceTree = "<group>"; };
|
||||
6EE59189918D06B8D2F588FC /* MLXServer.app */ = {isa = PBXFileReference; explicitFileType = wrapper.application; includeInIndex = 0; path = MLXServer.app; sourceTree = BUILT_PRODUCTS_DIR; };
|
||||
922CBDC9206737BD04AF2874 /* ModelManager.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ModelManager.swift; sourceTree = "<group>"; };
|
||||
944C699FBB76C734C9DF2F2E /* ContentView.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ContentView.swift; sourceTree = "<group>"; };
|
||||
A4B359324B5FD8D106C74338 /* ChatMessage.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ChatMessage.swift; sourceTree = "<group>"; };
|
||||
B0EAB35D7130D56B9E7484BA /* StatusBarView.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = StatusBarView.swift; sourceTree = "<group>"; };
|
||||
B629DA084A9A40E54F8EA5FA /* Assets.xcassets */ = {isa = PBXFileReference; lastKnownFileType = folder.assetcatalog; path = Assets.xcassets; sourceTree = "<group>"; };
|
||||
B8BD93859F0291F1A3E09DA5 /* ChatViewModel.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ChatViewModel.swift; sourceTree = "<group>"; };
|
||||
C3C3A76C02AF70A9D8F868FC /* ModelPickerView.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ModelPickerView.swift; sourceTree = "<group>"; };
|
||||
C67742651DB486871CEF1612 /* MLXServerApp.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = MLXServerApp.swift; sourceTree = "<group>"; };
|
||||
D733A0D1D4AC25DDDA6C8684 /* LocalModelResolver.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = LocalModelResolver.swift; sourceTree = "<group>"; };
|
||||
DB1A5E8B1C9F2BC4D262C53A /* ChatMessagesView.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ChatMessagesView.swift; sourceTree = "<group>"; };
|
||||
E5E6AD02CDF23BDAB64700A7 /* ChatInputView.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ChatInputView.swift; sourceTree = "<group>"; };
|
||||
E73B165A1822729C907791AE /* ToolCallParser.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ToolCallParser.swift; sourceTree = "<group>"; };
|
||||
F1A52E2C9964ADA9D841A89B /* APIModels.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = APIModels.swift; sourceTree = "<group>"; };
|
||||
/* End PBXFileReference section */
|
||||
|
||||
/* Begin PBXFrameworksBuildPhase section */
|
||||
A328B75C1B81B56CC7597F12 /* Frameworks */ = {
|
||||
isa = PBXFrameworksBuildPhase;
|
||||
buildActionMask = 2147483647;
|
||||
files = (
|
||||
FCD48F8C132A2B830A15EEB4 /* MLXLLM in Frameworks */,
|
||||
945474365D0B3E961811909A /* MLXVLM in Frameworks */,
|
||||
B6D3662995B885C102876B4A /* MLXLMCommon in Frameworks */,
|
||||
F546CE5955ED253D8A793D5E /* MarkdownUI in Frameworks */,
|
||||
);
|
||||
runOnlyForDeploymentPostprocessing = 0;
|
||||
};
|
||||
/* End PBXFrameworksBuildPhase section */
|
||||
|
||||
/* Begin PBXGroup section */
|
||||
05B1BAE308E64D2FB2E73823 /* Utilities */ = {
|
||||
isa = PBXGroup;
|
||||
children = (
|
||||
D733A0D1D4AC25DDDA6C8684 /* LocalModelResolver.swift */,
|
||||
145B888FBDD4F931512C5473 /* Preferences.swift */,
|
||||
);
|
||||
path = Utilities;
|
||||
sourceTree = "<group>";
|
||||
};
|
||||
652987C2A419DBFC79E32CDE /* Products */ = {
|
||||
isa = PBXGroup;
|
||||
children = (
|
||||
6EE59189918D06B8D2F588FC /* MLXServer.app */,
|
||||
);
|
||||
name = Products;
|
||||
sourceTree = "<group>";
|
||||
};
|
||||
6816BF8EF7C92384DD7C9177 /* MLXServer */ = {
|
||||
isa = PBXGroup;
|
||||
children = (
|
||||
B629DA084A9A40E54F8EA5FA /* Assets.xcassets */,
|
||||
944C699FBB76C734C9DF2F2E /* ContentView.swift */,
|
||||
3AF462805202797F61422AEE /* MLXServer.entitlements */,
|
||||
C67742651DB486871CEF1612 /* MLXServerApp.swift */,
|
||||
BD0E350482D91238B4B59721 /* Models */,
|
||||
E13C1AAA0C49D0ED85EFD94D /* Server */,
|
||||
05B1BAE308E64D2FB2E73823 /* Utilities */,
|
||||
D7A641B0969293E838F9147A /* ViewModels */,
|
||||
7B3BAACD850CBB35C7F4FB6C /* Views */,
|
||||
);
|
||||
path = MLXServer;
|
||||
sourceTree = "<group>";
|
||||
};
|
||||
7B3BAACD850CBB35C7F4FB6C /* Views */ = {
|
||||
isa = PBXGroup;
|
||||
children = (
|
||||
E5E6AD02CDF23BDAB64700A7 /* ChatInputView.swift */,
|
||||
DB1A5E8B1C9F2BC4D262C53A /* ChatMessagesView.swift */,
|
||||
C3C3A76C02AF70A9D8F868FC /* ModelPickerView.swift */,
|
||||
4147321383E94E9F17A0154E /* SettingsView.swift */,
|
||||
B0EAB35D7130D56B9E7484BA /* StatusBarView.swift */,
|
||||
);
|
||||
path = Views;
|
||||
sourceTree = "<group>";
|
||||
};
|
||||
BD0E350482D91238B4B59721 /* Models */ = {
|
||||
isa = PBXGroup;
|
||||
children = (
|
||||
A4B359324B5FD8D106C74338 /* ChatMessage.swift */,
|
||||
38DFC212AF4359A45FBE22BA /* ModelConfig.swift */,
|
||||
);
|
||||
path = Models;
|
||||
sourceTree = "<group>";
|
||||
};
|
||||
D7A641B0969293E838F9147A /* ViewModels */ = {
|
||||
isa = PBXGroup;
|
||||
children = (
|
||||
B8BD93859F0291F1A3E09DA5 /* ChatViewModel.swift */,
|
||||
922CBDC9206737BD04AF2874 /* ModelManager.swift */,
|
||||
);
|
||||
path = ViewModels;
|
||||
sourceTree = "<group>";
|
||||
};
|
||||
E13C1AAA0C49D0ED85EFD94D /* Server */ = {
|
||||
isa = PBXGroup;
|
||||
children = (
|
||||
F1A52E2C9964ADA9D841A89B /* APIModels.swift */,
|
||||
3D08828E16B17EF02C14243E /* APIServer.swift */,
|
||||
E73B165A1822729C907791AE /* ToolCallParser.swift */,
|
||||
16AE82A64D1D07AE3CD8D33A /* ToolPromptBuilder.swift */,
|
||||
);
|
||||
path = Server;
|
||||
sourceTree = "<group>";
|
||||
};
|
||||
E2540E47403820BAAFEF0560 = {
|
||||
isa = PBXGroup;
|
||||
children = (
|
||||
6816BF8EF7C92384DD7C9177 /* MLXServer */,
|
||||
652987C2A419DBFC79E32CDE /* Products */,
|
||||
);
|
||||
sourceTree = "<group>";
|
||||
};
|
||||
/* End PBXGroup section */
|
||||
|
||||
/* Begin PBXNativeTarget section */
|
||||
BCD7107EE884C9B2F4C2C40E /* MLXServer */ = {
|
||||
isa = PBXNativeTarget;
|
||||
buildConfigurationList = 732FBF6595F174F37E5F2835 /* Build configuration list for PBXNativeTarget "MLXServer" */;
|
||||
buildPhases = (
|
||||
BC03844286F51DFAEF96B823 /* Sources */,
|
||||
4668CBE7984322217309B525 /* Resources */,
|
||||
A328B75C1B81B56CC7597F12 /* Frameworks */,
|
||||
);
|
||||
buildRules = (
|
||||
);
|
||||
dependencies = (
|
||||
);
|
||||
name = MLXServer;
|
||||
packageProductDependencies = (
|
||||
3F5A4AC6DBAF7CA686ECA74E /* MLXLLM */,
|
||||
D5E8E1C2DD8D8AABB4306193 /* MLXVLM */,
|
||||
9090667D4134056AE66DC2F1 /* MLXLMCommon */,
|
||||
A98257123539E9E738213BFA /* MarkdownUI */,
|
||||
);
|
||||
productName = MLXServer;
|
||||
productReference = 6EE59189918D06B8D2F588FC /* MLXServer.app */;
|
||||
productType = "com.apple.product-type.application";
|
||||
};
|
||||
/* End PBXNativeTarget section */
|
||||
|
||||
/* Begin PBXProject section */
|
||||
938BC479816FCA8527B731F9 /* Project object */ = {
|
||||
isa = PBXProject;
|
||||
attributes = {
|
||||
BuildIndependentTargetsInParallel = YES;
|
||||
LastUpgradeCheck = 1640;
|
||||
TargetAttributes = {
|
||||
};
|
||||
};
|
||||
buildConfigurationList = 9281433C2FA3F9393F1D48E5 /* Build configuration list for PBXProject "MLXServer" */;
|
||||
developmentRegion = en;
|
||||
hasScannedForEncodings = 0;
|
||||
knownRegions = (
|
||||
Base,
|
||||
en,
|
||||
);
|
||||
mainGroup = E2540E47403820BAAFEF0560;
|
||||
minimizedProjectReferenceProxies = 1;
|
||||
packageReferences = (
|
||||
D402301668D113A49B6DD32D /* XCRemoteSwiftPackageReference "swift-markdown-ui" */,
|
||||
1AA4C71F15847A241E418C0C /* XCRemoteSwiftPackageReference "mlx-swift-lm" */,
|
||||
);
|
||||
preferredProjectObjectVersion = 77;
|
||||
productRefGroup = 652987C2A419DBFC79E32CDE /* Products */;
|
||||
projectDirPath = "";
|
||||
projectRoot = "";
|
||||
targets = (
|
||||
BCD7107EE884C9B2F4C2C40E /* MLXServer */,
|
||||
);
|
||||
};
|
||||
/* End PBXProject section */
|
||||
|
||||
/* Begin PBXResourcesBuildPhase section */
|
||||
4668CBE7984322217309B525 /* Resources */ = {
|
||||
isa = PBXResourcesBuildPhase;
|
||||
buildActionMask = 2147483647;
|
||||
files = (
|
||||
7CD765C1E2F9F4D7504C8D09 /* Assets.xcassets in Resources */,
|
||||
);
|
||||
runOnlyForDeploymentPostprocessing = 0;
|
||||
};
|
||||
/* End PBXResourcesBuildPhase section */
|
||||
|
||||
/* Begin PBXSourcesBuildPhase section */
|
||||
BC03844286F51DFAEF96B823 /* Sources */ = {
|
||||
isa = PBXSourcesBuildPhase;
|
||||
buildActionMask = 2147483647;
|
||||
files = (
|
||||
D96DDE66F76FDDA642629E17 /* APIModels.swift in Sources */,
|
||||
50DD129CCF2843482DEC3B96 /* APIServer.swift in Sources */,
|
||||
4CB13DC1AC7A500DDBB443EC /* ChatInputView.swift in Sources */,
|
||||
FAF7D4714AC6D02674920208 /* ChatMessage.swift in Sources */,
|
||||
5C1E8FE1C521914CEF98D3AA /* ChatMessagesView.swift in Sources */,
|
||||
B5AA6E3B4BE21676226B342B /* ChatViewModel.swift in Sources */,
|
||||
5946258F1DE88CE904584E0B /* ContentView.swift in Sources */,
|
||||
6828CCA8B78AB40906F87CAB /* LocalModelResolver.swift in Sources */,
|
||||
50B6861FF8610B3ED4FFAD9D /* MLXServerApp.swift in Sources */,
|
||||
80646C5066BF79BC76E1D9D7 /* ModelConfig.swift in Sources */,
|
||||
0168AEE16009097901363E16 /* ModelManager.swift in Sources */,
|
||||
2CAAF7129F7CC45200FA9F6B /* ModelPickerView.swift in Sources */,
|
||||
165E8AB6ADAE1D59B1A86420 /* Preferences.swift in Sources */,
|
||||
D666A311788375E8A061C832 /* SettingsView.swift in Sources */,
|
||||
621B7E4382199AC1378F5F9C /* StatusBarView.swift in Sources */,
|
||||
189362AAE2CDE5D4B3428334 /* ToolCallParser.swift in Sources */,
|
||||
84D32315B418B5243E017350 /* ToolPromptBuilder.swift in Sources */,
|
||||
);
|
||||
runOnlyForDeploymentPostprocessing = 0;
|
||||
};
|
||||
/* End PBXSourcesBuildPhase section */
|
||||
|
||||
/* Begin XCBuildConfiguration section */
|
||||
6C0C08FC4653A138A768ECF0 /* Release */ = {
|
||||
isa = XCBuildConfiguration;
|
||||
buildSettings = {
|
||||
ALWAYS_SEARCH_USER_PATHS = NO;
|
||||
CLANG_ANALYZER_NONNULL = YES;
|
||||
CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE;
|
||||
CLANG_CXX_LANGUAGE_STANDARD = "gnu++14";
|
||||
CLANG_CXX_LIBRARY = "libc++";
|
||||
CLANG_ENABLE_MODULES = YES;
|
||||
CLANG_ENABLE_OBJC_ARC = YES;
|
||||
CLANG_ENABLE_OBJC_WEAK = YES;
|
||||
CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES;
|
||||
CLANG_WARN_BOOL_CONVERSION = YES;
|
||||
CLANG_WARN_COMMA = YES;
|
||||
CLANG_WARN_CONSTANT_CONVERSION = YES;
|
||||
CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES;
|
||||
CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR;
|
||||
CLANG_WARN_DOCUMENTATION_COMMENTS = YES;
|
||||
CLANG_WARN_EMPTY_BODY = YES;
|
||||
CLANG_WARN_ENUM_CONVERSION = YES;
|
||||
CLANG_WARN_INFINITE_RECURSION = YES;
|
||||
CLANG_WARN_INT_CONVERSION = YES;
|
||||
CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES;
|
||||
CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES;
|
||||
CLANG_WARN_OBJC_LITERAL_CONVERSION = YES;
|
||||
CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR;
|
||||
CLANG_WARN_QUOTED_INCLUDE_IN_FRAMEWORK_HEADER = YES;
|
||||
CLANG_WARN_RANGE_LOOP_ANALYSIS = YES;
|
||||
CLANG_WARN_STRICT_PROTOTYPES = YES;
|
||||
CLANG_WARN_SUSPICIOUS_MOVE = YES;
|
||||
CLANG_WARN_UNGUARDED_AVAILABILITY = YES_AGGRESSIVE;
|
||||
CLANG_WARN_UNREACHABLE_CODE = YES;
|
||||
CLANG_WARN__DUPLICATE_METHOD_MATCH = YES;
|
||||
COPY_PHASE_STRIP = NO;
|
||||
DEBUG_INFORMATION_FORMAT = "dwarf-with-dsym";
|
||||
ENABLE_NS_ASSERTIONS = NO;
|
||||
ENABLE_STRICT_OBJC_MSGSEND = YES;
|
||||
GCC_C_LANGUAGE_STANDARD = gnu11;
|
||||
GCC_NO_COMMON_BLOCKS = YES;
|
||||
GCC_WARN_64_TO_32_BIT_CONVERSION = YES;
|
||||
GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR;
|
||||
GCC_WARN_UNDECLARED_SELECTOR = YES;
|
||||
GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE;
|
||||
GCC_WARN_UNUSED_FUNCTION = YES;
|
||||
GCC_WARN_UNUSED_VARIABLE = YES;
|
||||
MACOSX_DEPLOYMENT_TARGET = 15.0;
|
||||
MTL_ENABLE_DEBUG_INFO = NO;
|
||||
MTL_FAST_MATH = YES;
|
||||
PRODUCT_NAME = "$(TARGET_NAME)";
|
||||
SDKROOT = macosx;
|
||||
SWIFT_COMPILATION_MODE = wholemodule;
|
||||
SWIFT_OPTIMIZATION_LEVEL = "-O";
|
||||
SWIFT_VERSION = 5.0;
|
||||
};
|
||||
name = Release;
|
||||
};
|
||||
B906BFE6DBE690935A1B8B6E /* Debug */ = {
|
||||
isa = XCBuildConfiguration;
|
||||
buildSettings = {
|
||||
ALWAYS_SEARCH_USER_PATHS = NO;
|
||||
CLANG_ANALYZER_NONNULL = YES;
|
||||
CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE;
|
||||
CLANG_CXX_LANGUAGE_STANDARD = "gnu++14";
|
||||
CLANG_CXX_LIBRARY = "libc++";
|
||||
CLANG_ENABLE_MODULES = YES;
|
||||
CLANG_ENABLE_OBJC_ARC = YES;
|
||||
CLANG_ENABLE_OBJC_WEAK = YES;
|
||||
CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES;
|
||||
CLANG_WARN_BOOL_CONVERSION = YES;
|
||||
CLANG_WARN_COMMA = YES;
|
||||
CLANG_WARN_CONSTANT_CONVERSION = YES;
|
||||
CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES;
|
||||
CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR;
|
||||
CLANG_WARN_DOCUMENTATION_COMMENTS = YES;
|
||||
CLANG_WARN_EMPTY_BODY = YES;
|
||||
CLANG_WARN_ENUM_CONVERSION = YES;
|
||||
CLANG_WARN_INFINITE_RECURSION = YES;
|
||||
CLANG_WARN_INT_CONVERSION = YES;
|
||||
CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES;
|
||||
CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES;
|
||||
CLANG_WARN_OBJC_LITERAL_CONVERSION = YES;
|
||||
CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR;
|
||||
CLANG_WARN_QUOTED_INCLUDE_IN_FRAMEWORK_HEADER = YES;
|
||||
CLANG_WARN_RANGE_LOOP_ANALYSIS = YES;
|
||||
CLANG_WARN_STRICT_PROTOTYPES = YES;
|
||||
CLANG_WARN_SUSPICIOUS_MOVE = YES;
|
||||
CLANG_WARN_UNGUARDED_AVAILABILITY = YES_AGGRESSIVE;
|
||||
CLANG_WARN_UNREACHABLE_CODE = YES;
|
||||
CLANG_WARN__DUPLICATE_METHOD_MATCH = YES;
|
||||
COPY_PHASE_STRIP = NO;
|
||||
DEBUG_INFORMATION_FORMAT = dwarf;
|
||||
ENABLE_STRICT_OBJC_MSGSEND = YES;
|
||||
ENABLE_TESTABILITY = YES;
|
||||
GCC_C_LANGUAGE_STANDARD = gnu11;
|
||||
GCC_DYNAMIC_NO_PIC = NO;
|
||||
GCC_NO_COMMON_BLOCKS = YES;
|
||||
GCC_OPTIMIZATION_LEVEL = 0;
|
||||
GCC_PREPROCESSOR_DEFINITIONS = (
|
||||
"$(inherited)",
|
||||
"DEBUG=1",
|
||||
);
|
||||
GCC_WARN_64_TO_32_BIT_CONVERSION = YES;
|
||||
GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR;
|
||||
GCC_WARN_UNDECLARED_SELECTOR = YES;
|
||||
GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE;
|
||||
GCC_WARN_UNUSED_FUNCTION = YES;
|
||||
GCC_WARN_UNUSED_VARIABLE = YES;
|
||||
MACOSX_DEPLOYMENT_TARGET = 15.0;
|
||||
MTL_ENABLE_DEBUG_INFO = INCLUDE_SOURCE;
|
||||
MTL_FAST_MATH = YES;
|
||||
ONLY_ACTIVE_ARCH = YES;
|
||||
PRODUCT_NAME = "$(TARGET_NAME)";
|
||||
SDKROOT = macosx;
|
||||
SWIFT_ACTIVE_COMPILATION_CONDITIONS = DEBUG;
|
||||
SWIFT_OPTIMIZATION_LEVEL = "-Onone";
|
||||
SWIFT_VERSION = 5.0;
|
||||
};
|
||||
name = Debug;
|
||||
};
|
||||
BE93D60EEB354E6E242C3EDB /* Debug */ = {
|
||||
isa = XCBuildConfiguration;
|
||||
buildSettings = {
|
||||
ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon;
|
||||
CODE_SIGN_ALLOW_ENTITLEMENTS_MODIFICATION = YES;
|
||||
CODE_SIGN_ENTITLEMENTS = MLXServer/MLXServer.entitlements;
|
||||
CODE_SIGN_IDENTITY = "-";
|
||||
COMBINE_HIDPI_IMAGES = YES;
|
||||
CURRENT_PROJECT_VERSION = 1;
|
||||
GENERATE_INFOPLIST_FILE = YES;
|
||||
INFOPLIST_KEY_LSApplicationCategoryType = "public.app-category.developer-tools";
|
||||
INFOPLIST_KEY_NSHumanReadableCopyright = "";
|
||||
LD_RUNPATH_SEARCH_PATHS = (
|
||||
"$(inherited)",
|
||||
"@executable_path/../Frameworks",
|
||||
);
|
||||
MACOSX_DEPLOYMENT_TARGET = 15.0;
|
||||
MARKETING_VERSION = 1.0.0;
|
||||
PRODUCT_BUNDLE_IDENTIFIER = com.mlxserver.app;
|
||||
PRODUCT_NAME = "MLX Server";
|
||||
SDKROOT = macosx;
|
||||
SWIFT_VERSION = 6.0;
|
||||
};
|
||||
name = Debug;
|
||||
};
|
||||
D7A086319EEBE6664326B437 /* Release */ = {
|
||||
isa = XCBuildConfiguration;
|
||||
buildSettings = {
|
||||
ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon;
|
||||
CODE_SIGN_ALLOW_ENTITLEMENTS_MODIFICATION = YES;
|
||||
CODE_SIGN_ENTITLEMENTS = MLXServer/MLXServer.entitlements;
|
||||
CODE_SIGN_IDENTITY = "-";
|
||||
COMBINE_HIDPI_IMAGES = YES;
|
||||
CURRENT_PROJECT_VERSION = 1;
|
||||
GENERATE_INFOPLIST_FILE = YES;
|
||||
INFOPLIST_KEY_LSApplicationCategoryType = "public.app-category.developer-tools";
|
||||
INFOPLIST_KEY_NSHumanReadableCopyright = "";
|
||||
LD_RUNPATH_SEARCH_PATHS = (
|
||||
"$(inherited)",
|
||||
"@executable_path/../Frameworks",
|
||||
);
|
||||
MACOSX_DEPLOYMENT_TARGET = 15.0;
|
||||
MARKETING_VERSION = 1.0.0;
|
||||
PRODUCT_BUNDLE_IDENTIFIER = com.mlxserver.app;
|
||||
PRODUCT_NAME = "MLX Server";
|
||||
SDKROOT = macosx;
|
||||
SWIFT_VERSION = 6.0;
|
||||
};
|
||||
name = Release;
|
||||
};
|
||||
/* End XCBuildConfiguration section */
|
||||
|
||||
/* Begin XCConfigurationList section */
|
||||
732FBF6595F174F37E5F2835 /* Build configuration list for PBXNativeTarget "MLXServer" */ = {
|
||||
isa = XCConfigurationList;
|
||||
buildConfigurations = (
|
||||
BE93D60EEB354E6E242C3EDB /* Debug */,
|
||||
D7A086319EEBE6664326B437 /* Release */,
|
||||
);
|
||||
defaultConfigurationIsVisible = 0;
|
||||
defaultConfigurationName = Debug;
|
||||
};
|
||||
9281433C2FA3F9393F1D48E5 /* Build configuration list for PBXProject "MLXServer" */ = {
|
||||
isa = XCConfigurationList;
|
||||
buildConfigurations = (
|
||||
B906BFE6DBE690935A1B8B6E /* Debug */,
|
||||
6C0C08FC4653A138A768ECF0 /* Release */,
|
||||
);
|
||||
defaultConfigurationIsVisible = 0;
|
||||
defaultConfigurationName = Debug;
|
||||
};
|
||||
/* End XCConfigurationList section */
|
||||
|
||||
/* Begin XCRemoteSwiftPackageReference section */
|
||||
1AA4C71F15847A241E418C0C /* XCRemoteSwiftPackageReference "mlx-swift-lm" */ = {
|
||||
isa = XCRemoteSwiftPackageReference;
|
||||
repositoryURL = "https://github.com/ml-explore/mlx-swift-lm";
|
||||
requirement = {
|
||||
branch = main;
|
||||
kind = branch;
|
||||
};
|
||||
};
|
||||
D402301668D113A49B6DD32D /* XCRemoteSwiftPackageReference "swift-markdown-ui" */ = {
|
||||
isa = XCRemoteSwiftPackageReference;
|
||||
repositoryURL = "https://github.com/gonzalezreal/swift-markdown-ui";
|
||||
requirement = {
|
||||
kind = upToNextMajorVersion;
|
||||
minimumVersion = 2.4.0;
|
||||
};
|
||||
};
|
||||
/* End XCRemoteSwiftPackageReference section */
|
||||
|
||||
/* Begin XCSwiftPackageProductDependency section */
|
||||
3F5A4AC6DBAF7CA686ECA74E /* MLXLLM */ = {
|
||||
isa = XCSwiftPackageProductDependency;
|
||||
package = 1AA4C71F15847A241E418C0C /* XCRemoteSwiftPackageReference "mlx-swift-lm" */;
|
||||
productName = MLXLLM;
|
||||
};
|
||||
9090667D4134056AE66DC2F1 /* MLXLMCommon */ = {
|
||||
isa = XCSwiftPackageProductDependency;
|
||||
package = 1AA4C71F15847A241E418C0C /* XCRemoteSwiftPackageReference "mlx-swift-lm" */;
|
||||
productName = MLXLMCommon;
|
||||
};
|
||||
A98257123539E9E738213BFA /* MarkdownUI */ = {
|
||||
isa = XCSwiftPackageProductDependency;
|
||||
package = D402301668D113A49B6DD32D /* XCRemoteSwiftPackageReference "swift-markdown-ui" */;
|
||||
productName = MarkdownUI;
|
||||
};
|
||||
D5E8E1C2DD8D8AABB4306193 /* MLXVLM */ = {
|
||||
isa = XCSwiftPackageProductDependency;
|
||||
package = 1AA4C71F15847A241E418C0C /* XCRemoteSwiftPackageReference "mlx-swift-lm" */;
|
||||
productName = MLXVLM;
|
||||
};
|
||||
/* End XCSwiftPackageProductDependency section */
|
||||
};
|
||||
rootObject = 938BC479816FCA8527B731F9 /* Project object */;
|
||||
}
|
||||
7
MLXServer.xcodeproj/project.xcworkspace/contents.xcworkspacedata
generated
Normal file
7
MLXServer.xcodeproj/project.xcworkspace/contents.xcworkspacedata
generated
Normal file
@@ -0,0 +1,7 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<Workspace
|
||||
version = "1.0">
|
||||
<FileRef
|
||||
location = "self:">
|
||||
</FileRef>
|
||||
</Workspace>
|
||||
@@ -0,0 +1,159 @@
|
||||
{
|
||||
"originHash" : "418f7299ccb303e0e8992dfc960a3df5df98d527f18667aa162699027b29b6cd",
|
||||
"pins" : [
|
||||
{
|
||||
"identity" : "eventsource",
|
||||
"kind" : "remoteSourceControl",
|
||||
"location" : "https://github.com/mattt/EventSource.git",
|
||||
"state" : {
|
||||
"revision" : "a3a85a85214caf642abaa96ae664e4c772a59f6e",
|
||||
"version" : "1.4.1"
|
||||
}
|
||||
},
|
||||
{
|
||||
"identity" : "mlx-swift",
|
||||
"kind" : "remoteSourceControl",
|
||||
"location" : "https://github.com/ml-explore/mlx-swift",
|
||||
"state" : {
|
||||
"revision" : "6ba4827fb82c97d012eec9ab4b2de21f85c3b33d",
|
||||
"version" : "0.30.6"
|
||||
}
|
||||
},
|
||||
{
|
||||
"identity" : "mlx-swift-lm",
|
||||
"kind" : "remoteSourceControl",
|
||||
"location" : "https://github.com/ml-explore/mlx-swift-lm",
|
||||
"state" : {
|
||||
"branch" : "main",
|
||||
"revision" : "bc3c20ef4644c86f2b347debcfe1efe4308712a6"
|
||||
}
|
||||
},
|
||||
{
|
||||
"identity" : "networkimage",
|
||||
"kind" : "remoteSourceControl",
|
||||
"location" : "https://github.com/gonzalezreal/NetworkImage",
|
||||
"state" : {
|
||||
"revision" : "2849f5323265386e200484b0d0f896e73c3411b9",
|
||||
"version" : "6.0.1"
|
||||
}
|
||||
},
|
||||
{
|
||||
"identity" : "swift-asn1",
|
||||
"kind" : "remoteSourceControl",
|
||||
"location" : "https://github.com/apple/swift-asn1.git",
|
||||
"state" : {
|
||||
"revision" : "9f542610331815e29cc3821d3b6f488db8715517",
|
||||
"version" : "1.6.0"
|
||||
}
|
||||
},
|
||||
{
|
||||
"identity" : "swift-atomics",
|
||||
"kind" : "remoteSourceControl",
|
||||
"location" : "https://github.com/apple/swift-atomics.git",
|
||||
"state" : {
|
||||
"revision" : "b601256eab081c0f92f059e12818ac1d4f178ff7",
|
||||
"version" : "1.3.0"
|
||||
}
|
||||
},
|
||||
{
|
||||
"identity" : "swift-cmark",
|
||||
"kind" : "remoteSourceControl",
|
||||
"location" : "https://github.com/swiftlang/swift-cmark",
|
||||
"state" : {
|
||||
"revision" : "5d9bdaa4228b381639fff09403e39a04926e2dbe",
|
||||
"version" : "0.7.1"
|
||||
}
|
||||
},
|
||||
{
|
||||
"identity" : "swift-collections",
|
||||
"kind" : "remoteSourceControl",
|
||||
"location" : "https://github.com/apple/swift-collections.git",
|
||||
"state" : {
|
||||
"revision" : "8d9834a6189db730f6264db7556a7ffb751e99ee",
|
||||
"version" : "1.4.0"
|
||||
}
|
||||
},
|
||||
{
|
||||
"identity" : "swift-crypto",
|
||||
"kind" : "remoteSourceControl",
|
||||
"location" : "https://github.com/apple/swift-crypto.git",
|
||||
"state" : {
|
||||
"revision" : "fa308c07a6fa04a727212d793e761460e41049c3",
|
||||
"version" : "4.3.0"
|
||||
}
|
||||
},
|
||||
{
|
||||
"identity" : "swift-huggingface",
|
||||
"kind" : "remoteSourceControl",
|
||||
"location" : "https://github.com/huggingface/swift-huggingface.git",
|
||||
"state" : {
|
||||
"revision" : "b721959445b617d0bf03910b2b4aced345fd93bf",
|
||||
"version" : "0.9.0"
|
||||
}
|
||||
},
|
||||
{
|
||||
"identity" : "swift-jinja",
|
||||
"kind" : "remoteSourceControl",
|
||||
"location" : "https://github.com/huggingface/swift-jinja.git",
|
||||
"state" : {
|
||||
"revision" : "f731f03bf746481d4fda07f817c3774390c4d5b9",
|
||||
"version" : "2.3.2"
|
||||
}
|
||||
},
|
||||
{
|
||||
"identity" : "swift-markdown-ui",
|
||||
"kind" : "remoteSourceControl",
|
||||
"location" : "https://github.com/gonzalezreal/swift-markdown-ui",
|
||||
"state" : {
|
||||
"revision" : "5f613358148239d0292c0cef674a3c2314737f9e",
|
||||
"version" : "2.4.1"
|
||||
}
|
||||
},
|
||||
{
|
||||
"identity" : "swift-nio",
|
||||
"kind" : "remoteSourceControl",
|
||||
"location" : "https://github.com/apple/swift-nio.git",
|
||||
"state" : {
|
||||
"revision" : "b31565862a8f39866af50bc6676160d8dda7de35",
|
||||
"version" : "2.96.0"
|
||||
}
|
||||
},
|
||||
{
|
||||
"identity" : "swift-numerics",
|
||||
"kind" : "remoteSourceControl",
|
||||
"location" : "https://github.com/apple/swift-numerics",
|
||||
"state" : {
|
||||
"revision" : "0c0290ff6b24942dadb83a929ffaaa1481df04a2",
|
||||
"version" : "1.1.1"
|
||||
}
|
||||
},
|
||||
{
|
||||
"identity" : "swift-system",
|
||||
"kind" : "remoteSourceControl",
|
||||
"location" : "https://github.com/apple/swift-system.git",
|
||||
"state" : {
|
||||
"revision" : "7c6ad0fc39d0763e0b699210e4124afd5041c5df",
|
||||
"version" : "1.6.4"
|
||||
}
|
||||
},
|
||||
{
|
||||
"identity" : "swift-transformers",
|
||||
"kind" : "remoteSourceControl",
|
||||
"location" : "https://github.com/huggingface/swift-transformers",
|
||||
"state" : {
|
||||
"revision" : "eed7264ac5e4ec5dfa6165c6e5c5577364344fe4",
|
||||
"version" : "1.2.0"
|
||||
}
|
||||
},
|
||||
{
|
||||
"identity" : "yyjson",
|
||||
"kind" : "remoteSourceControl",
|
||||
"location" : "https://github.com/ibireme/yyjson.git",
|
||||
"state" : {
|
||||
"revision" : "8b4a38dc994a110abaec8a400615567bd996105f",
|
||||
"version" : "0.12.0"
|
||||
}
|
||||
}
|
||||
],
|
||||
"version" : 3
|
||||
}
|
||||
11
MLXServer/Assets.xcassets/AccentColor.colorset/Contents.json
Normal file
11
MLXServer/Assets.xcassets/AccentColor.colorset/Contents.json
Normal file
@@ -0,0 +1,11 @@
|
||||
{
|
||||
"colors" : [
|
||||
{
|
||||
"idiom" : "universal"
|
||||
}
|
||||
],
|
||||
"info" : {
|
||||
"author" : "xcode",
|
||||
"version" : 1
|
||||
}
|
||||
}
|
||||
58
MLXServer/Assets.xcassets/AppIcon.appiconset/Contents.json
Normal file
58
MLXServer/Assets.xcassets/AppIcon.appiconset/Contents.json
Normal file
@@ -0,0 +1,58 @@
|
||||
{
|
||||
"images" : [
|
||||
{
|
||||
"idiom" : "mac",
|
||||
"scale" : "1x",
|
||||
"size" : "16x16"
|
||||
},
|
||||
{
|
||||
"idiom" : "mac",
|
||||
"scale" : "2x",
|
||||
"size" : "16x16"
|
||||
},
|
||||
{
|
||||
"idiom" : "mac",
|
||||
"scale" : "1x",
|
||||
"size" : "32x32"
|
||||
},
|
||||
{
|
||||
"idiom" : "mac",
|
||||
"scale" : "2x",
|
||||
"size" : "32x32"
|
||||
},
|
||||
{
|
||||
"idiom" : "mac",
|
||||
"scale" : "1x",
|
||||
"size" : "128x128"
|
||||
},
|
||||
{
|
||||
"idiom" : "mac",
|
||||
"scale" : "2x",
|
||||
"size" : "128x128"
|
||||
},
|
||||
{
|
||||
"idiom" : "mac",
|
||||
"scale" : "1x",
|
||||
"size" : "256x256"
|
||||
},
|
||||
{
|
||||
"idiom" : "mac",
|
||||
"scale" : "2x",
|
||||
"size" : "256x256"
|
||||
},
|
||||
{
|
||||
"idiom" : "mac",
|
||||
"scale" : "1x",
|
||||
"size" : "512x512"
|
||||
},
|
||||
{
|
||||
"idiom" : "mac",
|
||||
"scale" : "2x",
|
||||
"size" : "512x512"
|
||||
}
|
||||
],
|
||||
"info" : {
|
||||
"author" : "xcode",
|
||||
"version" : 1
|
||||
}
|
||||
}
|
||||
6
MLXServer/Assets.xcassets/Contents.json
Normal file
6
MLXServer/Assets.xcassets/Contents.json
Normal file
@@ -0,0 +1,6 @@
|
||||
{
|
||||
"info" : {
|
||||
"author" : "xcode",
|
||||
"version" : 1
|
||||
}
|
||||
}
|
||||
115
MLXServer/ContentView.swift
Normal file
115
MLXServer/ContentView.swift
Normal file
@@ -0,0 +1,115 @@
|
||||
import SwiftUI
|
||||
|
||||
struct ContentView: View {
|
||||
@Environment(ModelManager.self) private var modelManager
|
||||
@State private var chatVM: ChatViewModel?
|
||||
@State private var showLoadError = false
|
||||
|
||||
var body: some View {
|
||||
Group {
|
||||
if let chatVM {
|
||||
ChatView(viewModel: chatVM)
|
||||
} else {
|
||||
ProgressView("Initializing…")
|
||||
}
|
||||
}
|
||||
.navigationTitle(modelManager.currentModel?.displayName ?? "MLX Server")
|
||||
.onAppear {
|
||||
if chatVM == nil {
|
||||
chatVM = ChatViewModel(modelManager: modelManager)
|
||||
// Auto-start API server if configured
|
||||
if Preferences.apiAutoStart {
|
||||
chatVM?.startAPIServer()
|
||||
}
|
||||
}
|
||||
}
|
||||
.onChange(of: modelManager.currentModel) {
|
||||
chatVM?.resetSession()
|
||||
// Persist last used model
|
||||
if let id = modelManager.currentModel?.id {
|
||||
Preferences.lastModelId = id
|
||||
}
|
||||
}
|
||||
.onChange(of: modelManager.errorMessage) {
|
||||
showLoadError = modelManager.errorMessage != nil
|
||||
}
|
||||
.alert("Model Error", isPresented: $showLoadError) {
|
||||
Button("Retry") {
|
||||
if let config = modelManager.currentModel ?? ModelConfig.availableModels.first {
|
||||
Task { await modelManager.loadModel(config) }
|
||||
}
|
||||
}
|
||||
Button("Cancel", role: .cancel) {
|
||||
modelManager.errorMessage = nil
|
||||
}
|
||||
} message: {
|
||||
Text(modelManager.errorMessage ?? "Unknown error loading model.")
|
||||
}
|
||||
.toolbar {
|
||||
ToolbarItem(placement: .principal) {
|
||||
ModelPickerView()
|
||||
}
|
||||
|
||||
ToolbarItemGroup(placement: .primaryAction) {
|
||||
// API server toggle
|
||||
Button {
|
||||
if let chatVM {
|
||||
if chatVM.apiServer.isRunning {
|
||||
chatVM.stopAPIServer()
|
||||
} else {
|
||||
chatVM.startAPIServer()
|
||||
}
|
||||
}
|
||||
} label: {
|
||||
// Running → solid globe (green tint), click to stop
|
||||
// Stopped → slashed globe, click to start
|
||||
Label(
|
||||
chatVM?.apiServer.isRunning == true ? "Stop API" : "Start API",
|
||||
systemImage: chatVM?.apiServer.isRunning == true ? "network" : "network.slash"
|
||||
)
|
||||
.foregroundStyle(chatVM?.apiServer.isRunning == true ? .green : .secondary)
|
||||
}
|
||||
.help(chatVM?.apiServer.isRunning == true ? "API server running on port \(Preferences.apiPort) — click to stop" : "Click to start API server")
|
||||
|
||||
// New conversation
|
||||
Button {
|
||||
chatVM?.newConversation()
|
||||
} label: {
|
||||
Label("New Chat", systemImage: "plus.message")
|
||||
}
|
||||
.keyboardShortcut("n", modifiers: .command)
|
||||
}
|
||||
}
|
||||
// Cmd+1/2/3 model switching
|
||||
.background {
|
||||
modelSwitchShortcuts
|
||||
}
|
||||
}
|
||||
|
||||
@ViewBuilder
|
||||
private var modelSwitchShortcuts: some View {
|
||||
ForEach(Array(ModelConfig.availableModels.enumerated()), id: \.element.id) { index, config in
|
||||
if index < 9 {
|
||||
Button("") {
|
||||
Task { await modelManager.loadModel(config) }
|
||||
}
|
||||
.keyboardShortcut(KeyEquivalent(Character(String(index + 1))), modifiers: .command)
|
||||
.hidden()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// The main chat layout: messages + input area + status bar.
|
||||
struct ChatView: View {
|
||||
@Bindable var viewModel: ChatViewModel
|
||||
|
||||
var body: some View {
|
||||
VStack(spacing: 0) {
|
||||
ChatMessagesView(viewModel: viewModel)
|
||||
Divider()
|
||||
ChatInputView(viewModel: viewModel)
|
||||
StatusBarView(viewModel: viewModel)
|
||||
}
|
||||
}
|
||||
}
|
||||
14
MLXServer/MLXServer.entitlements
Normal file
14
MLXServer/MLXServer.entitlements
Normal file
@@ -0,0 +1,14 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
|
||||
<plist version="1.0">
|
||||
<dict>
|
||||
<key>com.apple.security.app-sandbox</key>
|
||||
<true/>
|
||||
<key>com.apple.security.network.client</key>
|
||||
<true/>
|
||||
<key>com.apple.security.network.server</key>
|
||||
<true/>
|
||||
<key>com.apple.security.files.user-selected.read-only</key>
|
||||
<true/>
|
||||
</dict>
|
||||
</plist>
|
||||
33
MLXServer/MLXServerApp.swift
Normal file
33
MLXServer/MLXServerApp.swift
Normal file
@@ -0,0 +1,33 @@
|
||||
import SwiftUI
|
||||
import MLX
|
||||
|
||||
@main
|
||||
struct MLXServerApp: App {
|
||||
@State private var modelManager = ModelManager()
|
||||
|
||||
init() {
|
||||
MLX.GPU.set(cacheLimit: 20 * 1024 * 1024)
|
||||
}
|
||||
|
||||
var body: some Scene {
|
||||
WindowGroup {
|
||||
ContentView()
|
||||
.environment(modelManager)
|
||||
.task {
|
||||
// Auto-load last used model (or default)
|
||||
let modelId = Preferences.lastModelId ?? ModelConfig.default.id
|
||||
if let config = ModelConfig.availableModels.first(where: { $0.id == modelId }) {
|
||||
await modelManager.loadModel(config)
|
||||
}
|
||||
}
|
||||
}
|
||||
.windowStyle(.titleBar)
|
||||
.defaultSize(width: 800, height: 700)
|
||||
|
||||
#if os(macOS)
|
||||
Settings {
|
||||
SettingsView()
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
60
MLXServer/Models/ChatMessage.swift
Normal file
60
MLXServer/Models/ChatMessage.swift
Normal file
@@ -0,0 +1,60 @@
|
||||
import AppKit
|
||||
import Foundation
|
||||
|
||||
/// A single message in the chat conversation.
|
||||
struct ChatMessage: Identifiable {
|
||||
let id = UUID()
|
||||
let role: Role
|
||||
var content: String
|
||||
var images: [NSImage]
|
||||
var isStreaming: Bool
|
||||
let timestamp: Date
|
||||
|
||||
enum Role: String {
|
||||
case system
|
||||
case user
|
||||
case assistant
|
||||
}
|
||||
|
||||
init(role: Role, content: String, images: [NSImage] = [], isStreaming: Bool = false) {
|
||||
self.role = role
|
||||
self.content = content
|
||||
self.images = images
|
||||
self.isStreaming = isStreaming
|
||||
self.timestamp = Date()
|
||||
}
|
||||
}
|
||||
|
||||
/// Observable conversation state holding all messages.
|
||||
@Observable
|
||||
@MainActor
|
||||
final class Conversation {
|
||||
var messages: [ChatMessage] = []
|
||||
|
||||
func addUserMessage(_ text: String, images: [NSImage] = []) {
|
||||
messages.append(ChatMessage(role: .user, content: text, images: images))
|
||||
}
|
||||
|
||||
/// Adds an empty assistant message (to be filled via streaming) and returns its index.
|
||||
func addAssistantMessage() -> Int {
|
||||
let msg = ChatMessage(role: .assistant, content: "", isStreaming: true)
|
||||
messages.append(msg)
|
||||
return messages.count - 1
|
||||
}
|
||||
|
||||
/// Appends a text chunk to the assistant message at the given index.
|
||||
func appendToMessage(at index: Int, chunk: String) {
|
||||
guard index < messages.count else { return }
|
||||
messages[index].content += chunk
|
||||
}
|
||||
|
||||
/// Marks the assistant message at the given index as done streaming.
|
||||
func finalizeMessage(at index: Int) {
|
||||
guard index < messages.count else { return }
|
||||
messages[index].isStreaming = false
|
||||
}
|
||||
|
||||
func clear() {
|
||||
messages.removeAll()
|
||||
}
|
||||
}
|
||||
56
MLXServer/Models/ModelConfig.swift
Normal file
56
MLXServer/Models/ModelConfig.swift
Normal file
@@ -0,0 +1,56 @@
|
||||
import Foundation
|
||||
import MLXLMCommon
|
||||
|
||||
/// Defines a supported model with its metadata.
|
||||
struct ModelConfig: Identifiable, Hashable {
|
||||
let id: String // alias: "gemma", "gemma3n", "qwen"
|
||||
let repoId: String // HuggingFace ID
|
||||
let displayName: String
|
||||
let contextLength: Int
|
||||
|
||||
/// All models supported by the app.
|
||||
static let availableModels: [ModelConfig] = [
|
||||
ModelConfig(
|
||||
id: "gemma",
|
||||
repoId: "mlx-community/gemma-3-4b-it-4bit",
|
||||
displayName: "Gemma 3 4B",
|
||||
contextLength: 128_000
|
||||
),
|
||||
ModelConfig(
|
||||
id: "qwen",
|
||||
repoId: "mlx-community/Qwen3-VL-4B-Instruct-4bit",
|
||||
displayName: "Qwen3 VL 4B",
|
||||
contextLength: 256_000
|
||||
),
|
||||
]
|
||||
|
||||
static let `default` = availableModels[0]
|
||||
|
||||
/// Whether this model is cached locally (no download needed).
|
||||
var isLocal: Bool {
|
||||
LocalModelResolver.isAvailable(repoId: repoId)
|
||||
}
|
||||
|
||||
/// Build a ModelConfiguration for mlx-swift-lm from this config.
|
||||
var modelConfiguration: ModelConfiguration {
|
||||
ModelConfiguration(id: repoId)
|
||||
}
|
||||
|
||||
/// Resolve a model string (alias, full repo ID, or partial match) to a ModelConfig.
|
||||
/// Mirrors the Python server's `ModelManager.resolve_model()`.
|
||||
static func resolve(_ requested: String) -> ModelConfig? {
|
||||
// Exact alias match
|
||||
if let config = availableModels.first(where: { $0.id == requested }) {
|
||||
return config
|
||||
}
|
||||
// Exact repo ID match
|
||||
if let config = availableModels.first(where: { $0.repoId == requested }) {
|
||||
return config
|
||||
}
|
||||
// Partial match (e.g. "gemma-3-4b-it" matches the gemma entry)
|
||||
if let config = availableModels.first(where: { requested.contains($0.id) || $0.repoId.contains(requested) || requested.contains($0.repoId) }) {
|
||||
return config
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
237
MLXServer/Server/APIModels.swift
Normal file
237
MLXServer/Server/APIModels.swift
Normal file
@@ -0,0 +1,237 @@
|
||||
import Foundation
|
||||
|
||||
// MARK: - Request models
|
||||
|
||||
struct APIFunctionDefinition: Codable {
|
||||
let name: String
|
||||
let description: String?
|
||||
let parameters: [String: AnyCodable]?
|
||||
}
|
||||
|
||||
struct APIToolDefinition: Codable {
|
||||
let type: String // "function"
|
||||
let function: APIFunctionDefinition
|
||||
}
|
||||
|
||||
struct APIFunctionCall: Codable {
|
||||
let name: String
|
||||
let arguments: String // JSON string
|
||||
}
|
||||
|
||||
struct APIToolCall: Codable {
|
||||
let index: Int
|
||||
let id: String
|
||||
let type: String // "function"
|
||||
let function: APIFunctionCall
|
||||
|
||||
init(index: Int = 0, id: String, type: String = "function", function: APIFunctionCall) {
|
||||
self.index = index
|
||||
self.id = id
|
||||
self.type = type
|
||||
self.function = function
|
||||
}
|
||||
}
|
||||
|
||||
struct APIImageURL: Codable {
|
||||
let url: String
|
||||
let detail: String?
|
||||
}
|
||||
|
||||
struct APIContentPart: Codable {
|
||||
let type: String // "text" or "image_url"
|
||||
let text: String?
|
||||
let image_url: APIImageURL?
|
||||
}
|
||||
|
||||
struct APIChatMessage: Codable {
|
||||
let role: String
|
||||
let content: MessageContent?
|
||||
let name: String?
|
||||
let tool_calls: [APIToolCall]?
|
||||
let tool_call_id: String?
|
||||
|
||||
enum MessageContent: Codable {
|
||||
case text(String)
|
||||
case parts([APIContentPart])
|
||||
|
||||
init(from decoder: Decoder) throws {
|
||||
let container = try decoder.singleValueContainer()
|
||||
if let text = try? container.decode(String.self) {
|
||||
self = .text(text)
|
||||
} else if let parts = try? container.decode([APIContentPart].self) {
|
||||
self = .parts(parts)
|
||||
} else {
|
||||
self = .text("")
|
||||
}
|
||||
}
|
||||
|
||||
func encode(to encoder: Encoder) throws {
|
||||
var container = encoder.singleValueContainer()
|
||||
switch self {
|
||||
case .text(let text):
|
||||
try container.encode(text)
|
||||
case .parts(let parts):
|
||||
try container.encode(parts)
|
||||
}
|
||||
}
|
||||
|
||||
/// Extract plain text content.
|
||||
var textContent: String {
|
||||
switch self {
|
||||
case .text(let t): return t
|
||||
case .parts(let parts):
|
||||
return parts.compactMap { $0.text }.joined()
|
||||
}
|
||||
}
|
||||
|
||||
/// Extract image URLs/base64 data URIs.
|
||||
var imageURLs: [String] {
|
||||
switch self {
|
||||
case .text: return []
|
||||
case .parts(let parts):
|
||||
return parts.compactMap { $0.image_url?.url }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct APIChatCompletionRequest: Codable {
|
||||
let model: String?
|
||||
let messages: [APIChatMessage]
|
||||
let temperature: Double?
|
||||
let top_p: Double?
|
||||
let max_tokens: Int?
|
||||
let stream: Bool?
|
||||
let stop: StopSequence?
|
||||
let tools: [APIToolDefinition]?
|
||||
let tool_choice: AnyCodable?
|
||||
let frequency_penalty: Double?
|
||||
let presence_penalty: Double?
|
||||
let n: Int?
|
||||
|
||||
enum StopSequence: Codable {
|
||||
case single(String)
|
||||
case multiple([String])
|
||||
|
||||
init(from decoder: Decoder) throws {
|
||||
let container = try decoder.singleValueContainer()
|
||||
if let s = try? container.decode(String.self) {
|
||||
self = .single(s)
|
||||
} else if let arr = try? container.decode([String].self) {
|
||||
self = .multiple(arr)
|
||||
} else {
|
||||
self = .multiple([])
|
||||
}
|
||||
}
|
||||
|
||||
func encode(to encoder: Encoder) throws {
|
||||
var container = encoder.singleValueContainer()
|
||||
switch self {
|
||||
case .single(let s): try container.encode(s)
|
||||
case .multiple(let arr): try container.encode(arr)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Response models
|
||||
|
||||
struct APIUsageInfo: Codable {
|
||||
let prompt_tokens: Int
|
||||
let completion_tokens: Int
|
||||
let total_tokens: Int
|
||||
}
|
||||
|
||||
struct APIChoiceMessage: Codable {
|
||||
let role: String
|
||||
let content: String?
|
||||
let tool_calls: [APIToolCall]?
|
||||
}
|
||||
|
||||
struct APIChoice: Codable {
|
||||
let index: Int
|
||||
let message: APIChoiceMessage
|
||||
let finish_reason: String?
|
||||
}
|
||||
|
||||
struct APIChatCompletionResponse: Codable {
|
||||
let id: String
|
||||
let object: String
|
||||
let created: Int
|
||||
let model: String
|
||||
let choices: [APIChoice]
|
||||
let usage: APIUsageInfo
|
||||
}
|
||||
|
||||
// MARK: - Streaming response models
|
||||
|
||||
struct APIDeltaMessage: Codable {
|
||||
let role: String?
|
||||
let content: String?
|
||||
let tool_calls: [APIToolCall]?
|
||||
}
|
||||
|
||||
struct APIStreamChoice: Codable {
|
||||
let index: Int
|
||||
let delta: APIDeltaMessage
|
||||
let finish_reason: String?
|
||||
}
|
||||
|
||||
struct APIChatCompletionChunk: Codable {
|
||||
let id: String
|
||||
let object: String
|
||||
let created: Int
|
||||
let model: String
|
||||
let choices: [APIStreamChoice]
|
||||
let usage: APIUsageInfo?
|
||||
}
|
||||
|
||||
// MARK: - Model listing
|
||||
|
||||
struct APIModelInfo: Codable {
|
||||
let id: String
|
||||
let object: String
|
||||
let created: Int
|
||||
let owned_by: String
|
||||
let context_window: Int?
|
||||
}
|
||||
|
||||
struct APIModelListResponse: Codable {
|
||||
let object: String
|
||||
let data: [APIModelInfo]
|
||||
}
|
||||
|
||||
// MARK: - Utility: type-erased Codable
|
||||
|
||||
struct AnyCodable: Codable {
|
||||
let value: Any
|
||||
|
||||
init(_ value: Any) {
|
||||
self.value = value
|
||||
}
|
||||
|
||||
init(from decoder: Decoder) throws {
|
||||
let container = try decoder.singleValueContainer()
|
||||
if let intVal = try? container.decode(Int.self) { value = intVal }
|
||||
else if let doubleVal = try? container.decode(Double.self) { value = doubleVal }
|
||||
else if let boolVal = try? container.decode(Bool.self) { value = boolVal }
|
||||
else if let stringVal = try? container.decode(String.self) { value = stringVal }
|
||||
else if let arrayVal = try? container.decode([AnyCodable].self) { value = arrayVal.map(\.value) }
|
||||
else if let dictVal = try? container.decode([String: AnyCodable].self) {
|
||||
value = dictVal.mapValues(\.value)
|
||||
} else { value = NSNull() }
|
||||
}
|
||||
|
||||
func encode(to encoder: Encoder) throws {
|
||||
var container = encoder.singleValueContainer()
|
||||
switch value {
|
||||
case let v as Int: try container.encode(v)
|
||||
case let v as Double: try container.encode(v)
|
||||
case let v as Bool: try container.encode(v)
|
||||
case let v as String: try container.encode(v)
|
||||
case let v as [Any]: try container.encode(v.map { AnyCodable($0) })
|
||||
case let v as [String: Any]: try container.encode(v.mapValues { AnyCodable($0) })
|
||||
default: try container.encodeNil()
|
||||
}
|
||||
}
|
||||
}
|
||||
814
MLXServer/Server/APIServer.swift
Normal file
814
MLXServer/Server/APIServer.swift
Normal file
@@ -0,0 +1,814 @@
|
||||
import AppKit
|
||||
import Foundation
|
||||
import MLXLMCommon
|
||||
import Network
|
||||
|
||||
/// Lightweight HTTP server that exposes OpenAI-compatible endpoints.
|
||||
/// Runs entirely in-process using NWListener (Network.framework, no third-party deps).
|
||||
@Observable
|
||||
@MainActor
|
||||
final class APIServer {
|
||||
var isRunning = false
|
||||
var port: Int = 1234
|
||||
var requestCount: Int = 0
|
||||
|
||||
private var listener: NWListener?
|
||||
private var modelManager: ModelManager?
|
||||
|
||||
// Persistent ChatSession for KV cache reuse across requests
|
||||
private var cachedSession: ChatSession?
|
||||
private var cachedMessages: [Chat.Message]?
|
||||
private var cachedModelId: String?
|
||||
|
||||
func start(modelManager: ModelManager, port: Int = 1234) {
|
||||
guard !isRunning else { return }
|
||||
self.modelManager = modelManager
|
||||
self.port = port
|
||||
|
||||
do {
|
||||
let params = NWParameters.tcp
|
||||
params.allowLocalEndpointReuse = true
|
||||
listener = try NWListener(using: params, on: NWEndpoint.Port(integerLiteral: UInt16(port)))
|
||||
|
||||
listener?.stateUpdateHandler = { [weak self] state in
|
||||
Task { @MainActor in
|
||||
switch state {
|
||||
case .ready:
|
||||
self?.isRunning = true
|
||||
print("[APIServer] Listening on port \(port)")
|
||||
case .failed(let error):
|
||||
self?.isRunning = false
|
||||
print("[APIServer] Failed: \(error)")
|
||||
case .cancelled:
|
||||
self?.isRunning = false
|
||||
default:
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
listener?.newConnectionHandler = { [weak self] connection in
|
||||
Task { @MainActor in
|
||||
self?.handleConnection(connection)
|
||||
}
|
||||
}
|
||||
|
||||
listener?.start(queue: .global(qos: .userInitiated))
|
||||
} catch {
|
||||
print("[APIServer] Failed to start: \(error)")
|
||||
}
|
||||
}
|
||||
|
||||
func stop() {
|
||||
listener?.cancel()
|
||||
listener = nil
|
||||
isRunning = false
|
||||
cachedSession = nil
|
||||
cachedMessages = nil
|
||||
cachedModelId = nil
|
||||
}
|
||||
|
||||
// MARK: - Connection handling
|
||||
|
||||
private func handleConnection(_ connection: NWConnection) {
|
||||
connection.start(queue: .global(qos: .userInitiated))
|
||||
receiveFullHTTPRequest(connection: connection, accumulated: Data())
|
||||
}
|
||||
|
||||
/// Receive the full HTTP request, accumulating data until we have the complete body.
|
||||
/// This handles large POST bodies (e.g. base64 images) that arrive in multiple chunks.
|
||||
private func receiveFullHTTPRequest(connection: NWConnection, accumulated: Data) {
|
||||
connection.receive(minimumIncompleteLength: 1, maximumLength: 4_194_304) {
|
||||
[weak self] data, _, isComplete, error in
|
||||
guard let self else { connection.cancel(); return }
|
||||
|
||||
var buffer = accumulated
|
||||
if let data { buffer.append(data) }
|
||||
|
||||
// Try to determine if we have the full request
|
||||
if let request = HTTPRequest.parse(buffer) {
|
||||
// Check if we have enough body data based on Content-Length
|
||||
if let clHeader = request.headers["content-length"],
|
||||
let contentLength = Int(clHeader),
|
||||
(request.body?.count ?? 0) < contentLength {
|
||||
// Need more data
|
||||
if isComplete {
|
||||
// Connection closed before we got all data
|
||||
Task { @MainActor in
|
||||
self.sendResponse(connection: connection, status: 400, body: #"{"error":"Incomplete request body"}"#)
|
||||
}
|
||||
} else {
|
||||
self.receiveFullHTTPRequest(connection: connection, accumulated: buffer)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
Task { @MainActor in
|
||||
self.requestCount += 1
|
||||
await self.processHTTPRequest(request: request, connection: connection)
|
||||
}
|
||||
} else if isComplete {
|
||||
Task { @MainActor in
|
||||
self.sendResponse(connection: connection, status: 400, body: #"{"error":"Bad Request"}"#)
|
||||
}
|
||||
} else {
|
||||
// Keep accumulating
|
||||
self.receiveFullHTTPRequest(connection: connection, accumulated: buffer)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private func processHTTPRequest(request: HTTPRequest, connection: NWConnection) async {
|
||||
// CORS preflight
|
||||
if request.method == "OPTIONS" {
|
||||
sendResponse(connection: connection, status: 200, body: "", extraHeaders: corsHeaders())
|
||||
return
|
||||
}
|
||||
|
||||
switch (request.method, request.path) {
|
||||
case ("GET", "/health"):
|
||||
sendResponse(connection: connection, status: 200, body: #"{"status":"ok"}"#)
|
||||
|
||||
case ("GET", "/v1/models"):
|
||||
await handleListModels(connection: connection)
|
||||
|
||||
case ("POST", "/v1/chat/completions"):
|
||||
await handleChatCompletions(connection: connection, body: request.body)
|
||||
|
||||
default:
|
||||
sendResponse(connection: connection, status: 404, body: #"{"error":"Not Found"}"#)
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - GET /v1/models
|
||||
|
||||
private func handleListModels(connection: NWConnection) async {
|
||||
let models = ModelConfig.availableModels.map { config in
|
||||
APIModelInfo(
|
||||
id: config.repoId,
|
||||
object: "model",
|
||||
created: Int(Date().timeIntervalSince1970),
|
||||
owned_by: "local",
|
||||
context_window: config.contextLength
|
||||
)
|
||||
}
|
||||
let response = APIModelListResponse(object: "list", data: models)
|
||||
|
||||
if let json = try? JSONEncoder().encode(response) {
|
||||
sendResponse(connection: connection, status: 200, body: String(data: json, encoding: .utf8) ?? "{}")
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - POST /v1/chat/completions
|
||||
|
||||
private func handleChatCompletions(connection: NWConnection, body: Data?) async {
|
||||
guard let body, let request = try? JSONDecoder().decode(APIChatCompletionRequest.self, from: body) else {
|
||||
sendResponse(connection: connection, status: 400, body: #"{"error":"Invalid request body"}"#)
|
||||
return
|
||||
}
|
||||
|
||||
guard let modelManager else {
|
||||
sendResponse(connection: connection, status: 503, body: #"{"error":"No model loaded"}"#)
|
||||
return
|
||||
}
|
||||
|
||||
// Model swapping: if the request specifies a different model, swap to it
|
||||
if let requestedModel = request.model, !requestedModel.isEmpty {
|
||||
if let targetConfig = ModelConfig.resolve(requestedModel) {
|
||||
if modelManager.currentModel?.id != targetConfig.id {
|
||||
print("[APIServer] Swapping model: \(modelManager.currentModel?.repoId ?? "none") -> \(targetConfig.repoId)")
|
||||
cachedSession = nil
|
||||
cachedMessages = nil
|
||||
cachedModelId = nil
|
||||
await modelManager.loadModel(targetConfig)
|
||||
}
|
||||
}
|
||||
// If we can't resolve the model, continue with whatever is loaded
|
||||
}
|
||||
|
||||
guard modelManager.isReady, let container = modelManager.modelContainer else {
|
||||
sendResponse(connection: connection, status: 503, body: #"{"error":"No model loaded"}"#)
|
||||
return
|
||||
}
|
||||
|
||||
let isStream = request.stream ?? false
|
||||
let temperature = request.temperature ?? 0.7
|
||||
let topP = request.top_p ?? 1.0
|
||||
let maxTokens = request.max_tokens ?? 4096
|
||||
let requestId = "chatcmpl-\(UUID().uuidString.prefix(12).lowercased())"
|
||||
let created = Int(Date().timeIntervalSince1970)
|
||||
let modelName = request.model ?? modelManager.currentModel?.repoId ?? "unknown"
|
||||
let contextLength = modelManager.currentModel?.contextLength ?? 0
|
||||
|
||||
// Convert API messages to Chat.Message, extracting images from content parts
|
||||
var chatMessages: [Chat.Message] = []
|
||||
var images: [UserInput.Image] = []
|
||||
let currentModelRepoId = modelManager.currentModel?.repoId ?? modelName
|
||||
|
||||
// Inject tool definitions into the system prompt if tools are provided
|
||||
if let tools = request.tools, !tools.isEmpty {
|
||||
let toolSystemPrompt = ToolPromptBuilder.buildSystemPrompt(tools: tools, modelId: currentModelRepoId)
|
||||
|
||||
// Check if there's already a system message
|
||||
let hasSystem = request.messages.contains { $0.role == "system" }
|
||||
if hasSystem {
|
||||
// Append tool prompt to existing system message (handled below during conversion)
|
||||
} else {
|
||||
// For Gemma: inject as user message (Gemma doesn't support system role natively)
|
||||
// For Qwen: inject as system message
|
||||
if currentModelRepoId.lowercased().contains("qwen") {
|
||||
chatMessages.append(Chat.Message(role: .system, content: toolSystemPrompt))
|
||||
} else {
|
||||
chatMessages.append(Chat.Message(role: .user, content: toolSystemPrompt))
|
||||
chatMessages.append(Chat.Message(role: .assistant, content: "Understood. I will use the provided tools when appropriate."))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let toolsForInjection = request.tools
|
||||
let isQwen = currentModelRepoId.lowercased().contains("qwen")
|
||||
|
||||
for msg in request.messages {
|
||||
let role: Chat.Message.Role = switch msg.role {
|
||||
case "system": .system
|
||||
case "assistant": .assistant
|
||||
case "tool": .user
|
||||
default: .user
|
||||
}
|
||||
|
||||
var text = msg.content?.textContent ?? ""
|
||||
|
||||
// If this is a system message and tools are provided, append tool definitions
|
||||
if msg.role == "system", let tools = toolsForInjection, !tools.isEmpty {
|
||||
let toolSystemPrompt = ToolPromptBuilder.buildSystemPrompt(tools: tools, modelId: currentModelRepoId)
|
||||
text = text + "\n\n" + toolSystemPrompt
|
||||
}
|
||||
|
||||
// Format tool_call_id responses as tool_output for the model
|
||||
if msg.role == "tool" {
|
||||
if isQwen {
|
||||
// Qwen expects tool results as-is in a user message
|
||||
// (the role is already mapped to .user above)
|
||||
} else {
|
||||
// Gemma expects tool results wrapped in ```tool_output``` blocks
|
||||
text = "```tool_output\n\(text)\n```"
|
||||
}
|
||||
}
|
||||
|
||||
// Format assistant tool_calls back into model-native format
|
||||
if msg.role == "assistant", let toolCalls = msg.tool_calls, !toolCalls.isEmpty {
|
||||
let formattedCalls: String
|
||||
if isQwen {
|
||||
formattedCalls = ToolPromptBuilder.formatQwenToolCalls(toolCalls)
|
||||
} else {
|
||||
formattedCalls = ToolPromptBuilder.formatGemmaToolCalls(toolCalls)
|
||||
}
|
||||
text = (text.isEmpty ? "" : text + "\n") + formattedCalls
|
||||
}
|
||||
|
||||
// Extract base64 images from content parts
|
||||
let imageURLs = msg.content?.imageURLs ?? []
|
||||
var messageImages: [UserInput.Image] = []
|
||||
for urlString in imageURLs {
|
||||
if let image = decodeBase64Image(urlString) {
|
||||
messageImages.append(image)
|
||||
}
|
||||
}
|
||||
|
||||
// Attach images to this specific message
|
||||
chatMessages.append(Chat.Message(role: role, content: text, images: messageImages))
|
||||
images.append(contentsOf: messageImages)
|
||||
}
|
||||
|
||||
// Context window check: estimate token count and reject if over limit
|
||||
if contextLength > 0 {
|
||||
let totalChars = chatMessages.reduce(0) { $0 + $1.content.count }
|
||||
let estimatedTokens = totalChars * 10 / 35 // ~3.5 chars per token
|
||||
let needed = estimatedTokens + maxTokens
|
||||
if needed > contextLength {
|
||||
let errorBody = """
|
||||
{"error":{"message":"This model's maximum context length is \(contextLength) tokens. \
|
||||
However, your messages resulted in approximately \(estimatedTokens) tokens and \
|
||||
\(maxTokens) tokens were requested for the completion (\(needed) total). \
|
||||
Please reduce the length of the messages or completion.",\
|
||||
"type":"invalid_request_error","code":"context_length_exceeded"}}
|
||||
"""
|
||||
sendResponse(connection: connection, status: 400, body: errorBody)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
let generateParams = GenerateParameters(
|
||||
maxTokens: maxTokens,
|
||||
temperature: Float(temperature),
|
||||
topP: Float(topP)
|
||||
)
|
||||
|
||||
// Feed all messages except the last as history, then send the last as the prompt
|
||||
let allButLast = Array(chatMessages.dropLast())
|
||||
let lastMessage = chatMessages.last ?? Chat.Message(role: .user, content: "")
|
||||
|
||||
// KV cache reuse: check if the cached session's history matches
|
||||
let currentModelId = modelManager.currentModel?.id
|
||||
let canReuse = cachedSession != nil
|
||||
&& cachedModelId == currentModelId
|
||||
&& cachedMessages != nil
|
||||
&& messagesMatch(cachedMessages!, allButLast)
|
||||
|
||||
let session: ChatSession
|
||||
if canReuse {
|
||||
print("[APIServer] Reusing cached session (\(allButLast.count) history messages)")
|
||||
session = cachedSession!
|
||||
session.generateParameters = generateParams
|
||||
} else {
|
||||
if cachedSession != nil {
|
||||
print("[APIServer] History diverged, creating fresh session")
|
||||
}
|
||||
if !allButLast.isEmpty {
|
||||
session = ChatSession(
|
||||
container,
|
||||
history: allButLast,
|
||||
generateParameters: generateParams
|
||||
)
|
||||
} else {
|
||||
session = ChatSession(
|
||||
container,
|
||||
generateParameters: generateParams
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// Extract images from the last message only (ChatSession.streamDetails takes images separately)
|
||||
let lastImages = lastMessage.images
|
||||
|
||||
if isStream {
|
||||
await handleStreamingResponse(
|
||||
connection: connection,
|
||||
session: session,
|
||||
prompt: lastMessage.content,
|
||||
images: lastImages,
|
||||
tools: request.tools,
|
||||
requestId: requestId,
|
||||
created: created,
|
||||
modelName: modelName
|
||||
)
|
||||
} else {
|
||||
await handleNonStreamingResponse(
|
||||
connection: connection,
|
||||
session: session,
|
||||
prompt: lastMessage.content,
|
||||
images: lastImages,
|
||||
tools: request.tools,
|
||||
requestId: requestId,
|
||||
created: created,
|
||||
modelName: modelName
|
||||
)
|
||||
}
|
||||
|
||||
// Cache the session for reuse on next request
|
||||
// allButLast + lastMessage (user) + assistant response = new cached history
|
||||
cachedSession = session
|
||||
cachedMessages = chatMessages // full messages including the one just sent
|
||||
cachedModelId = currentModelId
|
||||
}
|
||||
|
||||
/// Decode a base64 data URI (data:image/png;base64,...) into a UserInput.Image.
|
||||
private func decodeBase64Image(_ urlString: String) -> UserInput.Image? {
|
||||
// Handle data URIs: data:image/png;base64,<data>
|
||||
let base64String: String
|
||||
if urlString.hasPrefix("data:") {
|
||||
guard let commaIndex = urlString.firstIndex(of: ",") else { return nil }
|
||||
base64String = String(urlString[urlString.index(after: commaIndex)...])
|
||||
} else {
|
||||
// Could be a plain base64 string
|
||||
base64String = urlString
|
||||
}
|
||||
|
||||
guard let data = Data(base64Encoded: base64String),
|
||||
let nsImage = NSImage(data: data),
|
||||
let cgImage = nsImage.cgImage(forProposedRect: nil, context: nil, hints: nil) else {
|
||||
return nil
|
||||
}
|
||||
|
||||
return .ciImage(CIImage(cgImage: cgImage))
|
||||
}
|
||||
|
||||
// MARK: - Non-streaming response
|
||||
|
||||
private func handleNonStreamingResponse(
|
||||
connection: NWConnection,
|
||||
session: ChatSession,
|
||||
prompt: String,
|
||||
images: [UserInput.Image],
|
||||
tools: [APIToolDefinition]?,
|
||||
requestId: String,
|
||||
created: Int,
|
||||
modelName: String
|
||||
) async {
|
||||
do {
|
||||
var fullText = ""
|
||||
var promptTokens = 0
|
||||
var completionTokens = 0
|
||||
var frameworkToolCalls: [MLXLMCommon.ToolCall] = []
|
||||
|
||||
let stream = session.streamDetails(
|
||||
to: prompt,
|
||||
images: images,
|
||||
videos: []
|
||||
)
|
||||
|
||||
for try await generation in stream {
|
||||
switch generation {
|
||||
case .chunk(let text):
|
||||
fullText += text
|
||||
case .info(let info):
|
||||
promptTokens = info.promptTokenCount
|
||||
completionTokens = info.generationTokenCount
|
||||
case .toolCall(let call):
|
||||
frameworkToolCalls.append(call)
|
||||
}
|
||||
}
|
||||
|
||||
// Parse tool calls: first check framework-detected ones, then our own text parser
|
||||
var finishReason = "stop"
|
||||
var responseContent: String? = fullText
|
||||
var apiToolCalls: [APIToolCall]? = nil
|
||||
|
||||
if !frameworkToolCalls.isEmpty {
|
||||
// Framework natively detected tool calls (e.g. Qwen)
|
||||
finishReason = "tool_calls"
|
||||
apiToolCalls = frameworkToolCalls.enumerated().map { i, tc in
|
||||
let argsJSON: String
|
||||
let argsDict = tc.function.arguments.mapValues { $0.anyValue }
|
||||
if let data = try? JSONSerialization.data(withJSONObject: argsDict),
|
||||
let str = String(data: data, encoding: .utf8) {
|
||||
argsJSON = str
|
||||
} else {
|
||||
argsJSON = "{}"
|
||||
}
|
||||
let callId = String(format: "call_%d_%08d", i, abs(tc.function.name.hashValue) % 100_000_000)
|
||||
return APIToolCall(
|
||||
index: i,
|
||||
id: callId,
|
||||
type: "function",
|
||||
function: APIFunctionCall(name: tc.function.name, arguments: argsJSON)
|
||||
)
|
||||
}
|
||||
responseContent = fullText.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty ? nil : fullText
|
||||
} else if let tools, !tools.isEmpty {
|
||||
// Try our own text parser (e.g. Gemma tool_code blocks)
|
||||
let (cleanText, parsedCalls) = ToolCallParser.parse(text: fullText, tools: tools)
|
||||
if !parsedCalls.isEmpty {
|
||||
finishReason = "tool_calls"
|
||||
apiToolCalls = parsedCalls.enumerated().map { i, tc in
|
||||
APIToolCall(
|
||||
index: i,
|
||||
id: tc.id,
|
||||
type: "function",
|
||||
function: APIFunctionCall(name: tc.name, arguments: tc.arguments)
|
||||
)
|
||||
}
|
||||
responseContent = cleanText.isEmpty ? nil : cleanText
|
||||
}
|
||||
}
|
||||
|
||||
let response = APIChatCompletionResponse(
|
||||
id: requestId,
|
||||
object: "chat.completion",
|
||||
created: created,
|
||||
model: modelName,
|
||||
choices: [
|
||||
APIChoice(
|
||||
index: 0,
|
||||
message: APIChoiceMessage(
|
||||
role: "assistant",
|
||||
content: responseContent,
|
||||
tool_calls: apiToolCalls
|
||||
),
|
||||
finish_reason: finishReason
|
||||
)
|
||||
],
|
||||
usage: APIUsageInfo(
|
||||
prompt_tokens: promptTokens,
|
||||
completion_tokens: completionTokens,
|
||||
total_tokens: promptTokens + completionTokens
|
||||
)
|
||||
)
|
||||
|
||||
if let json = try? JSONEncoder().encode(response) {
|
||||
sendResponse(connection: connection, status: 200, body: String(data: json, encoding: .utf8) ?? "{}")
|
||||
}
|
||||
} catch {
|
||||
sendResponse(connection: connection, status: 500, body: #"{"error":"\#(error.localizedDescription)"}"#)
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Streaming (SSE) response
|
||||
|
||||
private func handleStreamingResponse(
|
||||
connection: NWConnection,
|
||||
session: ChatSession,
|
||||
prompt: String,
|
||||
images: [UserInput.Image],
|
||||
tools: [APIToolDefinition]?,
|
||||
requestId: String,
|
||||
created: Int,
|
||||
modelName: String
|
||||
) async {
|
||||
// Send SSE headers
|
||||
let header = [
|
||||
"HTTP/1.1 200 OK",
|
||||
"Content-Type: text/event-stream",
|
||||
"Cache-Control: no-cache",
|
||||
"Connection: keep-alive",
|
||||
"Access-Control-Allow-Origin: *",
|
||||
"",
|
||||
"",
|
||||
].joined(separator: "\r\n")
|
||||
|
||||
let headerSent = await withCheckedContinuation { continuation in
|
||||
connection.send(content: header.data(using: .utf8), completion: .contentProcessed({ _ in
|
||||
continuation.resume(returning: true)
|
||||
}))
|
||||
}
|
||||
guard headerSent else { return }
|
||||
|
||||
// Send initial role chunk
|
||||
sendSSEEvent(connection: connection, chunk: APIChatCompletionChunk(
|
||||
id: requestId,
|
||||
object: "chat.completion.chunk",
|
||||
created: created,
|
||||
model: modelName,
|
||||
choices: [APIStreamChoice(index: 0, delta: APIDeltaMessage(role: "assistant", content: nil, tool_calls: nil), finish_reason: nil)],
|
||||
usage: nil
|
||||
))
|
||||
|
||||
// When tools are available, buffer full response to parse tool calls
|
||||
// (otherwise raw tool-call markup leaks into streamed text)
|
||||
let bufferForTools = tools != nil && !(tools?.isEmpty ?? true)
|
||||
|
||||
var promptTokens = 0
|
||||
var completionTokens = 0
|
||||
var fullText = ""
|
||||
var frameworkToolCalls: [MLXLMCommon.ToolCall] = []
|
||||
|
||||
do {
|
||||
let stream = session.streamDetails(
|
||||
to: prompt,
|
||||
images: images,
|
||||
videos: []
|
||||
)
|
||||
|
||||
for try await generation in stream {
|
||||
switch generation {
|
||||
case .chunk(let text):
|
||||
completionTokens += 1
|
||||
fullText += text
|
||||
|
||||
if !bufferForTools {
|
||||
sendSSEEvent(connection: connection, chunk: APIChatCompletionChunk(
|
||||
id: requestId,
|
||||
object: "chat.completion.chunk",
|
||||
created: created,
|
||||
model: modelName,
|
||||
choices: [APIStreamChoice(index: 0, delta: APIDeltaMessage(role: nil, content: text, tool_calls: nil), finish_reason: nil)],
|
||||
usage: nil
|
||||
))
|
||||
}
|
||||
|
||||
case .info(let info):
|
||||
promptTokens = info.promptTokenCount
|
||||
completionTokens = info.generationTokenCount
|
||||
|
||||
case .toolCall(let call):
|
||||
frameworkToolCalls.append(call)
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
let errorEvent = "data: {\"error\":\"\(error.localizedDescription)\"}\n\n"
|
||||
connection.send(content: errorEvent.data(using: .utf8), completion: .contentProcessed({ _ in }))
|
||||
}
|
||||
|
||||
// Post-generation: handle tool calls (framework-detected or text-parsed)
|
||||
var finishReason = "stop"
|
||||
|
||||
if !frameworkToolCalls.isEmpty {
|
||||
// Framework natively detected tool calls (e.g. Qwen)
|
||||
finishReason = "tool_calls"
|
||||
|
||||
// Emit any buffered text content
|
||||
if !fullText.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty {
|
||||
sendSSEEvent(connection: connection, chunk: APIChatCompletionChunk(
|
||||
id: requestId,
|
||||
object: "chat.completion.chunk",
|
||||
created: created,
|
||||
model: modelName,
|
||||
choices: [APIStreamChoice(index: 0, delta: APIDeltaMessage(role: nil, content: fullText, tool_calls: nil), finish_reason: nil)],
|
||||
usage: nil
|
||||
))
|
||||
}
|
||||
|
||||
// Emit tool call chunks
|
||||
for (i, tc) in frameworkToolCalls.enumerated() {
|
||||
let argsDict = tc.function.arguments.mapValues { $0.anyValue }
|
||||
let argsJSON: String
|
||||
if let data = try? JSONSerialization.data(withJSONObject: argsDict),
|
||||
let str = String(data: data, encoding: .utf8) {
|
||||
argsJSON = str
|
||||
} else {
|
||||
argsJSON = "{}"
|
||||
}
|
||||
let callId = String(format: "call_%d_%08d", i, abs(tc.function.name.hashValue) % 100_000_000)
|
||||
let apiToolCall = APIToolCall(
|
||||
index: i,
|
||||
id: callId,
|
||||
type: "function",
|
||||
function: APIFunctionCall(name: tc.function.name, arguments: argsJSON)
|
||||
)
|
||||
sendSSEEvent(connection: connection, chunk: APIChatCompletionChunk(
|
||||
id: requestId,
|
||||
object: "chat.completion.chunk",
|
||||
created: created,
|
||||
model: modelName,
|
||||
choices: [APIStreamChoice(index: 0, delta: APIDeltaMessage(role: nil, content: nil, tool_calls: [apiToolCall]), finish_reason: nil)],
|
||||
usage: nil
|
||||
))
|
||||
}
|
||||
} else if bufferForTools {
|
||||
// Text-parsed tool calls (e.g. Gemma tool_code blocks)
|
||||
let (cleanText, parsed) = ToolCallParser.parse(text: fullText, tools: tools)
|
||||
if !parsed.isEmpty {
|
||||
finishReason = "tool_calls"
|
||||
fullText = cleanText
|
||||
}
|
||||
|
||||
// Emit buffered content (cleaned of tool-call markup)
|
||||
if !fullText.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty {
|
||||
sendSSEEvent(connection: connection, chunk: APIChatCompletionChunk(
|
||||
id: requestId,
|
||||
object: "chat.completion.chunk",
|
||||
created: created,
|
||||
model: modelName,
|
||||
choices: [APIStreamChoice(index: 0, delta: APIDeltaMessage(role: nil, content: fullText, tool_calls: nil), finish_reason: nil)],
|
||||
usage: nil
|
||||
))
|
||||
}
|
||||
|
||||
// Emit tool call chunks
|
||||
for (i, tc) in parsed.enumerated() {
|
||||
let apiToolCall = APIToolCall(
|
||||
index: i,
|
||||
id: tc.id,
|
||||
type: "function",
|
||||
function: APIFunctionCall(name: tc.name, arguments: tc.arguments)
|
||||
)
|
||||
sendSSEEvent(connection: connection, chunk: APIChatCompletionChunk(
|
||||
id: requestId,
|
||||
object: "chat.completion.chunk",
|
||||
created: created,
|
||||
model: modelName,
|
||||
choices: [APIStreamChoice(index: 0, delta: APIDeltaMessage(role: nil, content: nil, tool_calls: [apiToolCall]), finish_reason: nil)],
|
||||
usage: nil
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
// Final chunk with finish_reason and usage
|
||||
sendSSEEvent(connection: connection, chunk: APIChatCompletionChunk(
|
||||
id: requestId,
|
||||
object: "chat.completion.chunk",
|
||||
created: created,
|
||||
model: modelName,
|
||||
choices: [APIStreamChoice(index: 0, delta: APIDeltaMessage(role: nil, content: nil, tool_calls: nil), finish_reason: finishReason)],
|
||||
usage: APIUsageInfo(
|
||||
prompt_tokens: promptTokens,
|
||||
completion_tokens: completionTokens,
|
||||
total_tokens: promptTokens + completionTokens
|
||||
)
|
||||
))
|
||||
|
||||
// Send [DONE] and close
|
||||
let done = "data: [DONE]\n\n"
|
||||
connection.send(content: done.data(using: .utf8), completion: .contentProcessed({ _ in
|
||||
connection.cancel()
|
||||
}))
|
||||
}
|
||||
|
||||
private func sendSSEEvent(connection: NWConnection, chunk: APIChatCompletionChunk) {
|
||||
guard let json = try? JSONEncoder().encode(chunk),
|
||||
let jsonString = String(data: json, encoding: .utf8) else { return }
|
||||
let event = "data: \(jsonString)\n\n"
|
||||
connection.send(content: event.data(using: .utf8), completion: .contentProcessed({ _ in }))
|
||||
}
|
||||
|
||||
// MARK: - HTTP helpers
|
||||
|
||||
private func sendResponse(
|
||||
connection: NWConnection,
|
||||
status: Int,
|
||||
body: String,
|
||||
extraHeaders: [String] = []
|
||||
) {
|
||||
let statusText = switch status {
|
||||
case 200: "OK"
|
||||
case 400: "Bad Request"
|
||||
case 404: "Not Found"
|
||||
case 500: "Internal Server Error"
|
||||
case 503: "Service Unavailable"
|
||||
default: "Error"
|
||||
}
|
||||
|
||||
var headers = [
|
||||
"HTTP/1.1 \(status) \(statusText)",
|
||||
"Content-Type: application/json",
|
||||
"Access-Control-Allow-Origin: *",
|
||||
"Access-Control-Allow-Methods: GET, POST, OPTIONS",
|
||||
"Access-Control-Allow-Headers: Content-Type, Authorization",
|
||||
"Content-Length: \(body.utf8.count)",
|
||||
]
|
||||
headers.append(contentsOf: extraHeaders)
|
||||
headers.append("")
|
||||
headers.append("")
|
||||
|
||||
let response = headers.joined(separator: "\r\n") + body
|
||||
connection.send(content: response.data(using: .utf8), completion: .contentProcessed({ _ in
|
||||
connection.cancel()
|
||||
}))
|
||||
}
|
||||
|
||||
private func corsHeaders() -> [String] {
|
||||
[
|
||||
"Access-Control-Allow-Methods: GET, POST, OPTIONS",
|
||||
"Access-Control-Allow-Headers: Content-Type, Authorization",
|
||||
"Access-Control-Max-Age: 86400",
|
||||
]
|
||||
}
|
||||
|
||||
/// Check if cached messages are a prefix of new messages (for KV cache reuse).
|
||||
/// The cached messages include the full history from the previous request.
|
||||
/// For cache reuse, all but the last message of the new request must match
|
||||
/// all but the last message of the cached messages (the cached last was the
|
||||
/// previous user prompt, which is now part of the history).
|
||||
private func messagesMatch(_ cached: [Chat.Message], _ newHistory: [Chat.Message]) -> Bool {
|
||||
// The cached messages are the full chatMessages from the previous request.
|
||||
// For the cache to be reusable, the new history (allButLast) must match
|
||||
// exactly what the session has already processed.
|
||||
// After a request, the session has seen: cachedMessages' history + prompt + response.
|
||||
// So on the next request, if newHistory == cachedMessages, the session already
|
||||
// contains all of those turns and we can just send the new last message.
|
||||
guard cached.count == newHistory.count else { return false }
|
||||
for (a, b) in zip(cached, newHistory) {
|
||||
if a.role != b.role || a.content != b.content { return false }
|
||||
}
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - HTTP request parser
|
||||
|
||||
private struct HTTPRequest {
|
||||
let method: String
|
||||
let path: String
|
||||
let headers: [String: String]
|
||||
let body: Data?
|
||||
|
||||
/// Parse raw HTTP data into a structured request.
|
||||
/// Uses raw Data operations to correctly handle binary body content.
|
||||
static func parse(_ data: Data) -> HTTPRequest? {
|
||||
// Find \r\n\r\n boundary between headers and body
|
||||
let separator: [UInt8] = [0x0D, 0x0A, 0x0D, 0x0A] // \r\n\r\n
|
||||
guard let separatorRange = data.firstRange(of: Data(separator)) else {
|
||||
// No complete header yet — might need more data
|
||||
// But if data is large enough, treat as malformed
|
||||
return data.count > 65536 ? nil : nil
|
||||
}
|
||||
|
||||
let headerData = data[data.startIndex..<separatorRange.lowerBound]
|
||||
guard let headerString = String(data: headerData, encoding: .utf8) else { return nil }
|
||||
|
||||
let lines = headerString.components(separatedBy: "\r\n")
|
||||
guard let requestLine = lines.first else { return nil }
|
||||
|
||||
let parts = requestLine.split(separator: " ", maxSplits: 2)
|
||||
guard parts.count >= 2 else { return nil }
|
||||
|
||||
let method = String(parts[0])
|
||||
let fullPath = String(parts[1])
|
||||
let path = fullPath.components(separatedBy: "?").first ?? fullPath
|
||||
|
||||
var headers: [String: String] = [:]
|
||||
for line in lines.dropFirst() {
|
||||
let kv = line.split(separator: ":", maxSplits: 1)
|
||||
if kv.count == 2 {
|
||||
headers[String(kv[0]).trimmingCharacters(in: .whitespaces).lowercased()] =
|
||||
String(kv[1]).trimmingCharacters(in: .whitespaces)
|
||||
}
|
||||
}
|
||||
|
||||
// Body is everything after \r\n\r\n
|
||||
let bodyStart = separatorRange.upperBound
|
||||
let body: Data? = bodyStart < data.endIndex ? data[bodyStart..<data.endIndex] : nil
|
||||
|
||||
return HTTPRequest(method: method, path: path, headers: headers, body: body)
|
||||
}
|
||||
}
|
||||
190
MLXServer/Server/ToolCallParser.swift
Normal file
190
MLXServer/Server/ToolCallParser.swift
Normal file
@@ -0,0 +1,190 @@
|
||||
import Foundation
|
||||
|
||||
/// Parses tool calls from model output text.
|
||||
/// Supports both Gemma's ```tool_code``` blocks and Qwen's <tool_call> XML tags.
|
||||
enum ToolCallParser {
|
||||
|
||||
struct ParsedToolCall {
|
||||
let id: String
|
||||
let name: String
|
||||
let arguments: String // JSON string
|
||||
}
|
||||
|
||||
/// Parse tool calls from model output. Returns (cleanText, toolCalls).
|
||||
static func parse(text: String, tools: [APIToolDefinition]?) -> (String, [ParsedToolCall]) {
|
||||
// Try Qwen-style first (<tool_call> tags)
|
||||
let (qwenClean, qwenCalls) = parseQwen(text: text)
|
||||
if !qwenCalls.isEmpty {
|
||||
return (qwenClean, qwenCalls)
|
||||
}
|
||||
|
||||
// Try Gemma-style (```tool_code``` blocks)
|
||||
let (gemmaClean, gemmaCalls) = parseGemma(text: text, tools: tools)
|
||||
if !gemmaCalls.isEmpty {
|
||||
return (gemmaClean, gemmaCalls)
|
||||
}
|
||||
|
||||
return (text, [])
|
||||
}
|
||||
|
||||
// MARK: - Gemma: ```tool_code``` blocks
|
||||
|
||||
/// Parse Gemma's tool_code blocks: ```tool_code\nfunc_name(arg="value")\n```
|
||||
private static func parseGemma(text: String, tools: [APIToolDefinition]?) -> (String, [ParsedToolCall]) {
|
||||
let pattern = #"```tool_code\s*(.*?)\s*```"#
|
||||
guard let regex = try? NSRegularExpression(pattern: pattern, options: .dotMatchesLineSeparators) else {
|
||||
return (text, [])
|
||||
}
|
||||
|
||||
let nsText = text as NSString
|
||||
let matches = regex.matches(in: text, range: NSRange(location: 0, length: nsText.length))
|
||||
guard !matches.isEmpty else { return (text, []) }
|
||||
|
||||
var toolCalls: [ParsedToolCall] = []
|
||||
|
||||
// Build tool definitions map for parameter inference
|
||||
var toolDefs: [String: [String]] = [:]
|
||||
if let tools {
|
||||
for tool in tools {
|
||||
let paramNames = tool.function.parameters?["properties"]?.value as? [String: Any]
|
||||
toolDefs[tool.function.name] = paramNames.map { Array($0.keys).sorted() } ?? []
|
||||
}
|
||||
}
|
||||
|
||||
for (i, match) in matches.enumerated() {
|
||||
let callStr = nsText.substring(with: match.range(at: 1)).trimmingCharacters(in: .whitespacesAndNewlines)
|
||||
|
||||
if let (name, args) = parsePythonCall(callStr, toolDefs: toolDefs) {
|
||||
let argsJSON = (try? JSONSerialization.data(withJSONObject: args))
|
||||
.flatMap { String(data: $0, encoding: .utf8) } ?? "{}"
|
||||
let callId = String(format: "call_%d_%08d", i, abs(callStr.hashValue) % 100_000_000)
|
||||
toolCalls.append(ParsedToolCall(id: callId, name: name, arguments: argsJSON))
|
||||
}
|
||||
}
|
||||
|
||||
// Remove tool_code blocks from text
|
||||
let cleanText = regex.stringByReplacingMatches(
|
||||
in: text, range: NSRange(location: 0, length: nsText.length),
|
||||
withTemplate: ""
|
||||
).trimmingCharacters(in: .whitespacesAndNewlines)
|
||||
|
||||
return (cleanText, toolCalls)
|
||||
}
|
||||
|
||||
/// Parse a Python-style function call: func_name(arg1="value", arg2=42)
|
||||
private static func parsePythonCall(_ callStr: String, toolDefs: [String: [String]]) -> (String, [String: Any])? {
|
||||
// Match: func_name(args...)
|
||||
let pattern = #"^(\w+)\s*\((.*)\)\s*$"#
|
||||
guard let regex = try? NSRegularExpression(pattern: pattern, options: .dotMatchesLineSeparators) else {
|
||||
return nil
|
||||
}
|
||||
|
||||
let nsCall = callStr as NSString
|
||||
let match = regex.firstMatch(in: callStr, range: NSRange(location: 0, length: nsCall.length))
|
||||
guard let match else {
|
||||
// Shell-style: "func_name arg1"
|
||||
let parts = callStr.split(separator: " ", maxSplits: 1)
|
||||
guard let name = parts.first.map(String.init), !name.isEmpty else { return nil }
|
||||
if parts.count > 1 {
|
||||
let argValue = String(parts[1]).trimmingCharacters(in: .whitespacesAndNewlines)
|
||||
let paramNames = toolDefs[name] ?? []
|
||||
let key = paramNames.first ?? "arg0"
|
||||
return (name, [key: argValue])
|
||||
}
|
||||
return (name, [:])
|
||||
}
|
||||
|
||||
let name = nsCall.substring(with: match.range(at: 1))
|
||||
let argsStr = nsCall.substring(with: match.range(at: 2)).trimmingCharacters(in: .whitespacesAndNewlines)
|
||||
|
||||
if argsStr.isEmpty {
|
||||
return (name, [:])
|
||||
}
|
||||
|
||||
// Parse keyword arguments: key="value", key2=42
|
||||
var args: [String: Any] = [:]
|
||||
let kwPattern = #"(\w+)\s*=\s*("(?:[^"\\]|\\.)*"|'(?:[^'\\]|\\.)*'|[^,]+)"#
|
||||
if let kwRegex = try? NSRegularExpression(pattern: kwPattern, options: []) {
|
||||
let kwMatches = kwRegex.matches(in: argsStr, range: NSRange(location: 0, length: (argsStr as NSString).length))
|
||||
for kwMatch in kwMatches {
|
||||
let key = (argsStr as NSString).substring(with: kwMatch.range(at: 1))
|
||||
var val = (argsStr as NSString).substring(with: kwMatch.range(at: 2)).trimmingCharacters(in: .whitespacesAndNewlines)
|
||||
// Strip quotes
|
||||
if (val.hasPrefix("\"") && val.hasSuffix("\"")) || (val.hasPrefix("'") && val.hasSuffix("'")) {
|
||||
val = String(val.dropFirst().dropLast())
|
||||
}
|
||||
// Try to parse as number/bool
|
||||
if let intVal = Int(val) { args[key] = intVal }
|
||||
else if let doubleVal = Double(val) { args[key] = doubleVal }
|
||||
else if val == "True" || val == "true" { args[key] = true }
|
||||
else if val == "False" || val == "false" { args[key] = false }
|
||||
else { args[key] = val }
|
||||
}
|
||||
}
|
||||
|
||||
// If no keyword args found, try positional
|
||||
if args.isEmpty {
|
||||
let paramNames = toolDefs[name] ?? []
|
||||
// Try splitting by comma for positional args
|
||||
let positionals = argsStr.split(separator: ",").map {
|
||||
$0.trimmingCharacters(in: .whitespacesAndNewlines)
|
||||
}
|
||||
for (i, pos) in positionals.enumerated() {
|
||||
var val = pos
|
||||
if (val.hasPrefix("\"") && val.hasSuffix("\"")) || (val.hasPrefix("'") && val.hasSuffix("'")) {
|
||||
val = String(val.dropFirst().dropLast())
|
||||
}
|
||||
let key = i < paramNames.count ? paramNames[i] : "arg\(i)"
|
||||
args[key] = val
|
||||
}
|
||||
}
|
||||
|
||||
return (name, args)
|
||||
}
|
||||
|
||||
// MARK: - Qwen: <tool_call> XML tags
|
||||
|
||||
/// Parse Qwen's tool_call tags: <tool_call>{"name":"func","arguments":{...}}</tool_call>
|
||||
private static func parseQwen(text: String) -> (String, [ParsedToolCall]) {
|
||||
let pattern = #"<tool_call>\s*(.*?)\s*</tool_call>"#
|
||||
guard let regex = try? NSRegularExpression(pattern: pattern, options: .dotMatchesLineSeparators) else {
|
||||
return (text, [])
|
||||
}
|
||||
|
||||
let nsText = text as NSString
|
||||
let matches = regex.matches(in: text, range: NSRange(location: 0, length: nsText.length))
|
||||
guard !matches.isEmpty else { return (text, []) }
|
||||
|
||||
var toolCalls: [ParsedToolCall] = []
|
||||
|
||||
for (i, match) in matches.enumerated() {
|
||||
let jsonStr = nsText.substring(with: match.range(at: 1)).trimmingCharacters(in: .whitespacesAndNewlines)
|
||||
|
||||
guard let data = jsonStr.data(using: .utf8),
|
||||
let obj = try? JSONSerialization.jsonObject(with: data) as? [String: Any],
|
||||
let name = obj["name"] as? String else {
|
||||
continue
|
||||
}
|
||||
|
||||
var argsJSON = "{}"
|
||||
if let args = obj["arguments"] {
|
||||
if let argsDict = args as? [String: Any],
|
||||
let argsData = try? JSONSerialization.data(withJSONObject: argsDict) {
|
||||
argsJSON = String(data: argsData, encoding: .utf8) ?? "{}"
|
||||
} else if let argsStr = args as? String {
|
||||
argsJSON = argsStr
|
||||
}
|
||||
}
|
||||
|
||||
let callId = String(format: "call_%d_%08d", i, abs(jsonStr.hashValue) % 100_000_000)
|
||||
toolCalls.append(ParsedToolCall(id: callId, name: name, arguments: argsJSON))
|
||||
}
|
||||
|
||||
let cleanText = regex.stringByReplacingMatches(
|
||||
in: text, range: NSRange(location: 0, length: nsText.length),
|
||||
withTemplate: ""
|
||||
).trimmingCharacters(in: .whitespacesAndNewlines)
|
||||
|
||||
return (cleanText, toolCalls)
|
||||
}
|
||||
}
|
||||
199
MLXServer/Server/ToolPromptBuilder.swift
Normal file
199
MLXServer/Server/ToolPromptBuilder.swift
Normal file
@@ -0,0 +1,199 @@
|
||||
import Foundation
|
||||
|
||||
/// Builds model-specific system prompts that inform the model about available tools.
|
||||
/// Mirrors the Python server's `_build_tool_system_prompt()` and `_build_qwen_tool_system_prompt()`.
|
||||
enum ToolPromptBuilder {
|
||||
|
||||
/// Build a tool system prompt appropriate for the current model.
|
||||
/// - Parameters:
|
||||
/// - tools: OpenAI-format tool definitions
|
||||
/// - modelId: The model's repo ID (to determine format)
|
||||
/// - Returns: A system prompt string describing the available tools
|
||||
static func buildSystemPrompt(tools: [APIToolDefinition], modelId: String) -> String {
|
||||
if modelId.lowercased().contains("qwen") {
|
||||
return buildQwenToolPrompt(tools: tools)
|
||||
} else {
|
||||
return buildGemmaToolPrompt(tools: tools)
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Gemma (tool_code format)
|
||||
|
||||
/// Build the tool system prompt using Google's Gemma 3 tool_code convention.
|
||||
private static func buildGemmaToolPrompt(tools: [APIToolDefinition]) -> String {
|
||||
let funcDefs = tools.map { toolToPythonSignature($0.function) }
|
||||
let functionsBlock = funcDefs.joined(separator: "\n\n")
|
||||
|
||||
return """
|
||||
At each turn, if you decide to invoke any of the function(s), \
|
||||
it should be wrapped with ```tool_code```. \
|
||||
The python methods described below are imported and available, \
|
||||
you can only use defined methods. \
|
||||
The generated code should be readable and efficient. \
|
||||
The response to a method will be wrapped in ```tool_output``` \
|
||||
use it to call more tools or generate a helpful, friendly response.
|
||||
|
||||
\(functionsBlock)
|
||||
"""
|
||||
}
|
||||
|
||||
/// Convert an OpenAI function definition to a Python function signature with docstring.
|
||||
private static func toolToPythonSignature(_ func: APIFunctionDefinition) -> String {
|
||||
let name = `func`.name
|
||||
let desc = `func`.description ?? ""
|
||||
let properties = `func`.parameters?["properties"]?.value as? [String: Any] ?? [:]
|
||||
let requiredArr = `func`.parameters?["required"]?.value as? [String] ?? []
|
||||
let required = Set(requiredArr)
|
||||
|
||||
var paramParts: [String] = []
|
||||
var docArgs: [String] = []
|
||||
|
||||
// Sort keys for deterministic output
|
||||
for pname in properties.keys.sorted() {
|
||||
guard let pinfo = properties[pname] as? [String: Any] else { continue }
|
||||
let ptype = jsonTypeToPython(pinfo["type"] as? String ?? "str")
|
||||
let pdesc = pinfo["description"] as? String ?? ""
|
||||
|
||||
if required.contains(pname) {
|
||||
paramParts.append("\(pname): \(ptype)")
|
||||
} else {
|
||||
let defaultVal = jsonTypeDefault(pinfo["type"] as? String ?? "str")
|
||||
paramParts.append("\(pname): \(ptype) = \(defaultVal)")
|
||||
}
|
||||
docArgs.append(pdesc.isEmpty ? " \(pname)" : " \(pname): \(pdesc)")
|
||||
}
|
||||
|
||||
let sig = "def \(name)(\(paramParts.joined(separator: ", "))):"
|
||||
var docLines = [" \"\"\"\(desc)"]
|
||||
if !docArgs.isEmpty {
|
||||
docLines.append("")
|
||||
docLines.append(" Args:")
|
||||
docLines.append(contentsOf: docArgs)
|
||||
}
|
||||
docLines.append(" \"\"\"")
|
||||
|
||||
return sig + "\n" + docLines.joined(separator: "\n")
|
||||
}
|
||||
|
||||
private static func jsonTypeToPython(_ type: String) -> String {
|
||||
switch type {
|
||||
case "string": return "str"
|
||||
case "integer": return "int"
|
||||
case "number": return "float"
|
||||
case "boolean": return "bool"
|
||||
case "array": return "list"
|
||||
case "object": return "dict"
|
||||
default: return "str"
|
||||
}
|
||||
}
|
||||
|
||||
private static func jsonTypeDefault(_ type: String) -> String {
|
||||
switch type {
|
||||
case "string": return "None"
|
||||
case "integer": return "0"
|
||||
case "number": return "0.0"
|
||||
case "boolean": return "False"
|
||||
case "array": return "[]"
|
||||
case "object": return "{}"
|
||||
default: return "None"
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Qwen (<tool_call> format)
|
||||
|
||||
/// Build the tool system prompt for Qwen3 using its native JSON format.
|
||||
private static func buildQwenToolPrompt(tools: [APIToolDefinition]) -> String {
|
||||
var toolDescs: [[String: Any]] = []
|
||||
for tool in tools {
|
||||
var funcDict: [String: Any] = [
|
||||
"name": tool.function.name,
|
||||
"description": tool.function.description ?? "",
|
||||
]
|
||||
if let params = tool.function.parameters {
|
||||
funcDict["parameters"] = params.mapValues(\.value)
|
||||
}
|
||||
toolDescs.append([
|
||||
"type": "function",
|
||||
"function": funcDict,
|
||||
])
|
||||
}
|
||||
|
||||
let toolsJSON: String
|
||||
if let data = try? JSONSerialization.data(withJSONObject: toolDescs, options: [.prettyPrinted, .sortedKeys]),
|
||||
let str = String(data: data, encoding: .utf8) {
|
||||
toolsJSON = str
|
||||
} else {
|
||||
toolsJSON = "[]"
|
||||
}
|
||||
|
||||
return """
|
||||
# Tools
|
||||
|
||||
You are a helpful assistant with access to the following tools. \
|
||||
Use them when appropriate by responding with a JSON tool call.
|
||||
|
||||
## Available Tools
|
||||
|
||||
\(toolsJSON)
|
||||
|
||||
## Tool Call Format
|
||||
|
||||
When you need to call a tool, respond with:
|
||||
<tool_call>
|
||||
{"name": "<function_name>", "arguments": {<args>}}
|
||||
</tool_call>
|
||||
"""
|
||||
}
|
||||
|
||||
// MARK: - Format tool calls back into model-specific format for prompt history
|
||||
|
||||
/// Format OpenAI-style tool calls back into Gemma's tool_code blocks for prompt history.
|
||||
static func formatGemmaToolCalls(_ toolCalls: [APIToolCall]) -> String {
|
||||
var parts: [String] = []
|
||||
for tc in toolCalls {
|
||||
let name = tc.function.name
|
||||
let argsStr = tc.function.arguments
|
||||
if let data = argsStr.data(using: .utf8),
|
||||
let args = try? JSONSerialization.jsonObject(with: data) as? [String: Any] {
|
||||
let argParts = args.keys.sorted().map { key -> String in
|
||||
let val = args[key]!
|
||||
return "\(key)=\(pythonRepr(val))"
|
||||
}
|
||||
let callStr = "\(name)(\(argParts.joined(separator: ", ")))"
|
||||
parts.append("```tool_code\n\(callStr)\n```")
|
||||
} else {
|
||||
parts.append("```tool_code\n\(name)()\n```")
|
||||
}
|
||||
}
|
||||
return parts.joined(separator: "\n")
|
||||
}
|
||||
|
||||
/// Format OpenAI-style tool calls back into Qwen's <tool_call> tags for prompt history.
|
||||
static func formatQwenToolCalls(_ toolCalls: [APIToolCall]) -> String {
|
||||
var parts: [String] = []
|
||||
for tc in toolCalls {
|
||||
let name = tc.function.name
|
||||
let argsStr = tc.function.arguments
|
||||
var callObj: [String: Any] = ["name": name]
|
||||
if let data = argsStr.data(using: .utf8),
|
||||
let args = try? JSONSerialization.jsonObject(with: data) {
|
||||
callObj["arguments"] = args
|
||||
}
|
||||
if let data = try? JSONSerialization.data(withJSONObject: callObj),
|
||||
let str = String(data: data, encoding: .utf8) {
|
||||
parts.append("<tool_call>\n\(str)\n</tool_call>")
|
||||
}
|
||||
}
|
||||
return parts.joined(separator: "\n")
|
||||
}
|
||||
|
||||
private static func pythonRepr(_ value: Any) -> String {
|
||||
switch value {
|
||||
case let s as String: return "\"\(s)\""
|
||||
case let i as Int: return "\(i)"
|
||||
case let d as Double: return "\(d)"
|
||||
case let b as Bool: return b ? "True" : "False"
|
||||
default: return "\"\(value)\""
|
||||
}
|
||||
}
|
||||
}
|
||||
47
MLXServer/Utilities/LocalModelResolver.swift
Normal file
47
MLXServer/Utilities/LocalModelResolver.swift
Normal file
@@ -0,0 +1,47 @@
|
||||
import Foundation
|
||||
|
||||
/// Resolves HuggingFace model repos to local snapshot directories,
|
||||
/// matching the cache layout used by Python's `huggingface_hub`.
|
||||
///
|
||||
/// Cache structure:
|
||||
/// ~/.cache/huggingface/hub/models--{org}--{name}/snapshots/{hash}/
|
||||
enum LocalModelResolver {
|
||||
|
||||
/// The standard HuggingFace cache directory used by Python's `huggingface_hub`.
|
||||
private static let cacheBase: URL = {
|
||||
FileManager.default.homeDirectoryForCurrentUser
|
||||
.appendingPathComponent(".cache/huggingface/hub", isDirectory: true)
|
||||
}()
|
||||
|
||||
/// Resolve a HuggingFace repo ID (e.g. "mlx-community/gemma-3-4b-it-4bit")
|
||||
/// to its local snapshot directory, if it exists.
|
||||
///
|
||||
/// Returns `nil` if the model hasn't been downloaded yet.
|
||||
static func resolve(repoId: String) -> URL? {
|
||||
// Convert "mlx-community/gemma-3-4b-it-4bit" → "models--mlx-community--gemma-3-4b-it-4bit"
|
||||
let dirName = "models--" + repoId.replacingOccurrences(of: "/", with: "--")
|
||||
let snapshotsDir = cacheBase
|
||||
.appendingPathComponent(dirName, isDirectory: true)
|
||||
.appendingPathComponent("snapshots", isDirectory: true)
|
||||
|
||||
// Find the first (usually only) snapshot hash directory
|
||||
guard let contents = try? FileManager.default.contentsOfDirectory(
|
||||
at: snapshotsDir,
|
||||
includingPropertiesForKeys: [.isDirectoryKey],
|
||||
options: [.skipsHiddenFiles]
|
||||
) else {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Return the most recent snapshot (last alphabetically = latest hash)
|
||||
return contents
|
||||
.filter { (try? $0.resourceValues(forKeys: [.isDirectoryKey]).isDirectory) == true }
|
||||
.sorted(by: { $0.lastPathComponent < $1.lastPathComponent })
|
||||
.last
|
||||
}
|
||||
|
||||
/// Check if a model is available locally.
|
||||
static func isAvailable(repoId: String) -> Bool {
|
||||
resolve(repoId: repoId) != nil
|
||||
}
|
||||
}
|
||||
42
MLXServer/Utilities/Preferences.swift
Normal file
42
MLXServer/Utilities/Preferences.swift
Normal file
@@ -0,0 +1,42 @@
|
||||
import Foundation
|
||||
|
||||
/// Persisted app preferences via UserDefaults.
|
||||
enum Preferences {
|
||||
nonisolated(unsafe) private static let defaults = UserDefaults.standard
|
||||
|
||||
// MARK: - Last used model
|
||||
|
||||
private static let lastModelKey = "lastModelId"
|
||||
|
||||
static var lastModelId: String? {
|
||||
get { defaults.string(forKey: lastModelKey) }
|
||||
set { defaults.set(newValue, forKey: lastModelKey) }
|
||||
}
|
||||
|
||||
// MARK: - System prompt
|
||||
|
||||
private static let systemPromptKey = "systemPrompt"
|
||||
|
||||
static var systemPrompt: String {
|
||||
get { defaults.string(forKey: systemPromptKey) ?? "" }
|
||||
set { defaults.set(newValue, forKey: systemPromptKey) }
|
||||
}
|
||||
|
||||
// MARK: - API server
|
||||
|
||||
private static let apiPortKey = "apiPort"
|
||||
private static let apiAutoStartKey = "apiAutoStart"
|
||||
|
||||
static var apiPort: Int {
|
||||
get {
|
||||
let val = defaults.integer(forKey: apiPortKey)
|
||||
return val > 0 ? val : 1234
|
||||
}
|
||||
set { defaults.set(newValue, forKey: apiPortKey) }
|
||||
}
|
||||
|
||||
static var apiAutoStart: Bool {
|
||||
get { defaults.bool(forKey: apiAutoStartKey) }
|
||||
set { defaults.set(newValue, forKey: apiAutoStartKey) }
|
||||
}
|
||||
}
|
||||
158
MLXServer/ViewModels/ChatViewModel.swift
Normal file
158
MLXServer/ViewModels/ChatViewModel.swift
Normal file
@@ -0,0 +1,158 @@
|
||||
import AppKit
|
||||
import Foundation
|
||||
import MLX
|
||||
import MLXLMCommon
|
||||
import MLXVLM
|
||||
|
||||
/// Drives the chat UI: sending messages, streaming responses, managing images.
|
||||
@Observable
|
||||
@MainActor
|
||||
final class ChatViewModel {
|
||||
var conversation = Conversation()
|
||||
var inputText = ""
|
||||
var attachedImages: [NSImage] = []
|
||||
var isGenerating = false
|
||||
var tokensPerSecond: Double = 0
|
||||
var promptTokens: Int = 0
|
||||
var generationTokens: Int = 0
|
||||
|
||||
private var generationTask: Task<Void, Never>?
|
||||
private var chatSession: ChatSession?
|
||||
|
||||
let modelManager: ModelManager
|
||||
let apiServer = APIServer()
|
||||
|
||||
init(modelManager: ModelManager) {
|
||||
self.modelManager = modelManager
|
||||
}
|
||||
|
||||
/// Ensure a ChatSession exists for the current model.
|
||||
private func ensureSession() {
|
||||
guard let container = modelManager.modelContainer else { return }
|
||||
if chatSession == nil {
|
||||
let systemPrompt = Preferences.systemPrompt
|
||||
chatSession = ChatSession(
|
||||
container,
|
||||
instructions: systemPrompt.isEmpty ? nil : systemPrompt,
|
||||
generateParameters: GenerateParameters(temperature: 0.7)
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func send() {
|
||||
let text = inputText.trimmingCharacters(in: .whitespacesAndNewlines)
|
||||
guard !text.isEmpty, modelManager.isReady else { return }
|
||||
|
||||
ensureSession()
|
||||
guard let session = chatSession else { return }
|
||||
|
||||
let images = attachedImages
|
||||
inputText = ""
|
||||
attachedImages = []
|
||||
|
||||
conversation.addUserMessage(text, images: images)
|
||||
let assistantIndex = conversation.addAssistantMessage()
|
||||
|
||||
isGenerating = true
|
||||
tokensPerSecond = 0
|
||||
promptTokens = 0
|
||||
generationTokens = 0
|
||||
|
||||
// Convert NSImages to UserInput.Image
|
||||
let inputImages: [UserInput.Image] = images.compactMap { nsImage in
|
||||
guard let cgImage = nsImage.cgImage(forProposedRect: nil, context: nil, hints: nil) else {
|
||||
return nil
|
||||
}
|
||||
return .ciImage(CIImage(cgImage: cgImage))
|
||||
}
|
||||
|
||||
generationTask = Task {
|
||||
do {
|
||||
let stream = session.streamDetails(
|
||||
to: text,
|
||||
images: inputImages,
|
||||
videos: []
|
||||
)
|
||||
|
||||
var tokenCount = 0
|
||||
let startTime = Date()
|
||||
|
||||
for try await generation in stream {
|
||||
if Task.isCancelled { break }
|
||||
|
||||
switch generation {
|
||||
case .chunk(let text):
|
||||
conversation.appendToMessage(at: assistantIndex, chunk: text)
|
||||
tokenCount += 1
|
||||
let elapsed = Date().timeIntervalSince(startTime)
|
||||
if elapsed > 0 {
|
||||
tokensPerSecond = Double(tokenCount) / elapsed
|
||||
}
|
||||
generationTokens = tokenCount
|
||||
|
||||
case .info(let info):
|
||||
promptTokens = info.promptTokenCount
|
||||
if info.tokensPerSecond > 0 {
|
||||
tokensPerSecond = info.tokensPerSecond
|
||||
}
|
||||
|
||||
case .toolCall:
|
||||
break
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
if !Task.isCancelled {
|
||||
conversation.appendToMessage(
|
||||
at: assistantIndex,
|
||||
chunk: "\n\n[Error: \(error.localizedDescription)]"
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
conversation.finalizeMessage(at: assistantIndex)
|
||||
isGenerating = false
|
||||
generationTask = nil
|
||||
}
|
||||
}
|
||||
|
||||
func stop() {
|
||||
generationTask?.cancel()
|
||||
generationTask = nil
|
||||
isGenerating = false
|
||||
|
||||
if let last = conversation.messages.indices.last,
|
||||
conversation.messages[last].isStreaming {
|
||||
conversation.finalizeMessage(at: last)
|
||||
}
|
||||
}
|
||||
|
||||
func attachImage(_ image: NSImage) {
|
||||
attachedImages.append(image)
|
||||
}
|
||||
|
||||
func removeImage(at index: Int) {
|
||||
guard attachedImages.indices.contains(index) else { return }
|
||||
attachedImages.remove(at: index)
|
||||
}
|
||||
|
||||
func newConversation() {
|
||||
stop()
|
||||
conversation.clear()
|
||||
resetSession()
|
||||
}
|
||||
|
||||
/// Reset the chat session (e.g. on model switch or new conversation).
|
||||
func resetSession() {
|
||||
chatSession = nil
|
||||
}
|
||||
|
||||
// MARK: - API Server
|
||||
|
||||
func startAPIServer() {
|
||||
apiServer.start(modelManager: modelManager, port: Preferences.apiPort)
|
||||
}
|
||||
|
||||
func stopAPIServer() {
|
||||
apiServer.stop()
|
||||
}
|
||||
}
|
||||
71
MLXServer/ViewModels/ModelManager.swift
Normal file
71
MLXServer/ViewModels/ModelManager.swift
Normal file
@@ -0,0 +1,71 @@
|
||||
import Foundation
|
||||
import MLX
|
||||
import MLXLMCommon
|
||||
import MLXVLM
|
||||
|
||||
/// Manages model loading, switching, and generation.
|
||||
@Observable
|
||||
@MainActor
|
||||
final class ModelManager {
|
||||
var currentModel: ModelConfig?
|
||||
var modelContainer: ModelContainer?
|
||||
var isLoading = false
|
||||
var downloadProgress: Double = 0
|
||||
var loadingModelName: String = ""
|
||||
var errorMessage: String?
|
||||
|
||||
/// Load a model, unloading the current one first.
|
||||
/// Prefers the local snapshot from ~/.cache/huggingface/hub/ (shared with the Python server).
|
||||
/// Only downloads if the model isn't cached locally.
|
||||
func loadModel(_ config: ModelConfig) async {
|
||||
if currentModel?.id == config.id && modelContainer != nil {
|
||||
return // already loaded
|
||||
}
|
||||
|
||||
unloadModel()
|
||||
isLoading = true
|
||||
downloadProgress = 0
|
||||
loadingModelName = config.displayName
|
||||
errorMessage = nil
|
||||
|
||||
do {
|
||||
let container: ModelContainer
|
||||
let progressHandler: @Sendable (Progress) -> Void = { progress in
|
||||
Task { @MainActor in
|
||||
self.downloadProgress = progress.fractionCompleted
|
||||
}
|
||||
}
|
||||
|
||||
let configuration: ModelConfiguration
|
||||
if let localDir = LocalModelResolver.resolve(repoId: config.repoId) {
|
||||
configuration = ModelConfiguration(directory: localDir)
|
||||
} else {
|
||||
configuration = config.modelConfiguration
|
||||
}
|
||||
|
||||
container = try await VLMModelFactory.shared.loadContainer(
|
||||
configuration: configuration,
|
||||
progressHandler: progressHandler
|
||||
)
|
||||
|
||||
self.modelContainer = container
|
||||
self.currentModel = config
|
||||
} catch {
|
||||
self.errorMessage = "Failed to load model: \(error.localizedDescription)"
|
||||
}
|
||||
|
||||
isLoading = false
|
||||
}
|
||||
|
||||
/// Unload the current model and free GPU memory.
|
||||
func unloadModel() {
|
||||
modelContainer = nil
|
||||
currentModel = nil
|
||||
MLX.GPU.clearCache()
|
||||
}
|
||||
|
||||
/// Whether a model is ready for generation.
|
||||
var isReady: Bool {
|
||||
modelContainer != nil && !isLoading
|
||||
}
|
||||
}
|
||||
128
MLXServer/Views/ChatInputView.swift
Normal file
128
MLXServer/Views/ChatInputView.swift
Normal file
@@ -0,0 +1,128 @@
|
||||
import SwiftUI
|
||||
import UniformTypeIdentifiers
|
||||
|
||||
struct ChatInputView: View {
|
||||
@Bindable var viewModel: ChatViewModel
|
||||
|
||||
var body: some View {
|
||||
VStack(spacing: 8) {
|
||||
// Image preview strip
|
||||
if !viewModel.attachedImages.isEmpty {
|
||||
ScrollView(.horizontal, showsIndicators: false) {
|
||||
HStack(spacing: 8) {
|
||||
ForEach(Array(viewModel.attachedImages.enumerated()), id: \.offset) { index, image in
|
||||
ZStack(alignment: .topTrailing) {
|
||||
Image(nsImage: image)
|
||||
.resizable()
|
||||
.aspectRatio(contentMode: .fill)
|
||||
.frame(width: 60, height: 60)
|
||||
.clipShape(RoundedRectangle(cornerRadius: 8))
|
||||
|
||||
Button {
|
||||
viewModel.removeImage(at: index)
|
||||
} label: {
|
||||
Image(systemName: "xmark.circle.fill")
|
||||
.font(.caption)
|
||||
.foregroundStyle(.white)
|
||||
.background(Circle().fill(.black.opacity(0.5)))
|
||||
}
|
||||
.buttonStyle(.plain)
|
||||
.offset(x: 4, y: -4)
|
||||
}
|
||||
}
|
||||
}
|
||||
.padding(.horizontal, 12)
|
||||
}
|
||||
}
|
||||
|
||||
// Input row
|
||||
HStack(alignment: .bottom, spacing: 8) {
|
||||
// Image attach button
|
||||
Button {
|
||||
pickImage()
|
||||
} label: {
|
||||
Image(systemName: "photo.badge.plus")
|
||||
.font(.title3)
|
||||
}
|
||||
.buttonStyle(.plain)
|
||||
.disabled(!viewModel.modelManager.isReady)
|
||||
|
||||
// Text field
|
||||
TextField("Message…", text: $viewModel.inputText, axis: .vertical)
|
||||
.textFieldStyle(.plain)
|
||||
.lineLimit(1...8)
|
||||
.onSubmit {
|
||||
if !NSEvent.modifierFlags.contains(.shift) {
|
||||
viewModel.send()
|
||||
}
|
||||
}
|
||||
|
||||
// Send or Stop button
|
||||
if viewModel.isGenerating {
|
||||
Button {
|
||||
viewModel.stop()
|
||||
} label: {
|
||||
Image(systemName: "stop.circle.fill")
|
||||
.font(.title2)
|
||||
.foregroundStyle(.red)
|
||||
}
|
||||
.buttonStyle(.plain)
|
||||
.keyboardShortcut(.escape, modifiers: [])
|
||||
} else {
|
||||
Button {
|
||||
viewModel.send()
|
||||
} label: {
|
||||
Image(systemName: "arrow.up.circle.fill")
|
||||
.font(.title2)
|
||||
.foregroundStyle(Color.accentColor)
|
||||
}
|
||||
.buttonStyle(.plain)
|
||||
.disabled(viewModel.inputText.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty || !viewModel.modelManager.isReady)
|
||||
.keyboardShortcut(.return, modifiers: .command)
|
||||
}
|
||||
}
|
||||
.padding(.horizontal, 12)
|
||||
.padding(.vertical, 10)
|
||||
}
|
||||
.padding(.top, 4)
|
||||
.onDrop(of: [.image], isTargeted: nil) { providers in
|
||||
for provider in providers {
|
||||
_ = provider.loadObject(ofClass: NSImage.self) { image, _ in
|
||||
if let image = image as? NSImage {
|
||||
Task { @MainActor in
|
||||
viewModel.attachImage(image)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
// Cmd+V paste for images
|
||||
.onPasteCommand(of: [.image, .png, .jpeg, .tiff]) { providers in
|
||||
for provider in providers {
|
||||
_ = provider.loadObject(ofClass: NSImage.self) { image, _ in
|
||||
if let image = image as? NSImage {
|
||||
Task { @MainActor in
|
||||
viewModel.attachImage(image)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private func pickImage() {
|
||||
let panel = NSOpenPanel()
|
||||
panel.allowedContentTypes = [.image]
|
||||
panel.allowsMultipleSelection = true
|
||||
panel.canChooseDirectories = false
|
||||
|
||||
if panel.runModal() == .OK {
|
||||
for url in panel.urls {
|
||||
if let image = NSImage(contentsOf: url) {
|
||||
viewModel.attachImage(image)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
105
MLXServer/Views/ChatMessagesView.swift
Normal file
105
MLXServer/Views/ChatMessagesView.swift
Normal file
@@ -0,0 +1,105 @@
|
||||
import MarkdownUI
|
||||
import SwiftUI
|
||||
|
||||
struct ChatMessagesView: View {
|
||||
let viewModel: ChatViewModel
|
||||
|
||||
var body: some View {
|
||||
ScrollViewReader { proxy in
|
||||
ScrollView {
|
||||
LazyVStack(alignment: .leading, spacing: 12) {
|
||||
if viewModel.conversation.messages.isEmpty {
|
||||
emptyState
|
||||
} else {
|
||||
ForEach(viewModel.conversation.messages) { message in
|
||||
MessageBubbleView(message: message)
|
||||
.id(message.id)
|
||||
}
|
||||
}
|
||||
}
|
||||
.padding()
|
||||
}
|
||||
.onChange(of: viewModel.conversation.messages.last?.content) {
|
||||
scrollToBottom(proxy: proxy)
|
||||
}
|
||||
.onChange(of: viewModel.conversation.messages.count) {
|
||||
scrollToBottom(proxy: proxy)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private var emptyState: some View {
|
||||
VStack(spacing: 8) {
|
||||
Spacer()
|
||||
Image(systemName: "message")
|
||||
.font(.system(size: 40))
|
||||
.foregroundStyle(.secondary)
|
||||
Text("Start a conversation")
|
||||
.font(.title3)
|
||||
.foregroundStyle(.secondary)
|
||||
if viewModel.modelManager.currentModel == nil {
|
||||
Text("Select a model from the toolbar to begin")
|
||||
.font(.caption)
|
||||
.foregroundStyle(.tertiary)
|
||||
}
|
||||
Spacer()
|
||||
}
|
||||
.frame(maxWidth: .infinity, minHeight: 300)
|
||||
}
|
||||
|
||||
private func scrollToBottom(proxy: ScrollViewProxy) {
|
||||
if let lastId = viewModel.conversation.messages.last?.id {
|
||||
withAnimation(.easeOut(duration: 0.2)) {
|
||||
proxy.scrollTo(lastId, anchor: .bottom)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct MessageBubbleView: View {
|
||||
let message: ChatMessage
|
||||
|
||||
var body: some View {
|
||||
HStack {
|
||||
if message.role == .user { Spacer(minLength: 60) }
|
||||
|
||||
VStack(alignment: message.role == .user ? .trailing : .leading, spacing: 6) {
|
||||
// Show attached images
|
||||
if !message.images.isEmpty {
|
||||
HStack(spacing: 4) {
|
||||
ForEach(Array(message.images.enumerated()), id: \.offset) { _, image in
|
||||
Image(nsImage: image)
|
||||
.resizable()
|
||||
.aspectRatio(contentMode: .fill)
|
||||
.frame(width: 80, height: 80)
|
||||
.clipShape(RoundedRectangle(cornerRadius: 8))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Message content
|
||||
if !message.content.isEmpty || message.isStreaming {
|
||||
Group {
|
||||
if message.role == .assistant {
|
||||
Markdown(message.content + (message.isStreaming ? " ●" : ""))
|
||||
.textSelection(.enabled)
|
||||
} else {
|
||||
Text(message.content)
|
||||
.textSelection(.enabled)
|
||||
}
|
||||
}
|
||||
.padding(.horizontal, 12)
|
||||
.padding(.vertical, 8)
|
||||
.background(
|
||||
message.role == .user
|
||||
? Color.accentColor.opacity(0.15)
|
||||
: Color.secondary.opacity(0.1)
|
||||
)
|
||||
.clipShape(RoundedRectangle(cornerRadius: 12))
|
||||
}
|
||||
}
|
||||
|
||||
if message.role == .assistant { Spacer(minLength: 60) }
|
||||
}
|
||||
}
|
||||
}
|
||||
32
MLXServer/Views/ModelPickerView.swift
Normal file
32
MLXServer/Views/ModelPickerView.swift
Normal file
@@ -0,0 +1,32 @@
|
||||
import SwiftUI
|
||||
|
||||
struct ModelPickerView: View {
|
||||
@Environment(ModelManager.self) private var modelManager
|
||||
|
||||
var body: some View {
|
||||
HStack(spacing: 8) {
|
||||
Picker("Model", selection: selectedModelBinding) {
|
||||
ForEach(ModelConfig.availableModels) { config in
|
||||
Label(
|
||||
config.displayName,
|
||||
systemImage: config.isLocal ? "checkmark.circle.fill" : "arrow.down.circle"
|
||||
).tag(config.id)
|
||||
}
|
||||
}
|
||||
.frame(width: 160)
|
||||
.disabled(modelManager.isLoading)
|
||||
}
|
||||
}
|
||||
|
||||
private var selectedModelBinding: Binding<String> {
|
||||
Binding(
|
||||
get: { modelManager.currentModel?.id ?? ModelConfig.default.id },
|
||||
set: { newId in
|
||||
guard let config = ModelConfig.availableModels.first(where: { $0.id == newId }) else { return }
|
||||
Task {
|
||||
await modelManager.loadModel(config)
|
||||
}
|
||||
}
|
||||
)
|
||||
}
|
||||
}
|
||||
44
MLXServer/Views/SettingsView.swift
Normal file
44
MLXServer/Views/SettingsView.swift
Normal file
@@ -0,0 +1,44 @@
|
||||
import SwiftUI
|
||||
|
||||
struct SettingsView: View {
|
||||
@State private var systemPrompt: String = Preferences.systemPrompt
|
||||
@State private var apiPort: String = String(Preferences.apiPort)
|
||||
@State private var apiAutoStart: Bool = Preferences.apiAutoStart
|
||||
|
||||
var body: some View {
|
||||
Form {
|
||||
Section("System Prompt") {
|
||||
TextEditor(text: $systemPrompt)
|
||||
.font(.body.monospaced())
|
||||
.frame(minHeight: 80)
|
||||
.onChange(of: systemPrompt) {
|
||||
Preferences.systemPrompt = systemPrompt
|
||||
}
|
||||
|
||||
Text("Applied to new conversations. Leave empty for no system prompt.")
|
||||
.font(.caption)
|
||||
.foregroundStyle(.secondary)
|
||||
}
|
||||
|
||||
Section("API Server") {
|
||||
HStack {
|
||||
Text("Port")
|
||||
TextField("1234", text: $apiPort)
|
||||
.frame(width: 80)
|
||||
.onChange(of: apiPort) {
|
||||
if let port = Int(apiPort), port > 0, port < 65536 {
|
||||
Preferences.apiPort = port
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Toggle("Start API server automatically", isOn: $apiAutoStart)
|
||||
.onChange(of: apiAutoStart) {
|
||||
Preferences.apiAutoStart = apiAutoStart
|
||||
}
|
||||
}
|
||||
}
|
||||
.formStyle(.grouped)
|
||||
.frame(width: 450, height: 300)
|
||||
}
|
||||
}
|
||||
77
MLXServer/Views/StatusBarView.swift
Normal file
77
MLXServer/Views/StatusBarView.swift
Normal file
@@ -0,0 +1,77 @@
|
||||
import MLX
|
||||
import SwiftUI
|
||||
|
||||
struct StatusBarView: View {
|
||||
let viewModel: ChatViewModel
|
||||
@Environment(ModelManager.self) private var modelManager
|
||||
|
||||
var body: some View {
|
||||
HStack(spacing: 16) {
|
||||
// Model info
|
||||
if modelManager.isLoading {
|
||||
let pct = Int(modelManager.downloadProgress * 100)
|
||||
Text("Loading \(modelManager.loadingModelName)… \(pct)%")
|
||||
.font(.caption.monospacedDigit())
|
||||
.foregroundStyle(.orange)
|
||||
} else if let model = modelManager.currentModel {
|
||||
Label(model.displayName, systemImage: "cpu")
|
||||
.font(.caption)
|
||||
.foregroundStyle(.secondary)
|
||||
|
||||
Text("\(model.contextLength / 1000)k ctx")
|
||||
.font(.caption)
|
||||
.foregroundStyle(.tertiary)
|
||||
} else {
|
||||
Text("No model loaded")
|
||||
.font(.caption)
|
||||
.foregroundStyle(.secondary)
|
||||
}
|
||||
|
||||
Spacer()
|
||||
|
||||
// GPU memory
|
||||
let activeMB = Double(MLX.GPU.activeMemory) / 1_048_576
|
||||
if activeMB > 0 {
|
||||
Text(String(format: "GPU: %.0f MB", activeMB))
|
||||
.font(.caption.monospacedDigit())
|
||||
.foregroundStyle(.tertiary)
|
||||
}
|
||||
|
||||
// Token generation speed
|
||||
if viewModel.isGenerating {
|
||||
Text(String(format: "%.1f tok/s", viewModel.tokensPerSecond))
|
||||
.font(.caption.monospacedDigit())
|
||||
.foregroundStyle(.secondary)
|
||||
}
|
||||
|
||||
// Token counts
|
||||
if viewModel.promptTokens > 0 || viewModel.generationTokens > 0 {
|
||||
Text("\(viewModel.promptTokens)→\(viewModel.generationTokens) tok")
|
||||
.font(.caption.monospacedDigit())
|
||||
.foregroundStyle(.tertiary)
|
||||
}
|
||||
|
||||
// API server status
|
||||
if viewModel.apiServer.isRunning {
|
||||
Label("API :\(viewModel.apiServer.port)", systemImage: "network")
|
||||
.font(.caption)
|
||||
.foregroundStyle(.green)
|
||||
} else {
|
||||
Label("API off", systemImage: "network.slash")
|
||||
.font(.caption)
|
||||
.foregroundStyle(.tertiary)
|
||||
}
|
||||
|
||||
// Error
|
||||
if let error = modelManager.errorMessage {
|
||||
Label(error, systemImage: "exclamationmark.triangle")
|
||||
.font(.caption)
|
||||
.foregroundStyle(.red)
|
||||
.lineLimit(1)
|
||||
}
|
||||
}
|
||||
.padding(.horizontal, 12)
|
||||
.padding(.vertical, 4)
|
||||
.background(.bar)
|
||||
}
|
||||
}
|
||||
105
README.md
105
README.md
@@ -1,63 +1,57 @@
|
||||
# MLX Server
|
||||
|
||||
OpenAI-compatible API server for running local LLMs on Apple Silicon via [MLX](https://github.com/ml-explore/mlx). Supports vision and tool use with automatic model swapping — only one model is loaded in memory at a time, switched on demand based on the request's `model` field.
|
||||
Native macOS app for running local LLMs on Apple Silicon via [MLX](https://github.com/ml-explore/mlx). Built with SwiftUI, it provides both a **chat UI** and an embedded **OpenAI-compatible API server**. Supports vision and tool use with automatic model swapping.
|
||||
|
||||
## Supported Models
|
||||
|
||||
| Alias | Model | Context | Capabilities |
|
||||
|-------|-------|---------|-------------|
|
||||
| `gemma` | `mlx-community/gemma-3-4b-it-4bit` | 128k | Vision, tool use (`tool_code` blocks) |
|
||||
| `gemma3n` | `mlx-community/gemma-3n-E4B-it-4bit` | 32k | Vision/audio/video, tool use (`tool_code` blocks), ~1.5x faster |
|
||||
| `qwen` | `mlx-community/Qwen3-VL-4B-Instruct-4bit` | 256k | Vision, tool use (`<tool_call>` tags) |
|
||||
|
||||
## Quick Start
|
||||
|
||||
Requires macOS 15+, Xcode 16.4+, and `xcodegen` (`brew install xcodegen`).
|
||||
|
||||
```bash
|
||||
source .venv/bin/activate
|
||||
|
||||
# Start with Gemma 3 (default)
|
||||
./run.sh
|
||||
|
||||
# Start with Qwen3
|
||||
./run.sh qwen
|
||||
|
||||
# Or directly
|
||||
python -m mlx_server.main --model mlx-community/gemma-3-4b-it-4bit --port 1234
|
||||
./build.sh # Debug build
|
||||
open "build/Debug/MLX Server.app"
|
||||
```
|
||||
|
||||
The server starts at `http://127.0.0.1:1234`.
|
||||
## App Features
|
||||
|
||||
## API
|
||||
- **Chat interface** with markdown rendering, image attachments (file picker, drag & drop, clipboard paste)
|
||||
- **Model picker** in toolbar with local/download status indicators
|
||||
- **Streaming responses** with live token display
|
||||
- **Status bar** showing model name, context window, tokens/sec, token counts, GPU memory, API server status
|
||||
- **Keyboard shortcuts**: `Cmd+N` (new chat), `Cmd+Return` (send), `Escape` (stop), `Cmd+1/2/3` (switch models)
|
||||
- **Settings** (`Cmd+,`): system prompt, API port, API auto-start
|
||||
|
||||
Standard OpenAI-compatible endpoints:
|
||||
## API Server
|
||||
|
||||
- `GET /v1/models` — lists all available models with `context_window` sizes
|
||||
The embedded API server (toggle in toolbar) runs on port 1234 by default. Standard OpenAI-compatible endpoints:
|
||||
|
||||
- `GET /v1/models` — lists available models with `context_window` sizes
|
||||
- `POST /v1/chat/completions` — chat completions (streaming and non-streaming)
|
||||
- `GET /health` — health check
|
||||
|
||||
### Model Swapping
|
||||
|
||||
Send any available model ID (or alias) in the `model` field. If it differs from the currently loaded model, the server unloads the old one and loads the new one automatically:
|
||||
Send any model ID or alias in the `model` field. If it differs from the currently loaded model, the server swaps automatically:
|
||||
|
||||
```bash
|
||||
# Uses Gemma
|
||||
curl http://localhost:1234/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"model": "mlx-community/gemma-3-4b-it-4bit", "messages": [{"role": "user", "content": "Hello"}]}'
|
||||
|
||||
# Swaps to Qwen
|
||||
curl http://localhost:1234/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"model": "mlx-community/Qwen3-VL-4B-Instruct-4bit", "messages": [{"role": "user", "content": "Hello"}]}'
|
||||
-d '{"model": "gemma", "messages": [{"role": "user", "content": "Hello"}]}'
|
||||
```
|
||||
|
||||
### Vision
|
||||
|
||||
Pass images as base64 data URIs or URLs in the `image_url` content part:
|
||||
Pass images as base64 data URIs in the `image_url` content part:
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "mlx-community/gemma-3-4b-it-4bit",
|
||||
"model": "gemma",
|
||||
"messages": [{
|
||||
"role": "user",
|
||||
"content": [
|
||||
@@ -68,37 +62,50 @@ Pass images as base64 data URIs or URLs in the `image_url` content part:
|
||||
}
|
||||
```
|
||||
|
||||
### Context Window Management
|
||||
|
||||
Each model's context window is read from its HuggingFace config (`max_position_embeddings`) and reported in `/v1/models` via the `context_window` field. Clients can use this to manage conversation length proactively.
|
||||
|
||||
If a request exceeds the context window, the server:
|
||||
|
||||
1. Automatically summarizes older messages (keeping system messages and the last 6 messages intact)
|
||||
2. Retries with the compressed conversation
|
||||
3. Returns an OpenAI-compatible `context_length_exceeded` error if it still doesn't fit
|
||||
|
||||
### Tool Use
|
||||
|
||||
Pass tools in the `tools` field (OpenAI format). The server handles model-specific formatting and parses tool calls from the output automatically.
|
||||
|
||||
## Installation
|
||||
|
||||
Requires Python 3.11+ and Apple Silicon.
|
||||
|
||||
```bash
|
||||
uv pip install -e "."
|
||||
```
|
||||
Pass tools in the `tools` field (OpenAI format). The server handles model-specific formatting (Gemma `tool_code` blocks, Qwen `<tool_call>` XML tags) and parses tool calls from output automatically. When tools are present during streaming, output is buffered to strip tool-call markup before sending to the client.
|
||||
|
||||
## Project Structure
|
||||
|
||||
```
|
||||
mlx_server/
|
||||
main.py — FastAPI server, endpoints, CLI entrypoint
|
||||
engine.py — Model loading, prompt building, generation (mlx_vlm)
|
||||
models.py — Pydantic models for OpenAI API types
|
||||
MLXServer/
|
||||
├── MLXServerApp.swift — App entry point, GPU cache config
|
||||
├── ContentView.swift — Main layout, toolbar, keyboard shortcuts
|
||||
├── Models/
|
||||
│ ├── ModelConfig.swift — Model definitions, alias/repoId resolution
|
||||
│ └── ChatMessage.swift — Chat message data model
|
||||
├── ViewModels/
|
||||
│ ├── ModelManager.swift — Model loading/switching via VLMModelFactory
|
||||
│ └── ChatViewModel.swift — Chat state, ChatSession, API server lifecycle
|
||||
├── Views/
|
||||
│ ├── ModelPickerView.swift — Toolbar model selector
|
||||
│ ├── ChatMessagesView.swift — Scrollable message list with markdown
|
||||
│ ├── ChatInputView.swift — Text input + image attach
|
||||
│ ├── StatusBarView.swift — Model info, tok/s, GPU memory, API status
|
||||
│ └── SettingsView.swift — System prompt + API settings
|
||||
├── Server/
|
||||
│ ├── APIServer.swift — NWListener HTTP server, SSE streaming, KV cache reuse
|
||||
│ ├── APIModels.swift — OpenAI-compatible Codable structs
|
||||
│ ├── ToolCallParser.swift — Parses tool calls from model output
|
||||
│ └── ToolPromptBuilder.swift — Model-specific tool prompt formatting
|
||||
└── Utilities/
|
||||
├── LocalModelResolver.swift — Offline-first HuggingFace cache resolution
|
||||
└── Preferences.swift — UserDefaults wrapper
|
||||
|
||||
project.yml — xcodegen project spec (dependencies, settings, deployment target)
|
||||
build.sh — One-command build script (xcodegen + xcodebuild)
|
||||
```
|
||||
|
||||
## Key Design Decisions
|
||||
|
||||
- Uses `mlx-swift-lm` (`MLXVLM` / `VLMModelFactory`) for inference — supports both text and vision in a single model load
|
||||
- **Offline-first**: `LocalModelResolver` checks `~/.cache/huggingface/hub/` for locally-cached snapshots before downloading
|
||||
- **KV cache reuse** across API requests — reuses `ChatSession` when conversation history prefix matches
|
||||
- HTTP server built on `Network.framework` (`NWListener`) — no third-party server dependencies
|
||||
- Model-specific prompt formatting: Gemma uses `tool_code` blocks, Qwen uses `<tool_call>` XML tags
|
||||
- GPU cache limit set to 20 MB; cache cleared on model unload
|
||||
|
||||
## Design Notes
|
||||
|
||||
- Uses `mlx_vlm` (not `mlx_lm`) as the backend — supports both text and vision in a single model load
|
||||
|
||||
41
build.sh
Executable file
41
build.sh
Executable file
@@ -0,0 +1,41 @@
|
||||
#!/bin/bash
|
||||
set -euo pipefail
|
||||
|
||||
PROJECT_DIR="$(cd "$(dirname "$0")" && pwd)"
|
||||
BUILD_DIR="$PROJECT_DIR/build"
|
||||
CONFIG="${1:-Debug}"
|
||||
APP_NAME="MLX Server"
|
||||
|
||||
echo "==> Building $APP_NAME ($CONFIG)"
|
||||
|
||||
# Regenerate Xcode project from project.yml (picks up any new/removed files)
|
||||
if command -v xcodegen &>/dev/null; then
|
||||
xcodegen generate --spec "$PROJECT_DIR/project.yml" --project "$PROJECT_DIR" 2>&1 | grep -v '^$'
|
||||
fi
|
||||
|
||||
# Build — filter to show only app source compilation, errors, and result
|
||||
xcodebuild \
|
||||
-project "$PROJECT_DIR/MLXServer.xcodeproj" \
|
||||
-scheme MLXServer \
|
||||
-destination 'platform=macOS' \
|
||||
-configuration "$CONFIG" \
|
||||
SYMROOT="$BUILD_DIR" \
|
||||
build 2>&1 | \
|
||||
grep -E "(CompileSwift .* 'MLXServer'|error:|warning:.*MLXServer/|BUILD )" | \
|
||||
sed "s|.*CompileSwift normal arm64 Compiling ||" | \
|
||||
sed "s| (in target 'MLXServer' from project 'MLXServer')||"
|
||||
|
||||
APP_PATH="$BUILD_DIR/$CONFIG/$APP_NAME.app"
|
||||
|
||||
if [ -d "$APP_PATH" ] && [ -f "$APP_PATH/Contents/MacOS/$APP_NAME" ]; then
|
||||
echo ""
|
||||
echo "==> Build succeeded"
|
||||
echo " $APP_PATH"
|
||||
echo ""
|
||||
echo " Run: open \"$APP_PATH\""
|
||||
echo " Or: \"$APP_PATH/Contents/MacOS/$APP_NAME\""
|
||||
else
|
||||
echo ""
|
||||
echo "==> Build failed"
|
||||
exit 1
|
||||
fi
|
||||
@@ -1,3 +0,0 @@
|
||||
from mlx_server.main import main
|
||||
|
||||
main()
|
||||
1120
mlx_server/engine.py
1120
mlx_server/engine.py
File diff suppressed because it is too large
Load Diff
@@ -1,409 +0,0 @@
|
||||
"""OpenAI-compatible API server for local LLMs (Gemma 3, Qwen3, …) via MLX."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from sse_starlette.sse import EventSourceResponse
|
||||
|
||||
from .engine import DEFAULT_MODEL, InferenceEngine, ModelManager
|
||||
from .models import (
|
||||
ChatCompletionChunk,
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
Choice,
|
||||
ChoiceMessage,
|
||||
DeltaMessage,
|
||||
ModelInfo,
|
||||
ModelListResponse,
|
||||
StreamChoice,
|
||||
ToolCall,
|
||||
FunctionCall,
|
||||
UsageInfo,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
app = FastAPI(title="MLX Server", description="OpenAI-compatible API for local LLMs on Apple Silicon")
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
manager: ModelManager | None = None
|
||||
|
||||
# Number of recent messages to always preserve when summarizing
|
||||
_KEEP_RECENT = 6
|
||||
|
||||
|
||||
def get_engine(requested_model: str | None = None):
|
||||
if manager is None:
|
||||
raise HTTPException(status_code=503, detail="Server not initialized")
|
||||
return manager.get_engine(requested_model)
|
||||
|
||||
|
||||
def _make_id() -> str:
|
||||
return f"chatcmpl-{uuid.uuid4().hex[:12]}"
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Context window management
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
def _manage_context(
|
||||
e: InferenceEngine,
|
||||
messages: list[dict],
|
||||
tools: list[dict] | None,
|
||||
max_tokens: int,
|
||||
) -> list[dict]:
|
||||
"""Check if messages fit in the context window; summarize if needed.
|
||||
|
||||
Returns the (possibly summarized) message list. Raises HTTPException
|
||||
with an OpenAI-compatible error if the conversation cannot fit.
|
||||
"""
|
||||
context_length = e.context_length
|
||||
if context_length <= 0:
|
||||
return messages # unknown context size, skip check
|
||||
|
||||
prompt, _ = e.build_prompt(messages, tools)
|
||||
prompt_tokens = e.count_tokens(prompt)
|
||||
available = context_length - max_tokens
|
||||
|
||||
if prompt_tokens <= available:
|
||||
return messages
|
||||
|
||||
# --- Need to summarize ---
|
||||
logger.info(
|
||||
"Context window pressure: %d prompt tokens + %d max_tokens = %d "
|
||||
"(limit %d). Attempting summarization.",
|
||||
prompt_tokens, max_tokens, prompt_tokens + max_tokens, context_length,
|
||||
)
|
||||
|
||||
# Split messages: system | middle (summarizable) | recent (kept)
|
||||
system_msgs = [m for m in messages if m.get("role") == "system"]
|
||||
non_system = [m for m in messages if m.get("role") != "system"]
|
||||
|
||||
if len(non_system) <= _KEEP_RECENT:
|
||||
_raise_context_exceeded(prompt_tokens, max_tokens, context_length)
|
||||
|
||||
recent = non_system[-_KEEP_RECENT:]
|
||||
middle = non_system[:-_KEEP_RECENT]
|
||||
|
||||
# Generate summary of the middle messages
|
||||
summary_text = e.summarize_messages(middle)
|
||||
|
||||
summary_msg = {
|
||||
"role": "user",
|
||||
"content": f"[Summary of earlier conversation]\n{summary_text}",
|
||||
}
|
||||
ack_msg = {
|
||||
"role": "assistant",
|
||||
"content": "Understood, I have the context from our earlier conversation.",
|
||||
}
|
||||
new_messages = system_msgs + [summary_msg, ack_msg] + recent
|
||||
|
||||
# Re-check fit
|
||||
new_prompt, _ = e.build_prompt(new_messages, tools)
|
||||
new_prompt_tokens = e.count_tokens(new_prompt)
|
||||
|
||||
if new_prompt_tokens + max_tokens > context_length:
|
||||
logger.warning(
|
||||
"Still over context limit after summarization: %d + %d = %d (limit %d)",
|
||||
new_prompt_tokens, max_tokens, new_prompt_tokens + max_tokens, context_length,
|
||||
)
|
||||
_raise_context_exceeded(new_prompt_tokens, max_tokens, context_length)
|
||||
|
||||
logger.info(
|
||||
"Summarization reduced prompt from %d to %d tokens (saved %d).",
|
||||
prompt_tokens, new_prompt_tokens, prompt_tokens - new_prompt_tokens,
|
||||
)
|
||||
return new_messages
|
||||
|
||||
|
||||
def _raise_context_exceeded(prompt_tokens: int, max_tokens: int, context_length: int):
|
||||
"""Raise an OpenAI-compatible context_length_exceeded error."""
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": {
|
||||
"message": (
|
||||
f"This model's maximum context length is {context_length} tokens. "
|
||||
f"However, your messages resulted in {prompt_tokens} tokens and "
|
||||
f"{max_tokens} tokens were requested for the completion "
|
||||
f"({prompt_tokens + max_tokens} total). "
|
||||
f"Please reduce the length of the messages or completion."
|
||||
),
|
||||
"type": "invalid_request_error",
|
||||
"code": "context_length_exceeded",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Endpoints
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
@app.get("/v1/models")
|
||||
async def list_models() -> ModelListResponse:
|
||||
if manager is None:
|
||||
raise HTTPException(status_code=503, detail="Server not initialized")
|
||||
return ModelListResponse(
|
||||
data=[
|
||||
ModelInfo(
|
||||
id=model_id,
|
||||
context_window=manager.get_context_length(model_id),
|
||||
)
|
||||
for model_id in manager.available_models
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@app.post("/v1/chat/completions")
|
||||
async def chat_completions(request: ChatCompletionRequest):
|
||||
e = get_engine(request.model)
|
||||
|
||||
# Convert pydantic messages to dicts
|
||||
messages = [m.model_dump(exclude_none=True) for m in request.messages]
|
||||
tools = None
|
||||
if request.tools:
|
||||
tools = [t.model_dump(exclude_none=True) for t in request.tools]
|
||||
|
||||
stop = request.stop
|
||||
if isinstance(stop, str):
|
||||
stop = [stop]
|
||||
|
||||
temperature = request.temperature if request.temperature is not None else 0.7
|
||||
top_p = request.top_p if request.top_p is not None else 0.9
|
||||
max_tokens = request.max_tokens if request.max_tokens is not None else 4096
|
||||
|
||||
# Context window management: summarize if needed, error if impossible
|
||||
messages = _manage_context(e, messages, tools, max_tokens)
|
||||
|
||||
prompt, images = e.build_prompt(messages, tools)
|
||||
|
||||
if request.stream:
|
||||
return EventSourceResponse(
|
||||
_stream_response(e, prompt, images, max_tokens, temperature, top_p, stop, tools, request.model),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
|
||||
# Non-streaming
|
||||
text, prompt_tokens, completion_tokens = e.generate(
|
||||
prompt=prompt,
|
||||
images=images or None,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
stop=stop,
|
||||
)
|
||||
|
||||
# Check for tool calls in the response
|
||||
finish_reason = "stop"
|
||||
tool_calls_parsed = None
|
||||
if tools:
|
||||
clean_text, parsed = e.parse_tool_calls(text, tools)
|
||||
if parsed:
|
||||
tool_calls_parsed = [
|
||||
ToolCall(
|
||||
index=i,
|
||||
id=tc["id"],
|
||||
type="function",
|
||||
function=FunctionCall(
|
||||
name=tc["function"]["name"],
|
||||
arguments=tc["function"]["arguments"],
|
||||
),
|
||||
)
|
||||
for i, tc in enumerate(parsed)
|
||||
]
|
||||
text = clean_text if clean_text else None
|
||||
finish_reason = "tool_calls"
|
||||
|
||||
return ChatCompletionResponse(
|
||||
id=_make_id(),
|
||||
model=request.model,
|
||||
choices=[
|
||||
Choice(
|
||||
message=ChoiceMessage(
|
||||
role="assistant",
|
||||
content=text if not tool_calls_parsed else (text or None),
|
||||
tool_calls=tool_calls_parsed,
|
||||
),
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
],
|
||||
usage=UsageInfo(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
async def _stream_response(
|
||||
e,
|
||||
prompt: str,
|
||||
images: list[str] | None,
|
||||
max_tokens: int,
|
||||
temperature: float,
|
||||
top_p: float,
|
||||
stop: list[str] | None,
|
||||
tools: list[dict] | None,
|
||||
model_name: str,
|
||||
):
|
||||
request_id = _make_id()
|
||||
created = int(time.time())
|
||||
|
||||
# Send initial chunk with role
|
||||
initial_chunk = ChatCompletionChunk(
|
||||
id=request_id,
|
||||
created=created,
|
||||
model=model_name,
|
||||
choices=[StreamChoice(delta=DeltaMessage(role="assistant"))],
|
||||
)
|
||||
yield {"data": initial_chunk.model_dump_json()}
|
||||
|
||||
full_text = ""
|
||||
prompt_tokens = 0
|
||||
gen_tokens = 0
|
||||
|
||||
# When tools are available we must buffer the full response before
|
||||
# emitting content — otherwise raw tool-call markup (```tool_code```
|
||||
# or <tool_call>) leaks into the streamed text.
|
||||
buffer_for_tools = bool(tools)
|
||||
|
||||
for token_text, is_final, pt, gt in e.stream_generate(
|
||||
prompt=prompt,
|
||||
images=images or None,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
stop=stop,
|
||||
):
|
||||
prompt_tokens = pt
|
||||
gen_tokens = gt
|
||||
full_text += token_text
|
||||
|
||||
if not buffer_for_tools and not is_final and token_text:
|
||||
chunk = ChatCompletionChunk(
|
||||
id=request_id,
|
||||
created=created,
|
||||
model=model_name,
|
||||
choices=[StreamChoice(delta=DeltaMessage(content=token_text))],
|
||||
)
|
||||
yield {"data": chunk.model_dump_json()}
|
||||
|
||||
# --- Post-generation: parse tool calls and emit clean content ------
|
||||
finish_reason = "stop"
|
||||
tool_calls_parsed = []
|
||||
|
||||
if tools:
|
||||
clean_text, parsed = e.parse_tool_calls(full_text, tools)
|
||||
if parsed:
|
||||
finish_reason = "tool_calls"
|
||||
tool_calls_parsed = parsed
|
||||
full_text = clean_text or ""
|
||||
|
||||
# Emit buffered content (when tools were present, this is the cleaned
|
||||
# text with tool-call markup stripped out)
|
||||
if buffer_for_tools and full_text.strip():
|
||||
content_chunk = ChatCompletionChunk(
|
||||
id=request_id,
|
||||
created=created,
|
||||
model=model_name,
|
||||
choices=[StreamChoice(delta=DeltaMessage(content=full_text))],
|
||||
)
|
||||
yield {"data": content_chunk.model_dump_json()}
|
||||
|
||||
# Emit tool call chunks
|
||||
for i, tc in enumerate(tool_calls_parsed):
|
||||
tc_chunk = ChatCompletionChunk(
|
||||
id=request_id,
|
||||
created=created,
|
||||
model=model_name,
|
||||
choices=[
|
||||
StreamChoice(
|
||||
delta=DeltaMessage(
|
||||
tool_calls=[
|
||||
ToolCall(
|
||||
index=i,
|
||||
id=tc["id"],
|
||||
type="function",
|
||||
function=FunctionCall(
|
||||
name=tc["function"]["name"],
|
||||
arguments=tc["function"]["arguments"],
|
||||
),
|
||||
)
|
||||
]
|
||||
)
|
||||
)
|
||||
],
|
||||
)
|
||||
yield {"data": tc_chunk.model_dump_json()}
|
||||
|
||||
# Final chunk with finish reason and usage
|
||||
final_chunk = ChatCompletionChunk(
|
||||
id=request_id,
|
||||
created=created,
|
||||
model=model_name,
|
||||
choices=[StreamChoice(delta=DeltaMessage(), finish_reason=finish_reason)],
|
||||
usage=UsageInfo(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=gen_tokens,
|
||||
total_tokens=prompt_tokens + gen_tokens,
|
||||
),
|
||||
)
|
||||
yield {"data": final_chunk.model_dump_json()}
|
||||
yield {"data": "[DONE]"}
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Health / utility
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Entrypoint
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="MLX Server – OpenAI-compatible API")
|
||||
parser.add_argument("--model", type=str, default=DEFAULT_MODEL, help="HuggingFace model path")
|
||||
parser.add_argument("--host", type=str, default="127.0.0.1")
|
||||
parser.add_argument("--port", type=int, default=1234)
|
||||
parser.add_argument("--log-level", type=str, default="info")
|
||||
args = parser.parse_args()
|
||||
|
||||
logging.basicConfig(
|
||||
level=getattr(logging, args.log_level.upper()),
|
||||
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
|
||||
)
|
||||
|
||||
global manager
|
||||
manager = ModelManager(default_model=args.model)
|
||||
manager.preload(args.model)
|
||||
|
||||
uvicorn.run(app, host=args.host, port=args.port, log_level=args.log_level)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,145 +0,0 @@
|
||||
"""OpenAI API compatible request/response models."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
# --- Request models ---
|
||||
|
||||
|
||||
class FunctionDefinition(BaseModel):
|
||||
name: str
|
||||
description: str | None = None
|
||||
parameters: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class ToolDefinition(BaseModel):
|
||||
type: Literal["function"] = "function"
|
||||
function: FunctionDefinition
|
||||
|
||||
|
||||
class FunctionCall(BaseModel):
|
||||
name: str
|
||||
arguments: str # JSON string
|
||||
|
||||
|
||||
class ToolCall(BaseModel):
|
||||
index: int = 0
|
||||
id: str
|
||||
type: Literal["function"] = "function"
|
||||
function: FunctionCall
|
||||
|
||||
|
||||
class ContentPartText(BaseModel):
|
||||
type: Literal["text"] = "text"
|
||||
text: str
|
||||
|
||||
|
||||
class ImageURL(BaseModel):
|
||||
url: str # Can be a URL or base64 data URI
|
||||
detail: str | None = None
|
||||
|
||||
|
||||
class ContentPartImage(BaseModel):
|
||||
type: Literal["image_url"] = "image_url"
|
||||
image_url: ImageURL
|
||||
|
||||
|
||||
ContentPart = ContentPartText | ContentPartImage
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
role: Literal["system", "user", "assistant", "tool"]
|
||||
content: str | list[ContentPart] | None = None
|
||||
name: str | None = None
|
||||
tool_calls: list[ToolCall] | None = None
|
||||
tool_call_id: str | None = None
|
||||
|
||||
|
||||
class ChatCompletionRequest(BaseModel):
|
||||
model: str = "gemma-3-4b-it"
|
||||
messages: list[ChatMessage]
|
||||
temperature: float | None = 0.7
|
||||
top_p: float | None = 0.9
|
||||
max_tokens: int | None = 4096
|
||||
stream: bool = False
|
||||
stop: str | list[str] | None = None
|
||||
tools: list[ToolDefinition] | None = None
|
||||
tool_choice: str | dict | None = None
|
||||
frequency_penalty: float | None = None
|
||||
presence_penalty: float | None = None
|
||||
n: int | None = 1
|
||||
|
||||
|
||||
# --- Response models ---
|
||||
|
||||
|
||||
class UsageInfo(BaseModel):
|
||||
prompt_tokens: int = 0
|
||||
completion_tokens: int = 0
|
||||
total_tokens: int = 0
|
||||
|
||||
|
||||
class ChoiceMessage(BaseModel):
|
||||
role: str = "assistant"
|
||||
content: str | None = None
|
||||
tool_calls: list[ToolCall] | None = None
|
||||
|
||||
|
||||
class Choice(BaseModel):
|
||||
index: int = 0
|
||||
message: ChoiceMessage
|
||||
finish_reason: str | None = "stop"
|
||||
|
||||
|
||||
class ChatCompletionResponse(BaseModel):
|
||||
id: str
|
||||
object: str = "chat.completion"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
model: str
|
||||
choices: list[Choice]
|
||||
usage: UsageInfo
|
||||
|
||||
|
||||
# --- Streaming response models ---
|
||||
|
||||
|
||||
class DeltaMessage(BaseModel):
|
||||
role: str | None = None
|
||||
content: str | None = None
|
||||
tool_calls: list[ToolCall] | None = None
|
||||
|
||||
|
||||
class StreamChoice(BaseModel):
|
||||
index: int = 0
|
||||
delta: DeltaMessage
|
||||
finish_reason: str | None = None
|
||||
|
||||
|
||||
class ChatCompletionChunk(BaseModel):
|
||||
id: str
|
||||
object: str = "chat.completion.chunk"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
model: str
|
||||
choices: list[StreamChoice]
|
||||
usage: UsageInfo | None = None
|
||||
|
||||
|
||||
# --- Model listing ---
|
||||
|
||||
|
||||
class ModelInfo(BaseModel):
|
||||
id: str
|
||||
object: str = "model"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
owned_by: str = "local"
|
||||
context_window: int | None = None
|
||||
|
||||
|
||||
class ModelListResponse(BaseModel):
|
||||
object: str = "list"
|
||||
data: list[ModelInfo]
|
||||
45
project.yml
Normal file
45
project.yml
Normal file
@@ -0,0 +1,45 @@
|
||||
name: MLXServer
|
||||
options:
|
||||
bundleIdPrefix: com.mlxserver
|
||||
deploymentTarget:
|
||||
macOS: "15.0"
|
||||
xcodeVersion: "16.4"
|
||||
minimumXcodeGenVersion: "2.40"
|
||||
|
||||
packages:
|
||||
mlx-swift-lm:
|
||||
url: https://github.com/ml-explore/mlx-swift-lm
|
||||
branch: main
|
||||
MarkdownUI:
|
||||
url: https://github.com/gonzalezreal/swift-markdown-ui
|
||||
from: "2.4.0"
|
||||
|
||||
targets:
|
||||
MLXServer:
|
||||
type: application
|
||||
platform: macOS
|
||||
sources:
|
||||
- MLXServer
|
||||
settings:
|
||||
base:
|
||||
PRODUCT_BUNDLE_IDENTIFIER: com.mlxserver.app
|
||||
PRODUCT_NAME: MLX Server
|
||||
MARKETING_VERSION: "1.0.0"
|
||||
CURRENT_PROJECT_VERSION: "1"
|
||||
SWIFT_VERSION: "6.0"
|
||||
MACOSX_DEPLOYMENT_TARGET: "15.0"
|
||||
GENERATE_INFOPLIST_FILE: "YES"
|
||||
INFOPLIST_KEY_LSApplicationCategoryType: "public.app-category.developer-tools"
|
||||
INFOPLIST_KEY_NSHumanReadableCopyright: ""
|
||||
CODE_SIGN_ENTITLEMENTS: MLXServer/MLXServer.entitlements
|
||||
CODE_SIGN_IDENTITY: "-"
|
||||
CODE_SIGN_ALLOW_ENTITLEMENTS_MODIFICATION: "YES"
|
||||
dependencies:
|
||||
- package: mlx-swift-lm
|
||||
product: MLXLLM
|
||||
- package: mlx-swift-lm
|
||||
product: MLXVLM
|
||||
- package: mlx-swift-lm
|
||||
product: MLXLMCommon
|
||||
- package: MarkdownUI
|
||||
product: MarkdownUI
|
||||
@@ -1,20 +0,0 @@
|
||||
[project]
|
||||
name = "mlx-server"
|
||||
version = "0.1.0"
|
||||
description = "OpenAI-compatible API server for Gemma 3 4B via MLX"
|
||||
requires-python = ">=3.11"
|
||||
dependencies = [
|
||||
"fastapi>=0.115.0",
|
||||
"uvicorn[standard]>=0.30.0",
|
||||
"mlx>=0.22.0",
|
||||
"mlx-lm>=0.22.0",
|
||||
"mlx-vlm>=0.1.18",
|
||||
"pydantic>=2.0.0",
|
||||
"sse-starlette>=2.0.0",
|
||||
"pillow>=10.0.0",
|
||||
"httpx>=0.27.0",
|
||||
"torchvision>=0.20.0",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
mlx-server = "mlx_server.main:main"
|
||||
47
run.sh
47
run.sh
@@ -1,47 +0,0 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
|
||||
cd "$SCRIPT_DIR"
|
||||
|
||||
# Activate virtual environment
|
||||
source .venv/bin/activate
|
||||
|
||||
# --- Model selection ---
|
||||
# Usage: ./run.sh [gemma|gemma3n|qwen]
|
||||
# Or set MODEL env var directly for a custom model.
|
||||
|
||||
MODEL_CHOICE="${1:-gemma}"
|
||||
|
||||
if [[ -z "${MODEL:-}" ]]; then
|
||||
case "$MODEL_CHOICE" in
|
||||
gemma)
|
||||
MODEL="mlx-community/gemma-3-4b-it-4bit"
|
||||
;;
|
||||
gemma3n)
|
||||
MODEL="mlx-community/gemma-3n-E4B-it-4bit"
|
||||
;;
|
||||
qwen)
|
||||
MODEL="mlx-community/Qwen3-VL-4B-Instruct-4bit"
|
||||
;;
|
||||
*)
|
||||
echo "Unknown model choice: $MODEL_CHOICE"
|
||||
echo "Usage: $0 [gemma|gemma3n|qwen]"
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
fi
|
||||
|
||||
HOST="${HOST:-127.0.0.1}"
|
||||
PORT="${PORT:-1234}"
|
||||
|
||||
echo "Starting MLX Server..."
|
||||
echo " Model: $MODEL"
|
||||
echo " Endpoint: http://$HOST:$PORT"
|
||||
echo ""
|
||||
|
||||
exec python -m mlx_server.main \
|
||||
--model "$MODEL" \
|
||||
--host "$HOST" \
|
||||
--port "$PORT" \
|
||||
"${@:2}"
|
||||
296
test_server.py
296
test_server.py
@@ -1,296 +0,0 @@
|
||||
"""Test script for MLX Server – exercises chat, streaming, vision, and tool use."""
|
||||
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
import sys
|
||||
|
||||
import httpx
|
||||
from PIL import Image, ImageDraw
|
||||
|
||||
BASE_URL = "http://127.0.0.1:1234/v1"
|
||||
MODEL = "mlx-community/gemma-3-4b-it-4bit"
|
||||
|
||||
|
||||
def test_models():
|
||||
"""Test GET /v1/models."""
|
||||
print("=" * 60)
|
||||
print("TEST: List models")
|
||||
print("=" * 60)
|
||||
r = httpx.get(f"{BASE_URL}/models")
|
||||
r.raise_for_status()
|
||||
data = r.json()
|
||||
print(f"Models: {[m['id'] for m in data['data']]}")
|
||||
print("PASS\n")
|
||||
|
||||
|
||||
def test_chat_basic():
|
||||
"""Test basic non-streaming chat."""
|
||||
print("=" * 60)
|
||||
print("TEST: Basic chat (non-streaming)")
|
||||
print("=" * 60)
|
||||
r = httpx.post(
|
||||
f"{BASE_URL}/chat/completions",
|
||||
json={
|
||||
"model": MODEL,
|
||||
"messages": [{"role": "user", "content": "Say exactly: 'The sky is blue.' Nothing else."}],
|
||||
"max_tokens": 50,
|
||||
"temperature": 0.1,
|
||||
},
|
||||
timeout=120,
|
||||
)
|
||||
r.raise_for_status()
|
||||
data = r.json()
|
||||
msg = data["choices"][0]["message"]["content"]
|
||||
usage = data["usage"]
|
||||
print(f"Response: {msg}")
|
||||
print(f"Usage: {usage}")
|
||||
print(f"Finish reason: {data['choices'][0]['finish_reason']}")
|
||||
print("PASS\n")
|
||||
|
||||
|
||||
def test_chat_streaming():
|
||||
"""Test streaming chat."""
|
||||
print("=" * 60)
|
||||
print("TEST: Streaming chat")
|
||||
print("=" * 60)
|
||||
collected = ""
|
||||
with httpx.stream(
|
||||
"POST",
|
||||
f"{BASE_URL}/chat/completions",
|
||||
json={
|
||||
"model": MODEL,
|
||||
"messages": [{"role": "user", "content": "Count from 1 to 5, one number per line."}],
|
||||
"max_tokens": 100,
|
||||
"temperature": 0.1,
|
||||
"stream": True,
|
||||
},
|
||||
timeout=120,
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
for line in response.iter_lines():
|
||||
if not line.startswith("data: "):
|
||||
continue
|
||||
payload = line[len("data: "):]
|
||||
if payload == "[DONE]":
|
||||
break
|
||||
chunk = json.loads(payload)
|
||||
delta = chunk["choices"][0]["delta"]
|
||||
if delta.get("content"):
|
||||
collected += delta["content"]
|
||||
print(delta["content"], end="", flush=True)
|
||||
if chunk["choices"][0].get("finish_reason"):
|
||||
print(f"\n[finish_reason: {chunk['choices'][0]['finish_reason']}]")
|
||||
if chunk.get("usage") and chunk["usage"].get("total_tokens", 0) > 0:
|
||||
print(f"[usage: {chunk['usage']}]")
|
||||
print(f"Full collected: {collected!r}")
|
||||
print("PASS\n")
|
||||
|
||||
|
||||
def _make_test_image() -> str:
|
||||
"""Create a simple test image and return it as a base64 data URI."""
|
||||
img = Image.new("RGB", (200, 200), color=(135, 206, 235))
|
||||
draw = ImageDraw.Draw(img)
|
||||
# Draw a red circle
|
||||
draw.ellipse([50, 50, 150, 150], fill=(255, 0, 0), outline=(0, 0, 0), width=2)
|
||||
# Draw a green triangle
|
||||
draw.polygon([(100, 20), (60, 80), (140, 80)], fill=(0, 180, 0), outline=(0, 0, 0))
|
||||
# Draw yellow text area
|
||||
draw.rectangle([10, 160, 190, 190], fill=(255, 255, 0))
|
||||
|
||||
buf = io.BytesIO()
|
||||
img.save(buf, format="PNG")
|
||||
b64 = base64.b64encode(buf.getvalue()).decode()
|
||||
return f"data:image/png;base64,{b64}"
|
||||
|
||||
|
||||
def test_vision():
|
||||
"""Test vision with an image."""
|
||||
print("=" * 60)
|
||||
print("TEST: Vision (image description)")
|
||||
print("=" * 60)
|
||||
image_uri = _make_test_image()
|
||||
print(f"Image: 200x200 PNG with red circle, green triangle, yellow bar")
|
||||
|
||||
r = httpx.post(
|
||||
f"{BASE_URL}/chat/completions",
|
||||
json={
|
||||
"model": MODEL,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "Describe what shapes and colors you see in this image. Be brief."},
|
||||
{"type": "image_url", "image_url": {"url": image_uri}},
|
||||
],
|
||||
}
|
||||
],
|
||||
"max_tokens": 200,
|
||||
"temperature": 0.1,
|
||||
},
|
||||
timeout=120,
|
||||
)
|
||||
r.raise_for_status()
|
||||
data = r.json()
|
||||
msg = data["choices"][0]["message"]["content"]
|
||||
print(f"Response: {msg}")
|
||||
print("PASS\n")
|
||||
|
||||
|
||||
def test_tool_use():
|
||||
"""Test tool calling."""
|
||||
print("=" * 60)
|
||||
print("TEST: Tool use")
|
||||
print("=" * 60)
|
||||
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "Get the current weather for a given city",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description": "The city name, e.g. 'London'",
|
||||
},
|
||||
"units": {
|
||||
"type": "string",
|
||||
"description": "Temperature units: 'celsius' or 'fahrenheit'",
|
||||
},
|
||||
},
|
||||
"required": ["city"],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
# Step 1: Ask the model to use the tool
|
||||
print("Step 1: Asking model to get weather for Paris...")
|
||||
r = httpx.post(
|
||||
f"{BASE_URL}/chat/completions",
|
||||
json={
|
||||
"model": MODEL,
|
||||
"messages": [
|
||||
{"role": "user", "content": "What is the weather in Paris right now? Use the get_weather tool."},
|
||||
],
|
||||
"tools": tools,
|
||||
"max_tokens": 300,
|
||||
"temperature": 0.1,
|
||||
},
|
||||
timeout=120,
|
||||
)
|
||||
r.raise_for_status()
|
||||
data = r.json()
|
||||
choice = data["choices"][0]
|
||||
print(f"Finish reason: {choice['finish_reason']}")
|
||||
print(f"Content: {choice['message'].get('content')}")
|
||||
print(f"Tool calls: {choice['message'].get('tool_calls')}")
|
||||
|
||||
if choice["message"].get("tool_calls"):
|
||||
tc = choice["message"]["tool_calls"][0]
|
||||
print(f"\nTool call detected:")
|
||||
print(f" ID: {tc['id']}")
|
||||
print(f" Function: {tc['function']['name']}")
|
||||
print(f" Arguments: {tc['function']['arguments']}")
|
||||
|
||||
# Step 2: Send the tool result back
|
||||
print("\nStep 2: Sending mock tool result back...")
|
||||
r2 = httpx.post(
|
||||
f"{BASE_URL}/chat/completions",
|
||||
json={
|
||||
"model": MODEL,
|
||||
"messages": [
|
||||
{"role": "user", "content": "What is the weather in Paris right now? Use the get_weather tool."},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": choice["message"].get("content"),
|
||||
"tool_calls": choice["message"]["tool_calls"],
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": tc["id"],
|
||||
"content": json.dumps({"temperature": 18, "condition": "Partly cloudy", "humidity": 65}),
|
||||
},
|
||||
],
|
||||
"tools": tools,
|
||||
"max_tokens": 300,
|
||||
"temperature": 0.1,
|
||||
},
|
||||
timeout=120,
|
||||
)
|
||||
r2.raise_for_status()
|
||||
data2 = r2.json()
|
||||
msg2 = data2["choices"][0]["message"]["content"]
|
||||
print(f"Final response: {msg2}")
|
||||
else:
|
||||
print("WARNING: Model did not produce a tool call. Raw response above.")
|
||||
|
||||
print("PASS\n")
|
||||
|
||||
|
||||
def test_multi_turn():
|
||||
"""Test multi-turn conversation."""
|
||||
print("=" * 60)
|
||||
print("TEST: Multi-turn conversation")
|
||||
print("=" * 60)
|
||||
messages = [
|
||||
{"role": "user", "content": "My name is Alice."},
|
||||
]
|
||||
r = httpx.post(
|
||||
f"{BASE_URL}/chat/completions",
|
||||
json={"model": MODEL, "messages": messages, "max_tokens": 100, "temperature": 0.1},
|
||||
timeout=120,
|
||||
)
|
||||
r.raise_for_status()
|
||||
reply1 = r.json()["choices"][0]["message"]["content"]
|
||||
print(f"Turn 1 reply: {reply1}")
|
||||
|
||||
messages.append({"role": "assistant", "content": reply1})
|
||||
messages.append({"role": "user", "content": "What is my name?"})
|
||||
|
||||
r2 = httpx.post(
|
||||
f"{BASE_URL}/chat/completions",
|
||||
json={"model": MODEL, "messages": messages, "max_tokens": 100, "temperature": 0.1},
|
||||
timeout=120,
|
||||
)
|
||||
r2.raise_for_status()
|
||||
reply2 = r2.json()["choices"][0]["message"]["content"]
|
||||
print(f"Turn 2 reply: {reply2}")
|
||||
assert "alice" in reply2.lower(), f"Expected 'Alice' in response, got: {reply2}"
|
||||
print("PASS\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tests = [
|
||||
test_models,
|
||||
test_chat_basic,
|
||||
test_chat_streaming,
|
||||
test_vision,
|
||||
test_tool_use,
|
||||
test_multi_turn,
|
||||
]
|
||||
|
||||
# Allow running a single test by name
|
||||
if len(sys.argv) > 1:
|
||||
name = sys.argv[1]
|
||||
tests = [t for t in tests if name in t.__name__]
|
||||
if not tests:
|
||||
print(f"No test matching '{name}'. Available: models, chat_basic, chat_streaming, vision, tool_use, multi_turn")
|
||||
sys.exit(1)
|
||||
|
||||
passed = 0
|
||||
failed = 0
|
||||
for test in tests:
|
||||
try:
|
||||
test()
|
||||
passed += 1
|
||||
except Exception as e:
|
||||
print(f"FAIL: {e}\n")
|
||||
failed += 1
|
||||
|
||||
print("=" * 60)
|
||||
print(f"Results: {passed} passed, {failed} failed")
|
||||
print("=" * 60)
|
||||
Reference in New Issue
Block a user