initial commit
This commit is contained in:
10
.gitignore
vendored
Normal file
10
.gitignore
vendored
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
*.egg-info/
|
||||||
|
dist/
|
||||||
|
build/
|
||||||
|
.venv/
|
||||||
|
.env
|
||||||
|
*.log
|
||||||
|
.DS_Store
|
||||||
38
CLAUDE.md
Normal file
38
CLAUDE.md
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
# MLX Server
|
||||||
|
|
||||||
|
OpenAI-compatible API server for Gemma 3 4B (vision + tool use) on Apple Silicon via MLX.
|
||||||
|
|
||||||
|
## Quick Start
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Activate virtual environment
|
||||||
|
source .venv/bin/activate
|
||||||
|
|
||||||
|
# Run the server (downloads model on first run)
|
||||||
|
./run.sh
|
||||||
|
|
||||||
|
# Or directly:
|
||||||
|
python -m mlx_server.main --model mlx-community/gemma-3-4b-it-4bit --port 1234
|
||||||
|
```
|
||||||
|
|
||||||
|
## Project Structure
|
||||||
|
|
||||||
|
- `mlx_server/main.py` — FastAPI server, endpoints, CLI entrypoint
|
||||||
|
- `mlx_server/engine.py` — Model loading, prompt building, generation (mlx_vlm)
|
||||||
|
- `mlx_server/models.py` — Pydantic models for OpenAI API request/response types
|
||||||
|
|
||||||
|
## Key Design Decisions
|
||||||
|
|
||||||
|
- Uses `mlx_vlm` (not `mlx_lm`) as the inference backend — this supports both text and vision in a single model load
|
||||||
|
- Gemma 3 has no system role — system messages are converted to user/assistant pairs
|
||||||
|
- Tool use is prompt-engineered: tools are injected into the system prompt with `<tool_call>` XML tags, and parsed from model output
|
||||||
|
- Thread lock on generation (single-request-at-a-time) — MLX models aren't safe for concurrent generation
|
||||||
|
- 128k context window supported via the model's native capabilities
|
||||||
|
|
||||||
|
## Dependencies
|
||||||
|
|
||||||
|
Managed via `uv` and `pyproject.toml`. Virtual environment in `.venv/`.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uv pip install -e "."
|
||||||
|
```
|
||||||
0
mlx_server/__init__.py
Normal file
0
mlx_server/__init__.py
Normal file
3
mlx_server/__main__.py
Normal file
3
mlx_server/__main__.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
from mlx_server.main import main
|
||||||
|
|
||||||
|
main()
|
||||||
576
mlx_server/engine.py
Normal file
576
mlx_server/engine.py
Normal file
@@ -0,0 +1,576 @@
|
|||||||
|
"""Model loading and inference engine using mlx_vlm (supports both text and vision)."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import io
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
import tempfile
|
||||||
|
import threading
|
||||||
|
from collections.abc import Generator
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx_vlm
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
DEFAULT_MODEL = "mlx-community/gemma-3-4b-it-4bit"
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Helpers for Gemma 3 tool_code format
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
_JSON_TO_PYTHON_TYPE = {
|
||||||
|
"string": "str",
|
||||||
|
"integer": "int",
|
||||||
|
"number": "float",
|
||||||
|
"boolean": "bool",
|
||||||
|
"array": "list",
|
||||||
|
"object": "dict",
|
||||||
|
}
|
||||||
|
|
||||||
|
_JSON_TYPE_DEFAULTS = {
|
||||||
|
"string": '""',
|
||||||
|
"integer": "0",
|
||||||
|
"number": "0.0",
|
||||||
|
"boolean": "False",
|
||||||
|
"array": "[]",
|
||||||
|
"object": "{}",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _json_type_to_python(json_type: str) -> str:
|
||||||
|
return _JSON_TO_PYTHON_TYPE.get(json_type, "str")
|
||||||
|
|
||||||
|
|
||||||
|
def _json_type_default(json_type: str) -> str:
|
||||||
|
return _JSON_TYPE_DEFAULTS.get(json_type, "None")
|
||||||
|
|
||||||
|
|
||||||
|
def _python_repr(value) -> str:
|
||||||
|
"""Produce a Python-repr-style string for a value."""
|
||||||
|
if isinstance(value, str):
|
||||||
|
return repr(value)
|
||||||
|
if isinstance(value, bool):
|
||||||
|
return "True" if value else "False"
|
||||||
|
if isinstance(value, (int, float)):
|
||||||
|
return str(value)
|
||||||
|
return repr(value)
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_python_call(call_str: str, tool_defs: dict[str, dict] | None = None) -> tuple[str, dict]:
|
||||||
|
"""Parse a function call string into (name, args_dict).
|
||||||
|
|
||||||
|
Handles multiple formats:
|
||||||
|
1. Python-style: func_name(arg1="value1", arg2=42)
|
||||||
|
2. Shell-style: func_name arg1 arg2 (common with small LLMs)
|
||||||
|
3. Mixed: func_name("value") (positional args)
|
||||||
|
|
||||||
|
tool_defs maps function names to their parameter schemas, used to
|
||||||
|
infer which parameter a positional/shell-style argument maps to.
|
||||||
|
"""
|
||||||
|
import ast
|
||||||
|
|
||||||
|
call_str = call_str.strip()
|
||||||
|
|
||||||
|
# Try Python-style: function_name(...)
|
||||||
|
m = re.match(r"(\w+)\s*\((.*)\)\s*$", call_str, re.DOTALL)
|
||||||
|
if m:
|
||||||
|
name = m.group(1)
|
||||||
|
args_str = m.group(2).strip()
|
||||||
|
|
||||||
|
if not args_str:
|
||||||
|
return name, {}
|
||||||
|
|
||||||
|
# Try parsing as a Python function call via dict()
|
||||||
|
try:
|
||||||
|
tree = ast.parse(f"dict({args_str})", mode="eval")
|
||||||
|
call_node = tree.body
|
||||||
|
args = {}
|
||||||
|
# Handle keyword arguments: func(key="val")
|
||||||
|
for kw in call_node.keywords:
|
||||||
|
args[kw.arg] = ast.literal_eval(kw.value)
|
||||||
|
# Handle positional arguments: func("val1", "val2")
|
||||||
|
if call_node.args and not args:
|
||||||
|
param_names = _get_param_names(name, tool_defs)
|
||||||
|
for i, arg_node in enumerate(call_node.args):
|
||||||
|
val = ast.literal_eval(arg_node)
|
||||||
|
if i < len(param_names):
|
||||||
|
args[param_names[i]] = val
|
||||||
|
else:
|
||||||
|
args[f"arg{i}"] = val
|
||||||
|
if args:
|
||||||
|
return name, args
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Fallback: regex-based key=value parsing
|
||||||
|
args = {}
|
||||||
|
for pair_match in re.finditer(r"(\w+)\s*=\s*(.+?)(?:,\s*(?=\w+\s*=)|$)", args_str, re.DOTALL):
|
||||||
|
key = pair_match.group(1)
|
||||||
|
val_str = pair_match.group(2).strip()
|
||||||
|
try:
|
||||||
|
args[key] = ast.literal_eval(val_str)
|
||||||
|
except Exception:
|
||||||
|
args[key] = val_str
|
||||||
|
return name, args
|
||||||
|
|
||||||
|
# Shell-style: "func_name arg1 arg2" or "func_name some/path"
|
||||||
|
# Also handles: "func_name -flag arg" (common with shell tools)
|
||||||
|
parts = call_str.split(None, 1)
|
||||||
|
if parts and re.match(r"^\w+$", parts[0]):
|
||||||
|
name = parts[0]
|
||||||
|
if len(parts) == 1:
|
||||||
|
return name, {}
|
||||||
|
|
||||||
|
rest = parts[1].strip()
|
||||||
|
param_names = _get_param_names(name, tool_defs)
|
||||||
|
first_param = param_names[0] if param_names else "input"
|
||||||
|
return name, {first_param: rest}
|
||||||
|
|
||||||
|
# Last resort: treat the entire block as a command for the first
|
||||||
|
# known tool that looks like a shell/command tool, or just fail
|
||||||
|
raise ValueError(f"Cannot parse as function call: {call_str!r}")
|
||||||
|
|
||||||
|
|
||||||
|
def _get_param_names(func_name: str, tool_defs: dict[str, dict] | None) -> list[str]:
|
||||||
|
"""Get ordered parameter names for a function from tool definitions."""
|
||||||
|
if not tool_defs or func_name not in tool_defs:
|
||||||
|
return []
|
||||||
|
params = tool_defs[func_name].get("parameters", {})
|
||||||
|
properties = params.get("properties", {})
|
||||||
|
required = params.get("required", [])
|
||||||
|
# Required params first, then optional
|
||||||
|
optional = [k for k in properties if k not in required]
|
||||||
|
return list(required) + optional
|
||||||
|
|
||||||
|
|
||||||
|
class PromptCache:
|
||||||
|
"""Manages KV cache reuse across requests with shared prompt prefixes."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._cache = None
|
||||||
|
self._cached_token_ids: list[int] | None = None
|
||||||
|
|
||||||
|
def get_reusable_length(self, new_token_ids: list[int]) -> int:
|
||||||
|
"""Find how many leading tokens match the cached prefix."""
|
||||||
|
if self._cached_token_ids is None or self._cache is None:
|
||||||
|
return 0
|
||||||
|
max_match = min(len(self._cached_token_ids), len(new_token_ids))
|
||||||
|
match_len = 0
|
||||||
|
for i in range(max_match):
|
||||||
|
if self._cached_token_ids[i] != new_token_ids[i]:
|
||||||
|
break
|
||||||
|
match_len = i + 1
|
||||||
|
return match_len
|
||||||
|
|
||||||
|
def update(self, cache, token_ids: list[int]) -> None:
|
||||||
|
"""Store cache and the token IDs it was built from."""
|
||||||
|
self._cache = cache
|
||||||
|
self._cached_token_ids = list(token_ids)
|
||||||
|
|
||||||
|
def clear(self) -> None:
|
||||||
|
self._cache = None
|
||||||
|
self._cached_token_ids = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def cache(self):
|
||||||
|
return self._cache
|
||||||
|
|
||||||
|
|
||||||
|
class InferenceEngine:
|
||||||
|
"""Manages model loading and text/vision generation."""
|
||||||
|
|
||||||
|
def __init__(self, model_path: str = DEFAULT_MODEL):
|
||||||
|
self.model_path = model_path
|
||||||
|
self.model = None
|
||||||
|
self.processor = None
|
||||||
|
self.config = None
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
self._prompt_cache = PromptCache()
|
||||||
|
|
||||||
|
def load(self) -> None:
|
||||||
|
logger.info("Loading model %s ...", self.model_path)
|
||||||
|
self.model, self.processor = mlx_vlm.load(self.model_path)
|
||||||
|
# Load model config for chat template
|
||||||
|
from transformers import AutoConfig
|
||||||
|
|
||||||
|
self.config = AutoConfig.from_pretrained(self.model_path, trust_remote_code=True)
|
||||||
|
logger.info("Model loaded successfully.")
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Image helpers
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _decode_image_url(url: str) -> str:
|
||||||
|
"""Convert a data URI or URL to a file path that mlx_vlm can consume."""
|
||||||
|
if url.startswith("data:"):
|
||||||
|
# data:image/png;base64,iVBOR...
|
||||||
|
header, b64data = url.split(",", 1)
|
||||||
|
img_bytes = base64.b64decode(b64data)
|
||||||
|
img = Image.open(io.BytesIO(img_bytes))
|
||||||
|
tmp = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
|
||||||
|
img.save(tmp, format="PNG")
|
||||||
|
tmp.close()
|
||||||
|
return tmp.name
|
||||||
|
# Assume it's a URL or local path – mlx_vlm handles URLs natively
|
||||||
|
return url
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Prompt building
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def build_prompt(
|
||||||
|
self,
|
||||||
|
messages: list[dict],
|
||||||
|
tools: list[dict] | None = None,
|
||||||
|
) -> tuple[str, list[str]]:
|
||||||
|
"""Build a prompt string and collect image paths from messages.
|
||||||
|
|
||||||
|
Returns (prompt_str, image_paths).
|
||||||
|
"""
|
||||||
|
image_paths: list[str] = []
|
||||||
|
formatted_messages: list[dict] = []
|
||||||
|
|
||||||
|
for msg in messages:
|
||||||
|
role = msg["role"]
|
||||||
|
content = msg.get("content")
|
||||||
|
tool_calls = msg.get("tool_calls")
|
||||||
|
tool_call_id = msg.get("tool_call_id")
|
||||||
|
|
||||||
|
if role == "system":
|
||||||
|
text = self._get_text_content(content)
|
||||||
|
# Inject tool definitions into system prompt
|
||||||
|
if tools:
|
||||||
|
text = self._inject_tools_into_system(text, tools)
|
||||||
|
formatted_messages.append({"role": "user", "content": text})
|
||||||
|
# Gemma 3 doesn't have a system role; we use the user role
|
||||||
|
# and add a model acknowledgment
|
||||||
|
formatted_messages.append({
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "Understood. I will follow these instructions.",
|
||||||
|
})
|
||||||
|
elif role == "user":
|
||||||
|
text, imgs = self._extract_content_parts(content)
|
||||||
|
image_paths.extend(imgs)
|
||||||
|
formatted_messages.append({"role": "user", "content": text})
|
||||||
|
elif role == "assistant":
|
||||||
|
text = self._get_text_content(content) or ""
|
||||||
|
if tool_calls:
|
||||||
|
# Format tool calls in the way Gemma 3 expects
|
||||||
|
tc_text = self._format_tool_calls_for_prompt(tool_calls)
|
||||||
|
text = (text + "\n" + tc_text).strip()
|
||||||
|
formatted_messages.append({"role": "assistant", "content": text})
|
||||||
|
elif role == "tool":
|
||||||
|
# Tool results use Gemma 3's tool_output format
|
||||||
|
tool_text = self._get_text_content(content) or ""
|
||||||
|
result_msg = f"```tool_output\n{tool_text}\n```"
|
||||||
|
formatted_messages.append({"role": "user", "content": result_msg})
|
||||||
|
|
||||||
|
# If the first system prompt had no tools but we have tools, inject at start
|
||||||
|
if tools and not any(m.get("role") == "system" for m in messages):
|
||||||
|
tool_system = self._build_tool_system_prompt(tools)
|
||||||
|
formatted_messages.insert(0, {"role": "user", "content": tool_system})
|
||||||
|
formatted_messages.insert(1, {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "Understood. I will follow these instructions and use tools when appropriate.",
|
||||||
|
})
|
||||||
|
|
||||||
|
# Gemma 3 requires strictly alternating user/assistant turns.
|
||||||
|
# Merge consecutive same-role messages and ensure it starts with user.
|
||||||
|
formatted_messages = self._merge_consecutive_roles(formatted_messages)
|
||||||
|
|
||||||
|
# Apply chat template via mlx_vlm
|
||||||
|
prompt = mlx_vlm.apply_chat_template(
|
||||||
|
self.processor,
|
||||||
|
self.config,
|
||||||
|
formatted_messages,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
num_images=len(image_paths),
|
||||||
|
)
|
||||||
|
|
||||||
|
return prompt, image_paths
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _merge_consecutive_roles(messages: list[dict]) -> list[dict]:
|
||||||
|
"""Merge consecutive messages with the same role into one.
|
||||||
|
|
||||||
|
Gemma 3's chat template enforces strict user/assistant alternation.
|
||||||
|
"""
|
||||||
|
if not messages:
|
||||||
|
return messages
|
||||||
|
|
||||||
|
merged = [messages[0].copy()]
|
||||||
|
for msg in messages[1:]:
|
||||||
|
if msg["role"] == merged[-1]["role"]:
|
||||||
|
# Merge content with the previous message
|
||||||
|
merged[-1]["content"] = (
|
||||||
|
merged[-1].get("content", "") + "\n\n" + msg.get("content", "")
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
merged.append(msg.copy())
|
||||||
|
|
||||||
|
# Ensure conversation starts with user
|
||||||
|
if merged and merged[0]["role"] != "user":
|
||||||
|
merged.insert(0, {"role": "user", "content": ""})
|
||||||
|
|
||||||
|
return merged
|
||||||
|
|
||||||
|
def _get_text_content(self, content) -> str:
|
||||||
|
if content is None:
|
||||||
|
return ""
|
||||||
|
if isinstance(content, str):
|
||||||
|
return content
|
||||||
|
# list of content parts
|
||||||
|
parts = []
|
||||||
|
for part in content:
|
||||||
|
if isinstance(part, dict) and part.get("type") == "text":
|
||||||
|
parts.append(part["text"])
|
||||||
|
return "\n".join(parts)
|
||||||
|
|
||||||
|
def _extract_content_parts(self, content) -> tuple[str, list[str]]:
|
||||||
|
"""Extract text and image paths from content parts."""
|
||||||
|
if isinstance(content, str):
|
||||||
|
return content, []
|
||||||
|
if content is None:
|
||||||
|
return "", []
|
||||||
|
|
||||||
|
texts = []
|
||||||
|
images = []
|
||||||
|
for part in content:
|
||||||
|
if isinstance(part, dict):
|
||||||
|
if part.get("type") == "text":
|
||||||
|
texts.append(part["text"])
|
||||||
|
elif part.get("type") == "image_url":
|
||||||
|
url = part["image_url"]["url"]
|
||||||
|
images.append(self._decode_image_url(url))
|
||||||
|
return "\n".join(texts), images
|
||||||
|
|
||||||
|
def _inject_tools_into_system(self, system_text: str, tools: list[dict]) -> str:
|
||||||
|
tool_block = self._build_tool_system_prompt(tools)
|
||||||
|
return f"{system_text}\n\n{tool_block}"
|
||||||
|
|
||||||
|
def _build_tool_system_prompt(self, tools: list[dict]) -> str:
|
||||||
|
"""Build the tool system prompt using Google's official Gemma 3 format.
|
||||||
|
|
||||||
|
Uses the tool_code/tool_output convention recommended by Google:
|
||||||
|
- Tools defined as Python function signatures with docstrings
|
||||||
|
- Model outputs calls in ```tool_code``` fenced blocks
|
||||||
|
- Results returned in ```tool_output``` fenced blocks
|
||||||
|
"""
|
||||||
|
func_defs = []
|
||||||
|
for tool in tools:
|
||||||
|
func = tool.get("function", tool)
|
||||||
|
func_defs.append(self._tool_to_python_signature(func))
|
||||||
|
|
||||||
|
functions_block = "\n\n".join(func_defs)
|
||||||
|
|
||||||
|
return (
|
||||||
|
"At each turn, if you decide to invoke any of the function(s), "
|
||||||
|
"it should be wrapped with ```tool_code```. "
|
||||||
|
"The python methods described below are imported and available, "
|
||||||
|
"you can only use defined methods. "
|
||||||
|
"The generated code should be readable and efficient. "
|
||||||
|
"The response to a method will be wrapped in ```tool_output``` "
|
||||||
|
"use it to call more tools or generate a helpful, friendly response.\n"
|
||||||
|
"\n"
|
||||||
|
f"{functions_block}"
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _tool_to_python_signature(func: dict) -> str:
|
||||||
|
"""Convert an OpenAI function definition to a Python function signature with docstring."""
|
||||||
|
name = func["name"]
|
||||||
|
desc = func.get("description", "")
|
||||||
|
params = func.get("parameters", {})
|
||||||
|
properties = params.get("properties", {})
|
||||||
|
required = set(params.get("required", []))
|
||||||
|
|
||||||
|
# Build parameter list
|
||||||
|
param_parts = []
|
||||||
|
doc_args = []
|
||||||
|
for pname, pinfo in properties.items():
|
||||||
|
ptype = _json_type_to_python(pinfo.get("type", "str"))
|
||||||
|
pdesc = pinfo.get("description", "")
|
||||||
|
if pname in required:
|
||||||
|
param_parts.append(f"{pname}: {ptype}")
|
||||||
|
else:
|
||||||
|
default = _json_type_default(pinfo.get("type", "str"))
|
||||||
|
param_parts.append(f"{pname}: {ptype} = {default}")
|
||||||
|
doc_args.append(f" {pname}: {pdesc}" if pdesc else f" {pname}")
|
||||||
|
|
||||||
|
sig = f"def {name}({', '.join(param_parts)}):"
|
||||||
|
doc_lines = [f' """{desc}']
|
||||||
|
if doc_args:
|
||||||
|
doc_lines.append("")
|
||||||
|
doc_lines.append(" Args:")
|
||||||
|
doc_lines.extend(doc_args)
|
||||||
|
doc_lines.append(' """')
|
||||||
|
|
||||||
|
return sig + "\n" + "\n".join(doc_lines)
|
||||||
|
|
||||||
|
def _format_tool_calls_for_prompt(self, tool_calls: list[dict]) -> str:
|
||||||
|
"""Format OpenAI-style tool calls back into Gemma 3 tool_code blocks."""
|
||||||
|
parts = []
|
||||||
|
for tc in tool_calls:
|
||||||
|
func = tc.get("function", tc)
|
||||||
|
name = func["name"]
|
||||||
|
args = func.get("arguments", "{}")
|
||||||
|
if isinstance(args, str):
|
||||||
|
args = json.loads(args)
|
||||||
|
# Format as Python function call
|
||||||
|
arg_parts = [f"{k}={_python_repr(v)}" for k, v in args.items()]
|
||||||
|
call_str = f"{name}({', '.join(arg_parts)})"
|
||||||
|
parts.append(f"```tool_code\n{call_str}\n```")
|
||||||
|
return "\n".join(parts)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Generation
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
# Common kwargs for mlx_vlm generate calls — optimized for Apple Silicon
|
||||||
|
_GENERATE_KWARGS = {
|
||||||
|
"kv_bits": 8, # Quantize KV cache to 8-bit (halves memory bandwidth)
|
||||||
|
"kv_group_size": 64, # Group size for KV quantization
|
||||||
|
}
|
||||||
|
|
||||||
|
def generate(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
images: list[str] | None = None,
|
||||||
|
max_tokens: int = 4096,
|
||||||
|
temperature: float = 0.7,
|
||||||
|
top_p: float = 0.9,
|
||||||
|
stop: list[str] | None = None,
|
||||||
|
repetition_penalty: float = 1.1,
|
||||||
|
) -> tuple[str, int, int]:
|
||||||
|
"""Generate a complete response. Returns (text, prompt_tokens, completion_tokens)."""
|
||||||
|
with self._lock:
|
||||||
|
image_arg = images if images else None
|
||||||
|
result = mlx_vlm.generate(
|
||||||
|
self.model,
|
||||||
|
self.processor,
|
||||||
|
prompt,
|
||||||
|
image=image_arg,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
temperature=temperature,
|
||||||
|
top_p=top_p,
|
||||||
|
repetition_penalty=repetition_penalty,
|
||||||
|
verbose=False,
|
||||||
|
**self._GENERATE_KWARGS,
|
||||||
|
)
|
||||||
|
text = result.text
|
||||||
|
if stop:
|
||||||
|
text = self._apply_stop(text, stop)
|
||||||
|
return text, result.prompt_tokens, result.generation_tokens
|
||||||
|
|
||||||
|
def stream_generate(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
images: list[str] | None = None,
|
||||||
|
max_tokens: int = 4096,
|
||||||
|
temperature: float = 0.7,
|
||||||
|
top_p: float = 0.9,
|
||||||
|
stop: list[str] | None = None,
|
||||||
|
repetition_penalty: float = 1.1,
|
||||||
|
) -> Generator[tuple[str, bool, int, int], None, None]:
|
||||||
|
"""Stream tokens. Yields (token_text, is_final, prompt_tokens, gen_tokens)."""
|
||||||
|
with self._lock:
|
||||||
|
image_arg = images if images else None
|
||||||
|
accumulated = ""
|
||||||
|
prompt_tokens = 0
|
||||||
|
gen_tokens = 0
|
||||||
|
for result in mlx_vlm.stream_generate(
|
||||||
|
self.model,
|
||||||
|
self.processor,
|
||||||
|
prompt,
|
||||||
|
image=image_arg,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
temperature=temperature,
|
||||||
|
top_p=top_p,
|
||||||
|
repetition_penalty=repetition_penalty,
|
||||||
|
**self._GENERATE_KWARGS,
|
||||||
|
):
|
||||||
|
# result.text is the incremental segment (detokenizer.last_segment),
|
||||||
|
# NOT the full accumulated text.
|
||||||
|
token_text = result.text
|
||||||
|
accumulated += token_text
|
||||||
|
prompt_tokens = result.prompt_tokens
|
||||||
|
gen_tokens = result.generation_tokens
|
||||||
|
|
||||||
|
if stop and self._check_stop(accumulated, stop):
|
||||||
|
# Trim the accumulated text and yield what's safe
|
||||||
|
trimmed = self._apply_stop(accumulated, stop)
|
||||||
|
# Only yield the part we haven't yielded yet
|
||||||
|
safe_delta = trimmed[len(accumulated) - len(token_text):]
|
||||||
|
yield safe_delta, True, prompt_tokens, gen_tokens
|
||||||
|
return
|
||||||
|
|
||||||
|
yield token_text, False, prompt_tokens, gen_tokens
|
||||||
|
|
||||||
|
# Final yield to signal completion
|
||||||
|
yield "", True, prompt_tokens, gen_tokens
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _apply_stop(text: str, stop: list[str]) -> str:
|
||||||
|
for s in stop:
|
||||||
|
idx = text.find(s)
|
||||||
|
if idx != -1:
|
||||||
|
text = text[:idx]
|
||||||
|
return text
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _check_stop(text: str, stop: list[str]) -> bool:
|
||||||
|
return any(s in text for s in stop)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Tool call parsing from model output
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def parse_tool_calls(
|
||||||
|
text: str, tools: list[dict] | None = None
|
||||||
|
) -> tuple[str, list[dict]]:
|
||||||
|
"""Parse tool calls from model output using Gemma 3's tool_code format.
|
||||||
|
|
||||||
|
Detects ```tool_code ... ``` blocks containing Python-style or
|
||||||
|
shell-style function calls.
|
||||||
|
|
||||||
|
Returns (clean_text, tool_calls) where tool_calls is a list of
|
||||||
|
{"id": str, "type": "function", "function": {"name": str, "arguments": str}}.
|
||||||
|
"""
|
||||||
|
# Build a lookup of function name -> parameter schema
|
||||||
|
tool_defs: dict[str, dict] = {}
|
||||||
|
if tools:
|
||||||
|
for tool in tools:
|
||||||
|
func = tool.get("function", tool)
|
||||||
|
tool_defs[func["name"]] = func
|
||||||
|
|
||||||
|
tool_calls = []
|
||||||
|
pattern = r"```tool_code\s*(.*?)\s*```"
|
||||||
|
matches = re.findall(pattern, text, re.DOTALL)
|
||||||
|
|
||||||
|
clean_text = re.sub(r"```tool_code\s*.*?\s*```", "", text, flags=re.DOTALL).strip()
|
||||||
|
|
||||||
|
for i, match in enumerate(matches):
|
||||||
|
call_str = match.strip()
|
||||||
|
try:
|
||||||
|
name, args = _parse_python_call(call_str, tool_defs)
|
||||||
|
tool_calls.append({
|
||||||
|
"id": f"call_{i}_{hash(call_str) % 10**8:08d}",
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": name,
|
||||||
|
"arguments": json.dumps(args),
|
||||||
|
},
|
||||||
|
})
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("Failed to parse tool_code call %r: %s", call_str, e)
|
||||||
|
|
||||||
|
return clean_text, tool_calls
|
||||||
278
mlx_server/main.py
Normal file
278
mlx_server/main.py
Normal file
@@ -0,0 +1,278 @@
|
|||||||
|
"""OpenAI-compatible API server for Gemma 3 4B via MLX."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
import uvicorn
|
||||||
|
from fastapi import FastAPI, HTTPException
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
from sse_starlette.sse import EventSourceResponse
|
||||||
|
|
||||||
|
from .engine import DEFAULT_MODEL, InferenceEngine
|
||||||
|
from .models import (
|
||||||
|
ChatCompletionChunk,
|
||||||
|
ChatCompletionRequest,
|
||||||
|
ChatCompletionResponse,
|
||||||
|
Choice,
|
||||||
|
ChoiceMessage,
|
||||||
|
DeltaMessage,
|
||||||
|
ModelInfo,
|
||||||
|
ModelListResponse,
|
||||||
|
StreamChoice,
|
||||||
|
ToolCall,
|
||||||
|
FunctionCall,
|
||||||
|
UsageInfo,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
app = FastAPI(title="MLX Server", description="OpenAI-compatible API for Gemma 3 4B")
|
||||||
|
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=["*"],
|
||||||
|
allow_credentials=True,
|
||||||
|
allow_methods=["*"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
)
|
||||||
|
|
||||||
|
engine: InferenceEngine | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_engine() -> InferenceEngine:
|
||||||
|
if engine is None:
|
||||||
|
raise HTTPException(status_code=503, detail="Model not loaded")
|
||||||
|
return engine
|
||||||
|
|
||||||
|
|
||||||
|
def _make_id() -> str:
|
||||||
|
return f"chatcmpl-{uuid.uuid4().hex[:12]}"
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Endpoints
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/v1/models")
|
||||||
|
async def list_models() -> ModelListResponse:
|
||||||
|
e = get_engine()
|
||||||
|
return ModelListResponse(data=[ModelInfo(id=e.model_path)])
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/v1/chat/completions")
|
||||||
|
async def chat_completions(request: ChatCompletionRequest):
|
||||||
|
e = get_engine()
|
||||||
|
|
||||||
|
# Convert pydantic messages to dicts
|
||||||
|
messages = [m.model_dump(exclude_none=True) for m in request.messages]
|
||||||
|
tools = None
|
||||||
|
if request.tools:
|
||||||
|
tools = [t.model_dump(exclude_none=True) for t in request.tools]
|
||||||
|
|
||||||
|
prompt, images = e.build_prompt(messages, tools)
|
||||||
|
|
||||||
|
stop = request.stop
|
||||||
|
if isinstance(stop, str):
|
||||||
|
stop = [stop]
|
||||||
|
|
||||||
|
temperature = request.temperature if request.temperature is not None else 0.7
|
||||||
|
top_p = request.top_p if request.top_p is not None else 0.9
|
||||||
|
max_tokens = request.max_tokens if request.max_tokens is not None else 4096
|
||||||
|
|
||||||
|
if request.stream:
|
||||||
|
return EventSourceResponse(
|
||||||
|
_stream_response(e, prompt, images, max_tokens, temperature, top_p, stop, tools, request.model),
|
||||||
|
media_type="text/event-stream",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Non-streaming
|
||||||
|
text, prompt_tokens, completion_tokens = e.generate(
|
||||||
|
prompt=prompt,
|
||||||
|
images=images or None,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
temperature=temperature,
|
||||||
|
top_p=top_p,
|
||||||
|
stop=stop,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check for tool calls in the response
|
||||||
|
finish_reason = "stop"
|
||||||
|
tool_calls_parsed = None
|
||||||
|
if tools:
|
||||||
|
clean_text, parsed = e.parse_tool_calls(text, tools)
|
||||||
|
if parsed:
|
||||||
|
tool_calls_parsed = [
|
||||||
|
ToolCall(
|
||||||
|
index=i,
|
||||||
|
id=tc["id"],
|
||||||
|
type="function",
|
||||||
|
function=FunctionCall(
|
||||||
|
name=tc["function"]["name"],
|
||||||
|
arguments=tc["function"]["arguments"],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
for i, tc in enumerate(parsed)
|
||||||
|
]
|
||||||
|
text = clean_text if clean_text else None
|
||||||
|
finish_reason = "tool_calls"
|
||||||
|
|
||||||
|
return ChatCompletionResponse(
|
||||||
|
id=_make_id(),
|
||||||
|
model=request.model,
|
||||||
|
choices=[
|
||||||
|
Choice(
|
||||||
|
message=ChoiceMessage(
|
||||||
|
role="assistant",
|
||||||
|
content=text if not tool_calls_parsed else (text or None),
|
||||||
|
tool_calls=tool_calls_parsed,
|
||||||
|
),
|
||||||
|
finish_reason=finish_reason,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
usage=UsageInfo(
|
||||||
|
prompt_tokens=prompt_tokens,
|
||||||
|
completion_tokens=completion_tokens,
|
||||||
|
total_tokens=prompt_tokens + completion_tokens,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _stream_response(
|
||||||
|
e: InferenceEngine,
|
||||||
|
prompt: str,
|
||||||
|
images: list[str] | None,
|
||||||
|
max_tokens: int,
|
||||||
|
temperature: float,
|
||||||
|
top_p: float,
|
||||||
|
stop: list[str] | None,
|
||||||
|
tools: list[dict] | None,
|
||||||
|
model_name: str,
|
||||||
|
):
|
||||||
|
request_id = _make_id()
|
||||||
|
created = int(time.time())
|
||||||
|
|
||||||
|
# Send initial chunk with role
|
||||||
|
initial_chunk = ChatCompletionChunk(
|
||||||
|
id=request_id,
|
||||||
|
created=created,
|
||||||
|
model=model_name,
|
||||||
|
choices=[StreamChoice(delta=DeltaMessage(role="assistant"))],
|
||||||
|
)
|
||||||
|
yield {"data": initial_chunk.model_dump_json()}
|
||||||
|
|
||||||
|
full_text = ""
|
||||||
|
prompt_tokens = 0
|
||||||
|
gen_tokens = 0
|
||||||
|
|
||||||
|
for token_text, is_final, pt, gt in e.stream_generate(
|
||||||
|
prompt=prompt,
|
||||||
|
images=images or None,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
temperature=temperature,
|
||||||
|
top_p=top_p,
|
||||||
|
stop=stop,
|
||||||
|
):
|
||||||
|
prompt_tokens = pt
|
||||||
|
gen_tokens = gt
|
||||||
|
full_text += token_text
|
||||||
|
|
||||||
|
if not is_final and token_text:
|
||||||
|
chunk = ChatCompletionChunk(
|
||||||
|
id=request_id,
|
||||||
|
created=created,
|
||||||
|
model=model_name,
|
||||||
|
choices=[StreamChoice(delta=DeltaMessage(content=token_text))],
|
||||||
|
)
|
||||||
|
yield {"data": chunk.model_dump_json()}
|
||||||
|
|
||||||
|
# Check for tool calls in complete response
|
||||||
|
finish_reason = "stop"
|
||||||
|
if tools:
|
||||||
|
clean_text, parsed = e.parse_tool_calls(full_text, tools)
|
||||||
|
if parsed:
|
||||||
|
finish_reason = "tool_calls"
|
||||||
|
# Emit tool call chunks
|
||||||
|
for i, tc in enumerate(parsed):
|
||||||
|
tc_chunk = ChatCompletionChunk(
|
||||||
|
id=request_id,
|
||||||
|
created=created,
|
||||||
|
model=model_name,
|
||||||
|
choices=[
|
||||||
|
StreamChoice(
|
||||||
|
delta=DeltaMessage(
|
||||||
|
tool_calls=[
|
||||||
|
ToolCall(
|
||||||
|
index=i,
|
||||||
|
id=tc["id"],
|
||||||
|
type="function",
|
||||||
|
function=FunctionCall(
|
||||||
|
name=tc["function"]["name"],
|
||||||
|
arguments=tc["function"]["arguments"],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
yield {"data": tc_chunk.model_dump_json()}
|
||||||
|
|
||||||
|
# Final chunk with finish reason and usage
|
||||||
|
final_chunk = ChatCompletionChunk(
|
||||||
|
id=request_id,
|
||||||
|
created=created,
|
||||||
|
model=model_name,
|
||||||
|
choices=[StreamChoice(delta=DeltaMessage(), finish_reason=finish_reason)],
|
||||||
|
usage=UsageInfo(
|
||||||
|
prompt_tokens=prompt_tokens,
|
||||||
|
completion_tokens=gen_tokens,
|
||||||
|
total_tokens=prompt_tokens + gen_tokens,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
yield {"data": final_chunk.model_dump_json()}
|
||||||
|
yield {"data": "[DONE]"}
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Health / utility
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/health")
|
||||||
|
async def health():
|
||||||
|
return {"status": "ok"}
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Entrypoint
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="MLX Server – OpenAI-compatible API")
|
||||||
|
parser.add_argument("--model", type=str, default=DEFAULT_MODEL, help="HuggingFace model path")
|
||||||
|
parser.add_argument("--host", type=str, default="127.0.0.1")
|
||||||
|
parser.add_argument("--port", type=int, default=1234)
|
||||||
|
parser.add_argument("--log-level", type=str, default="info")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
logging.basicConfig(
|
||||||
|
level=getattr(logging, args.log_level.upper()),
|
||||||
|
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
|
||||||
|
)
|
||||||
|
|
||||||
|
global engine
|
||||||
|
engine = InferenceEngine(model_path=args.model)
|
||||||
|
engine.load()
|
||||||
|
|
||||||
|
uvicorn.run(app, host=args.host, port=args.port, log_level=args.log_level)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
144
mlx_server/models.py
Normal file
144
mlx_server/models.py
Normal file
@@ -0,0 +1,144 @@
|
|||||||
|
"""OpenAI API compatible request/response models."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import time
|
||||||
|
from typing import Any, Literal
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
# --- Request models ---
|
||||||
|
|
||||||
|
|
||||||
|
class FunctionDefinition(BaseModel):
|
||||||
|
name: str
|
||||||
|
description: str | None = None
|
||||||
|
parameters: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class ToolDefinition(BaseModel):
|
||||||
|
type: Literal["function"] = "function"
|
||||||
|
function: FunctionDefinition
|
||||||
|
|
||||||
|
|
||||||
|
class FunctionCall(BaseModel):
|
||||||
|
name: str
|
||||||
|
arguments: str # JSON string
|
||||||
|
|
||||||
|
|
||||||
|
class ToolCall(BaseModel):
|
||||||
|
index: int = 0
|
||||||
|
id: str
|
||||||
|
type: Literal["function"] = "function"
|
||||||
|
function: FunctionCall
|
||||||
|
|
||||||
|
|
||||||
|
class ContentPartText(BaseModel):
|
||||||
|
type: Literal["text"] = "text"
|
||||||
|
text: str
|
||||||
|
|
||||||
|
|
||||||
|
class ImageURL(BaseModel):
|
||||||
|
url: str # Can be a URL or base64 data URI
|
||||||
|
detail: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class ContentPartImage(BaseModel):
|
||||||
|
type: Literal["image_url"] = "image_url"
|
||||||
|
image_url: ImageURL
|
||||||
|
|
||||||
|
|
||||||
|
ContentPart = ContentPartText | ContentPartImage
|
||||||
|
|
||||||
|
|
||||||
|
class ChatMessage(BaseModel):
|
||||||
|
role: Literal["system", "user", "assistant", "tool"]
|
||||||
|
content: str | list[ContentPart] | None = None
|
||||||
|
name: str | None = None
|
||||||
|
tool_calls: list[ToolCall] | None = None
|
||||||
|
tool_call_id: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionRequest(BaseModel):
|
||||||
|
model: str = "gemma-3-4b-it"
|
||||||
|
messages: list[ChatMessage]
|
||||||
|
temperature: float | None = 0.7
|
||||||
|
top_p: float | None = 0.9
|
||||||
|
max_tokens: int | None = 4096
|
||||||
|
stream: bool = False
|
||||||
|
stop: str | list[str] | None = None
|
||||||
|
tools: list[ToolDefinition] | None = None
|
||||||
|
tool_choice: str | dict | None = None
|
||||||
|
frequency_penalty: float | None = None
|
||||||
|
presence_penalty: float | None = None
|
||||||
|
n: int | None = 1
|
||||||
|
|
||||||
|
|
||||||
|
# --- Response models ---
|
||||||
|
|
||||||
|
|
||||||
|
class UsageInfo(BaseModel):
|
||||||
|
prompt_tokens: int = 0
|
||||||
|
completion_tokens: int = 0
|
||||||
|
total_tokens: int = 0
|
||||||
|
|
||||||
|
|
||||||
|
class ChoiceMessage(BaseModel):
|
||||||
|
role: str = "assistant"
|
||||||
|
content: str | None = None
|
||||||
|
tool_calls: list[ToolCall] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class Choice(BaseModel):
|
||||||
|
index: int = 0
|
||||||
|
message: ChoiceMessage
|
||||||
|
finish_reason: str | None = "stop"
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionResponse(BaseModel):
|
||||||
|
id: str
|
||||||
|
object: str = "chat.completion"
|
||||||
|
created: int = Field(default_factory=lambda: int(time.time()))
|
||||||
|
model: str
|
||||||
|
choices: list[Choice]
|
||||||
|
usage: UsageInfo
|
||||||
|
|
||||||
|
|
||||||
|
# --- Streaming response models ---
|
||||||
|
|
||||||
|
|
||||||
|
class DeltaMessage(BaseModel):
|
||||||
|
role: str | None = None
|
||||||
|
content: str | None = None
|
||||||
|
tool_calls: list[ToolCall] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class StreamChoice(BaseModel):
|
||||||
|
index: int = 0
|
||||||
|
delta: DeltaMessage
|
||||||
|
finish_reason: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionChunk(BaseModel):
|
||||||
|
id: str
|
||||||
|
object: str = "chat.completion.chunk"
|
||||||
|
created: int = Field(default_factory=lambda: int(time.time()))
|
||||||
|
model: str
|
||||||
|
choices: list[StreamChoice]
|
||||||
|
usage: UsageInfo | None = None
|
||||||
|
|
||||||
|
|
||||||
|
# --- Model listing ---
|
||||||
|
|
||||||
|
|
||||||
|
class ModelInfo(BaseModel):
|
||||||
|
id: str
|
||||||
|
object: str = "model"
|
||||||
|
created: int = Field(default_factory=lambda: int(time.time()))
|
||||||
|
owned_by: str = "local"
|
||||||
|
|
||||||
|
|
||||||
|
class ModelListResponse(BaseModel):
|
||||||
|
object: str = "list"
|
||||||
|
data: list[ModelInfo]
|
||||||
20
pyproject.toml
Normal file
20
pyproject.toml
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
[project]
|
||||||
|
name = "mlx-server"
|
||||||
|
version = "0.1.0"
|
||||||
|
description = "OpenAI-compatible API server for Gemma 3 4B via MLX"
|
||||||
|
requires-python = ">=3.11"
|
||||||
|
dependencies = [
|
||||||
|
"fastapi>=0.115.0",
|
||||||
|
"uvicorn[standard]>=0.30.0",
|
||||||
|
"mlx>=0.22.0",
|
||||||
|
"mlx-lm>=0.22.0",
|
||||||
|
"mlx-vlm>=0.1.18",
|
||||||
|
"pydantic>=2.0.0",
|
||||||
|
"sse-starlette>=2.0.0",
|
||||||
|
"pillow>=10.0.0",
|
||||||
|
"httpx>=0.27.0",
|
||||||
|
"torchvision>=0.20.0",
|
||||||
|
]
|
||||||
|
|
||||||
|
[project.scripts]
|
||||||
|
mlx-server = "mlx_server.main:main"
|
||||||
24
run.sh
Executable file
24
run.sh
Executable file
@@ -0,0 +1,24 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
set -euo pipefail
|
||||||
|
|
||||||
|
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
|
||||||
|
cd "$SCRIPT_DIR"
|
||||||
|
|
||||||
|
# Activate virtual environment
|
||||||
|
source .venv/bin/activate
|
||||||
|
|
||||||
|
# Default model – 4-bit quantized Gemma 3 4B IT (vision-capable)
|
||||||
|
MODEL="${MODEL:-mlx-community/gemma-3-4b-it-4bit}"
|
||||||
|
HOST="${HOST:-127.0.0.1}"
|
||||||
|
PORT="${PORT:-1234}"
|
||||||
|
|
||||||
|
echo "Starting MLX Server..."
|
||||||
|
echo " Model: $MODEL"
|
||||||
|
echo " Endpoint: http://$HOST:$PORT"
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
exec python -m mlx_server.main \
|
||||||
|
--model "$MODEL" \
|
||||||
|
--host "$HOST" \
|
||||||
|
--port "$PORT" \
|
||||||
|
"$@"
|
||||||
296
test_server.py
Normal file
296
test_server.py
Normal file
@@ -0,0 +1,296 @@
|
|||||||
|
"""Test script for MLX Server – exercises chat, streaming, vision, and tool use."""
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import io
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
from PIL import Image, ImageDraw
|
||||||
|
|
||||||
|
BASE_URL = "http://127.0.0.1:1234/v1"
|
||||||
|
MODEL = "mlx-community/gemma-3-4b-it-4bit"
|
||||||
|
|
||||||
|
|
||||||
|
def test_models():
|
||||||
|
"""Test GET /v1/models."""
|
||||||
|
print("=" * 60)
|
||||||
|
print("TEST: List models")
|
||||||
|
print("=" * 60)
|
||||||
|
r = httpx.get(f"{BASE_URL}/models")
|
||||||
|
r.raise_for_status()
|
||||||
|
data = r.json()
|
||||||
|
print(f"Models: {[m['id'] for m in data['data']]}")
|
||||||
|
print("PASS\n")
|
||||||
|
|
||||||
|
|
||||||
|
def test_chat_basic():
|
||||||
|
"""Test basic non-streaming chat."""
|
||||||
|
print("=" * 60)
|
||||||
|
print("TEST: Basic chat (non-streaming)")
|
||||||
|
print("=" * 60)
|
||||||
|
r = httpx.post(
|
||||||
|
f"{BASE_URL}/chat/completions",
|
||||||
|
json={
|
||||||
|
"model": MODEL,
|
||||||
|
"messages": [{"role": "user", "content": "Say exactly: 'The sky is blue.' Nothing else."}],
|
||||||
|
"max_tokens": 50,
|
||||||
|
"temperature": 0.1,
|
||||||
|
},
|
||||||
|
timeout=120,
|
||||||
|
)
|
||||||
|
r.raise_for_status()
|
||||||
|
data = r.json()
|
||||||
|
msg = data["choices"][0]["message"]["content"]
|
||||||
|
usage = data["usage"]
|
||||||
|
print(f"Response: {msg}")
|
||||||
|
print(f"Usage: {usage}")
|
||||||
|
print(f"Finish reason: {data['choices'][0]['finish_reason']}")
|
||||||
|
print("PASS\n")
|
||||||
|
|
||||||
|
|
||||||
|
def test_chat_streaming():
|
||||||
|
"""Test streaming chat."""
|
||||||
|
print("=" * 60)
|
||||||
|
print("TEST: Streaming chat")
|
||||||
|
print("=" * 60)
|
||||||
|
collected = ""
|
||||||
|
with httpx.stream(
|
||||||
|
"POST",
|
||||||
|
f"{BASE_URL}/chat/completions",
|
||||||
|
json={
|
||||||
|
"model": MODEL,
|
||||||
|
"messages": [{"role": "user", "content": "Count from 1 to 5, one number per line."}],
|
||||||
|
"max_tokens": 100,
|
||||||
|
"temperature": 0.1,
|
||||||
|
"stream": True,
|
||||||
|
},
|
||||||
|
timeout=120,
|
||||||
|
) as response:
|
||||||
|
response.raise_for_status()
|
||||||
|
for line in response.iter_lines():
|
||||||
|
if not line.startswith("data: "):
|
||||||
|
continue
|
||||||
|
payload = line[len("data: "):]
|
||||||
|
if payload == "[DONE]":
|
||||||
|
break
|
||||||
|
chunk = json.loads(payload)
|
||||||
|
delta = chunk["choices"][0]["delta"]
|
||||||
|
if delta.get("content"):
|
||||||
|
collected += delta["content"]
|
||||||
|
print(delta["content"], end="", flush=True)
|
||||||
|
if chunk["choices"][0].get("finish_reason"):
|
||||||
|
print(f"\n[finish_reason: {chunk['choices'][0]['finish_reason']}]")
|
||||||
|
if chunk.get("usage") and chunk["usage"].get("total_tokens", 0) > 0:
|
||||||
|
print(f"[usage: {chunk['usage']}]")
|
||||||
|
print(f"Full collected: {collected!r}")
|
||||||
|
print("PASS\n")
|
||||||
|
|
||||||
|
|
||||||
|
def _make_test_image() -> str:
|
||||||
|
"""Create a simple test image and return it as a base64 data URI."""
|
||||||
|
img = Image.new("RGB", (200, 200), color=(135, 206, 235))
|
||||||
|
draw = ImageDraw.Draw(img)
|
||||||
|
# Draw a red circle
|
||||||
|
draw.ellipse([50, 50, 150, 150], fill=(255, 0, 0), outline=(0, 0, 0), width=2)
|
||||||
|
# Draw a green triangle
|
||||||
|
draw.polygon([(100, 20), (60, 80), (140, 80)], fill=(0, 180, 0), outline=(0, 0, 0))
|
||||||
|
# Draw yellow text area
|
||||||
|
draw.rectangle([10, 160, 190, 190], fill=(255, 255, 0))
|
||||||
|
|
||||||
|
buf = io.BytesIO()
|
||||||
|
img.save(buf, format="PNG")
|
||||||
|
b64 = base64.b64encode(buf.getvalue()).decode()
|
||||||
|
return f"data:image/png;base64,{b64}"
|
||||||
|
|
||||||
|
|
||||||
|
def test_vision():
|
||||||
|
"""Test vision with an image."""
|
||||||
|
print("=" * 60)
|
||||||
|
print("TEST: Vision (image description)")
|
||||||
|
print("=" * 60)
|
||||||
|
image_uri = _make_test_image()
|
||||||
|
print(f"Image: 200x200 PNG with red circle, green triangle, yellow bar")
|
||||||
|
|
||||||
|
r = httpx.post(
|
||||||
|
f"{BASE_URL}/chat/completions",
|
||||||
|
json={
|
||||||
|
"model": MODEL,
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "Describe what shapes and colors you see in this image. Be brief."},
|
||||||
|
{"type": "image_url", "image_url": {"url": image_uri}},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"max_tokens": 200,
|
||||||
|
"temperature": 0.1,
|
||||||
|
},
|
||||||
|
timeout=120,
|
||||||
|
)
|
||||||
|
r.raise_for_status()
|
||||||
|
data = r.json()
|
||||||
|
msg = data["choices"][0]["message"]["content"]
|
||||||
|
print(f"Response: {msg}")
|
||||||
|
print("PASS\n")
|
||||||
|
|
||||||
|
|
||||||
|
def test_tool_use():
|
||||||
|
"""Test tool calling."""
|
||||||
|
print("=" * 60)
|
||||||
|
print("TEST: Tool use")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
tools = [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_weather",
|
||||||
|
"description": "Get the current weather for a given city",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"city": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The city name, e.g. 'London'",
|
||||||
|
},
|
||||||
|
"units": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Temperature units: 'celsius' or 'fahrenheit'",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["city"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
# Step 1: Ask the model to use the tool
|
||||||
|
print("Step 1: Asking model to get weather for Paris...")
|
||||||
|
r = httpx.post(
|
||||||
|
f"{BASE_URL}/chat/completions",
|
||||||
|
json={
|
||||||
|
"model": MODEL,
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "What is the weather in Paris right now? Use the get_weather tool."},
|
||||||
|
],
|
||||||
|
"tools": tools,
|
||||||
|
"max_tokens": 300,
|
||||||
|
"temperature": 0.1,
|
||||||
|
},
|
||||||
|
timeout=120,
|
||||||
|
)
|
||||||
|
r.raise_for_status()
|
||||||
|
data = r.json()
|
||||||
|
choice = data["choices"][0]
|
||||||
|
print(f"Finish reason: {choice['finish_reason']}")
|
||||||
|
print(f"Content: {choice['message'].get('content')}")
|
||||||
|
print(f"Tool calls: {choice['message'].get('tool_calls')}")
|
||||||
|
|
||||||
|
if choice["message"].get("tool_calls"):
|
||||||
|
tc = choice["message"]["tool_calls"][0]
|
||||||
|
print(f"\nTool call detected:")
|
||||||
|
print(f" ID: {tc['id']}")
|
||||||
|
print(f" Function: {tc['function']['name']}")
|
||||||
|
print(f" Arguments: {tc['function']['arguments']}")
|
||||||
|
|
||||||
|
# Step 2: Send the tool result back
|
||||||
|
print("\nStep 2: Sending mock tool result back...")
|
||||||
|
r2 = httpx.post(
|
||||||
|
f"{BASE_URL}/chat/completions",
|
||||||
|
json={
|
||||||
|
"model": MODEL,
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "What is the weather in Paris right now? Use the get_weather tool."},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": choice["message"].get("content"),
|
||||||
|
"tool_calls": choice["message"]["tool_calls"],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "tool",
|
||||||
|
"tool_call_id": tc["id"],
|
||||||
|
"content": json.dumps({"temperature": 18, "condition": "Partly cloudy", "humidity": 65}),
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"tools": tools,
|
||||||
|
"max_tokens": 300,
|
||||||
|
"temperature": 0.1,
|
||||||
|
},
|
||||||
|
timeout=120,
|
||||||
|
)
|
||||||
|
r2.raise_for_status()
|
||||||
|
data2 = r2.json()
|
||||||
|
msg2 = data2["choices"][0]["message"]["content"]
|
||||||
|
print(f"Final response: {msg2}")
|
||||||
|
else:
|
||||||
|
print("WARNING: Model did not produce a tool call. Raw response above.")
|
||||||
|
|
||||||
|
print("PASS\n")
|
||||||
|
|
||||||
|
|
||||||
|
def test_multi_turn():
|
||||||
|
"""Test multi-turn conversation."""
|
||||||
|
print("=" * 60)
|
||||||
|
print("TEST: Multi-turn conversation")
|
||||||
|
print("=" * 60)
|
||||||
|
messages = [
|
||||||
|
{"role": "user", "content": "My name is Alice."},
|
||||||
|
]
|
||||||
|
r = httpx.post(
|
||||||
|
f"{BASE_URL}/chat/completions",
|
||||||
|
json={"model": MODEL, "messages": messages, "max_tokens": 100, "temperature": 0.1},
|
||||||
|
timeout=120,
|
||||||
|
)
|
||||||
|
r.raise_for_status()
|
||||||
|
reply1 = r.json()["choices"][0]["message"]["content"]
|
||||||
|
print(f"Turn 1 reply: {reply1}")
|
||||||
|
|
||||||
|
messages.append({"role": "assistant", "content": reply1})
|
||||||
|
messages.append({"role": "user", "content": "What is my name?"})
|
||||||
|
|
||||||
|
r2 = httpx.post(
|
||||||
|
f"{BASE_URL}/chat/completions",
|
||||||
|
json={"model": MODEL, "messages": messages, "max_tokens": 100, "temperature": 0.1},
|
||||||
|
timeout=120,
|
||||||
|
)
|
||||||
|
r2.raise_for_status()
|
||||||
|
reply2 = r2.json()["choices"][0]["message"]["content"]
|
||||||
|
print(f"Turn 2 reply: {reply2}")
|
||||||
|
assert "alice" in reply2.lower(), f"Expected 'Alice' in response, got: {reply2}"
|
||||||
|
print("PASS\n")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
tests = [
|
||||||
|
test_models,
|
||||||
|
test_chat_basic,
|
||||||
|
test_chat_streaming,
|
||||||
|
test_vision,
|
||||||
|
test_tool_use,
|
||||||
|
test_multi_turn,
|
||||||
|
]
|
||||||
|
|
||||||
|
# Allow running a single test by name
|
||||||
|
if len(sys.argv) > 1:
|
||||||
|
name = sys.argv[1]
|
||||||
|
tests = [t for t in tests if name in t.__name__]
|
||||||
|
if not tests:
|
||||||
|
print(f"No test matching '{name}'. Available: models, chat_basic, chat_streaming, vision, tool_use, multi_turn")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
passed = 0
|
||||||
|
failed = 0
|
||||||
|
for test in tests:
|
||||||
|
try:
|
||||||
|
test()
|
||||||
|
passed += 1
|
||||||
|
except Exception as e:
|
||||||
|
print(f"FAIL: {e}\n")
|
||||||
|
failed += 1
|
||||||
|
|
||||||
|
print("=" * 60)
|
||||||
|
print(f"Results: {passed} passed, {failed} failed")
|
||||||
|
print("=" * 60)
|
||||||
Reference in New Issue
Block a user