initial commit

This commit is contained in:
2026-03-17 09:14:27 +01:00
commit df81afe8d7
10 changed files with 1389 additions and 0 deletions

10
.gitignore vendored Normal file
View File

@@ -0,0 +1,10 @@
__pycache__/
*.py[cod]
*$py.class
*.egg-info/
dist/
build/
.venv/
.env
*.log
.DS_Store

38
CLAUDE.md Normal file
View File

@@ -0,0 +1,38 @@
# MLX Server
OpenAI-compatible API server for Gemma 3 4B (vision + tool use) on Apple Silicon via MLX.
## Quick Start
```bash
# Activate virtual environment
source .venv/bin/activate
# Run the server (downloads model on first run)
./run.sh
# Or directly:
python -m mlx_server.main --model mlx-community/gemma-3-4b-it-4bit --port 1234
```
## Project Structure
- `mlx_server/main.py` — FastAPI server, endpoints, CLI entrypoint
- `mlx_server/engine.py` — Model loading, prompt building, generation (mlx_vlm)
- `mlx_server/models.py` — Pydantic models for OpenAI API request/response types
## Key Design Decisions
- Uses `mlx_vlm` (not `mlx_lm`) as the inference backend — this supports both text and vision in a single model load
- Gemma 3 has no system role — system messages are converted to user/assistant pairs
- Tool use is prompt-engineered: tools are injected into the system prompt with `<tool_call>` XML tags, and parsed from model output
- Thread lock on generation (single-request-at-a-time) — MLX models aren't safe for concurrent generation
- 128k context window supported via the model's native capabilities
## Dependencies
Managed via `uv` and `pyproject.toml`. Virtual environment in `.venv/`.
```bash
uv pip install -e "."
```

0
mlx_server/__init__.py Normal file
View File

3
mlx_server/__main__.py Normal file
View File

@@ -0,0 +1,3 @@
from mlx_server.main import main
main()

576
mlx_server/engine.py Normal file
View File

@@ -0,0 +1,576 @@
"""Model loading and inference engine using mlx_vlm (supports both text and vision)."""
from __future__ import annotations
import base64
import io
import json
import logging
import re
import tempfile
import threading
from collections.abc import Generator
from pathlib import Path
import mlx.core as mx
import mlx_vlm
from PIL import Image
logger = logging.getLogger(__name__)
DEFAULT_MODEL = "mlx-community/gemma-3-4b-it-4bit"
# ------------------------------------------------------------------
# Helpers for Gemma 3 tool_code format
# ------------------------------------------------------------------
_JSON_TO_PYTHON_TYPE = {
"string": "str",
"integer": "int",
"number": "float",
"boolean": "bool",
"array": "list",
"object": "dict",
}
_JSON_TYPE_DEFAULTS = {
"string": '""',
"integer": "0",
"number": "0.0",
"boolean": "False",
"array": "[]",
"object": "{}",
}
def _json_type_to_python(json_type: str) -> str:
return _JSON_TO_PYTHON_TYPE.get(json_type, "str")
def _json_type_default(json_type: str) -> str:
return _JSON_TYPE_DEFAULTS.get(json_type, "None")
def _python_repr(value) -> str:
"""Produce a Python-repr-style string for a value."""
if isinstance(value, str):
return repr(value)
if isinstance(value, bool):
return "True" if value else "False"
if isinstance(value, (int, float)):
return str(value)
return repr(value)
def _parse_python_call(call_str: str, tool_defs: dict[str, dict] | None = None) -> tuple[str, dict]:
"""Parse a function call string into (name, args_dict).
Handles multiple formats:
1. Python-style: func_name(arg1="value1", arg2=42)
2. Shell-style: func_name arg1 arg2 (common with small LLMs)
3. Mixed: func_name("value") (positional args)
tool_defs maps function names to their parameter schemas, used to
infer which parameter a positional/shell-style argument maps to.
"""
import ast
call_str = call_str.strip()
# Try Python-style: function_name(...)
m = re.match(r"(\w+)\s*\((.*)\)\s*$", call_str, re.DOTALL)
if m:
name = m.group(1)
args_str = m.group(2).strip()
if not args_str:
return name, {}
# Try parsing as a Python function call via dict()
try:
tree = ast.parse(f"dict({args_str})", mode="eval")
call_node = tree.body
args = {}
# Handle keyword arguments: func(key="val")
for kw in call_node.keywords:
args[kw.arg] = ast.literal_eval(kw.value)
# Handle positional arguments: func("val1", "val2")
if call_node.args and not args:
param_names = _get_param_names(name, tool_defs)
for i, arg_node in enumerate(call_node.args):
val = ast.literal_eval(arg_node)
if i < len(param_names):
args[param_names[i]] = val
else:
args[f"arg{i}"] = val
if args:
return name, args
except Exception:
pass
# Fallback: regex-based key=value parsing
args = {}
for pair_match in re.finditer(r"(\w+)\s*=\s*(.+?)(?:,\s*(?=\w+\s*=)|$)", args_str, re.DOTALL):
key = pair_match.group(1)
val_str = pair_match.group(2).strip()
try:
args[key] = ast.literal_eval(val_str)
except Exception:
args[key] = val_str
return name, args
# Shell-style: "func_name arg1 arg2" or "func_name some/path"
# Also handles: "func_name -flag arg" (common with shell tools)
parts = call_str.split(None, 1)
if parts and re.match(r"^\w+$", parts[0]):
name = parts[0]
if len(parts) == 1:
return name, {}
rest = parts[1].strip()
param_names = _get_param_names(name, tool_defs)
first_param = param_names[0] if param_names else "input"
return name, {first_param: rest}
# Last resort: treat the entire block as a command for the first
# known tool that looks like a shell/command tool, or just fail
raise ValueError(f"Cannot parse as function call: {call_str!r}")
def _get_param_names(func_name: str, tool_defs: dict[str, dict] | None) -> list[str]:
"""Get ordered parameter names for a function from tool definitions."""
if not tool_defs or func_name not in tool_defs:
return []
params = tool_defs[func_name].get("parameters", {})
properties = params.get("properties", {})
required = params.get("required", [])
# Required params first, then optional
optional = [k for k in properties if k not in required]
return list(required) + optional
class PromptCache:
"""Manages KV cache reuse across requests with shared prompt prefixes."""
def __init__(self):
self._cache = None
self._cached_token_ids: list[int] | None = None
def get_reusable_length(self, new_token_ids: list[int]) -> int:
"""Find how many leading tokens match the cached prefix."""
if self._cached_token_ids is None or self._cache is None:
return 0
max_match = min(len(self._cached_token_ids), len(new_token_ids))
match_len = 0
for i in range(max_match):
if self._cached_token_ids[i] != new_token_ids[i]:
break
match_len = i + 1
return match_len
def update(self, cache, token_ids: list[int]) -> None:
"""Store cache and the token IDs it was built from."""
self._cache = cache
self._cached_token_ids = list(token_ids)
def clear(self) -> None:
self._cache = None
self._cached_token_ids = None
@property
def cache(self):
return self._cache
class InferenceEngine:
"""Manages model loading and text/vision generation."""
def __init__(self, model_path: str = DEFAULT_MODEL):
self.model_path = model_path
self.model = None
self.processor = None
self.config = None
self._lock = threading.Lock()
self._prompt_cache = PromptCache()
def load(self) -> None:
logger.info("Loading model %s ...", self.model_path)
self.model, self.processor = mlx_vlm.load(self.model_path)
# Load model config for chat template
from transformers import AutoConfig
self.config = AutoConfig.from_pretrained(self.model_path, trust_remote_code=True)
logger.info("Model loaded successfully.")
# ------------------------------------------------------------------
# Image helpers
# ------------------------------------------------------------------
@staticmethod
def _decode_image_url(url: str) -> str:
"""Convert a data URI or URL to a file path that mlx_vlm can consume."""
if url.startswith("data:"):
# data:image/png;base64,iVBOR...
header, b64data = url.split(",", 1)
img_bytes = base64.b64decode(b64data)
img = Image.open(io.BytesIO(img_bytes))
tmp = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
img.save(tmp, format="PNG")
tmp.close()
return tmp.name
# Assume it's a URL or local path mlx_vlm handles URLs natively
return url
# ------------------------------------------------------------------
# Prompt building
# ------------------------------------------------------------------
def build_prompt(
self,
messages: list[dict],
tools: list[dict] | None = None,
) -> tuple[str, list[str]]:
"""Build a prompt string and collect image paths from messages.
Returns (prompt_str, image_paths).
"""
image_paths: list[str] = []
formatted_messages: list[dict] = []
for msg in messages:
role = msg["role"]
content = msg.get("content")
tool_calls = msg.get("tool_calls")
tool_call_id = msg.get("tool_call_id")
if role == "system":
text = self._get_text_content(content)
# Inject tool definitions into system prompt
if tools:
text = self._inject_tools_into_system(text, tools)
formatted_messages.append({"role": "user", "content": text})
# Gemma 3 doesn't have a system role; we use the user role
# and add a model acknowledgment
formatted_messages.append({
"role": "assistant",
"content": "Understood. I will follow these instructions.",
})
elif role == "user":
text, imgs = self._extract_content_parts(content)
image_paths.extend(imgs)
formatted_messages.append({"role": "user", "content": text})
elif role == "assistant":
text = self._get_text_content(content) or ""
if tool_calls:
# Format tool calls in the way Gemma 3 expects
tc_text = self._format_tool_calls_for_prompt(tool_calls)
text = (text + "\n" + tc_text).strip()
formatted_messages.append({"role": "assistant", "content": text})
elif role == "tool":
# Tool results use Gemma 3's tool_output format
tool_text = self._get_text_content(content) or ""
result_msg = f"```tool_output\n{tool_text}\n```"
formatted_messages.append({"role": "user", "content": result_msg})
# If the first system prompt had no tools but we have tools, inject at start
if tools and not any(m.get("role") == "system" for m in messages):
tool_system = self._build_tool_system_prompt(tools)
formatted_messages.insert(0, {"role": "user", "content": tool_system})
formatted_messages.insert(1, {
"role": "assistant",
"content": "Understood. I will follow these instructions and use tools when appropriate.",
})
# Gemma 3 requires strictly alternating user/assistant turns.
# Merge consecutive same-role messages and ensure it starts with user.
formatted_messages = self._merge_consecutive_roles(formatted_messages)
# Apply chat template via mlx_vlm
prompt = mlx_vlm.apply_chat_template(
self.processor,
self.config,
formatted_messages,
add_generation_prompt=True,
num_images=len(image_paths),
)
return prompt, image_paths
@staticmethod
def _merge_consecutive_roles(messages: list[dict]) -> list[dict]:
"""Merge consecutive messages with the same role into one.
Gemma 3's chat template enforces strict user/assistant alternation.
"""
if not messages:
return messages
merged = [messages[0].copy()]
for msg in messages[1:]:
if msg["role"] == merged[-1]["role"]:
# Merge content with the previous message
merged[-1]["content"] = (
merged[-1].get("content", "") + "\n\n" + msg.get("content", "")
)
else:
merged.append(msg.copy())
# Ensure conversation starts with user
if merged and merged[0]["role"] != "user":
merged.insert(0, {"role": "user", "content": ""})
return merged
def _get_text_content(self, content) -> str:
if content is None:
return ""
if isinstance(content, str):
return content
# list of content parts
parts = []
for part in content:
if isinstance(part, dict) and part.get("type") == "text":
parts.append(part["text"])
return "\n".join(parts)
def _extract_content_parts(self, content) -> tuple[str, list[str]]:
"""Extract text and image paths from content parts."""
if isinstance(content, str):
return content, []
if content is None:
return "", []
texts = []
images = []
for part in content:
if isinstance(part, dict):
if part.get("type") == "text":
texts.append(part["text"])
elif part.get("type") == "image_url":
url = part["image_url"]["url"]
images.append(self._decode_image_url(url))
return "\n".join(texts), images
def _inject_tools_into_system(self, system_text: str, tools: list[dict]) -> str:
tool_block = self._build_tool_system_prompt(tools)
return f"{system_text}\n\n{tool_block}"
def _build_tool_system_prompt(self, tools: list[dict]) -> str:
"""Build the tool system prompt using Google's official Gemma 3 format.
Uses the tool_code/tool_output convention recommended by Google:
- Tools defined as Python function signatures with docstrings
- Model outputs calls in ```tool_code``` fenced blocks
- Results returned in ```tool_output``` fenced blocks
"""
func_defs = []
for tool in tools:
func = tool.get("function", tool)
func_defs.append(self._tool_to_python_signature(func))
functions_block = "\n\n".join(func_defs)
return (
"At each turn, if you decide to invoke any of the function(s), "
"it should be wrapped with ```tool_code```. "
"The python methods described below are imported and available, "
"you can only use defined methods. "
"The generated code should be readable and efficient. "
"The response to a method will be wrapped in ```tool_output``` "
"use it to call more tools or generate a helpful, friendly response.\n"
"\n"
f"{functions_block}"
)
@staticmethod
def _tool_to_python_signature(func: dict) -> str:
"""Convert an OpenAI function definition to a Python function signature with docstring."""
name = func["name"]
desc = func.get("description", "")
params = func.get("parameters", {})
properties = params.get("properties", {})
required = set(params.get("required", []))
# Build parameter list
param_parts = []
doc_args = []
for pname, pinfo in properties.items():
ptype = _json_type_to_python(pinfo.get("type", "str"))
pdesc = pinfo.get("description", "")
if pname in required:
param_parts.append(f"{pname}: {ptype}")
else:
default = _json_type_default(pinfo.get("type", "str"))
param_parts.append(f"{pname}: {ptype} = {default}")
doc_args.append(f" {pname}: {pdesc}" if pdesc else f" {pname}")
sig = f"def {name}({', '.join(param_parts)}):"
doc_lines = [f' """{desc}']
if doc_args:
doc_lines.append("")
doc_lines.append(" Args:")
doc_lines.extend(doc_args)
doc_lines.append(' """')
return sig + "\n" + "\n".join(doc_lines)
def _format_tool_calls_for_prompt(self, tool_calls: list[dict]) -> str:
"""Format OpenAI-style tool calls back into Gemma 3 tool_code blocks."""
parts = []
for tc in tool_calls:
func = tc.get("function", tc)
name = func["name"]
args = func.get("arguments", "{}")
if isinstance(args, str):
args = json.loads(args)
# Format as Python function call
arg_parts = [f"{k}={_python_repr(v)}" for k, v in args.items()]
call_str = f"{name}({', '.join(arg_parts)})"
parts.append(f"```tool_code\n{call_str}\n```")
return "\n".join(parts)
# ------------------------------------------------------------------
# Generation
# ------------------------------------------------------------------
# Common kwargs for mlx_vlm generate calls — optimized for Apple Silicon
_GENERATE_KWARGS = {
"kv_bits": 8, # Quantize KV cache to 8-bit (halves memory bandwidth)
"kv_group_size": 64, # Group size for KV quantization
}
def generate(
self,
prompt: str,
images: list[str] | None = None,
max_tokens: int = 4096,
temperature: float = 0.7,
top_p: float = 0.9,
stop: list[str] | None = None,
repetition_penalty: float = 1.1,
) -> tuple[str, int, int]:
"""Generate a complete response. Returns (text, prompt_tokens, completion_tokens)."""
with self._lock:
image_arg = images if images else None
result = mlx_vlm.generate(
self.model,
self.processor,
prompt,
image=image_arg,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
repetition_penalty=repetition_penalty,
verbose=False,
**self._GENERATE_KWARGS,
)
text = result.text
if stop:
text = self._apply_stop(text, stop)
return text, result.prompt_tokens, result.generation_tokens
def stream_generate(
self,
prompt: str,
images: list[str] | None = None,
max_tokens: int = 4096,
temperature: float = 0.7,
top_p: float = 0.9,
stop: list[str] | None = None,
repetition_penalty: float = 1.1,
) -> Generator[tuple[str, bool, int, int], None, None]:
"""Stream tokens. Yields (token_text, is_final, prompt_tokens, gen_tokens)."""
with self._lock:
image_arg = images if images else None
accumulated = ""
prompt_tokens = 0
gen_tokens = 0
for result in mlx_vlm.stream_generate(
self.model,
self.processor,
prompt,
image=image_arg,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
repetition_penalty=repetition_penalty,
**self._GENERATE_KWARGS,
):
# result.text is the incremental segment (detokenizer.last_segment),
# NOT the full accumulated text.
token_text = result.text
accumulated += token_text
prompt_tokens = result.prompt_tokens
gen_tokens = result.generation_tokens
if stop and self._check_stop(accumulated, stop):
# Trim the accumulated text and yield what's safe
trimmed = self._apply_stop(accumulated, stop)
# Only yield the part we haven't yielded yet
safe_delta = trimmed[len(accumulated) - len(token_text):]
yield safe_delta, True, prompt_tokens, gen_tokens
return
yield token_text, False, prompt_tokens, gen_tokens
# Final yield to signal completion
yield "", True, prompt_tokens, gen_tokens
@staticmethod
def _apply_stop(text: str, stop: list[str]) -> str:
for s in stop:
idx = text.find(s)
if idx != -1:
text = text[:idx]
return text
@staticmethod
def _check_stop(text: str, stop: list[str]) -> bool:
return any(s in text for s in stop)
# ------------------------------------------------------------------
# Tool call parsing from model output
# ------------------------------------------------------------------
@staticmethod
def parse_tool_calls(
text: str, tools: list[dict] | None = None
) -> tuple[str, list[dict]]:
"""Parse tool calls from model output using Gemma 3's tool_code format.
Detects ```tool_code ... ``` blocks containing Python-style or
shell-style function calls.
Returns (clean_text, tool_calls) where tool_calls is a list of
{"id": str, "type": "function", "function": {"name": str, "arguments": str}}.
"""
# Build a lookup of function name -> parameter schema
tool_defs: dict[str, dict] = {}
if tools:
for tool in tools:
func = tool.get("function", tool)
tool_defs[func["name"]] = func
tool_calls = []
pattern = r"```tool_code\s*(.*?)\s*```"
matches = re.findall(pattern, text, re.DOTALL)
clean_text = re.sub(r"```tool_code\s*.*?\s*```", "", text, flags=re.DOTALL).strip()
for i, match in enumerate(matches):
call_str = match.strip()
try:
name, args = _parse_python_call(call_str, tool_defs)
tool_calls.append({
"id": f"call_{i}_{hash(call_str) % 10**8:08d}",
"type": "function",
"function": {
"name": name,
"arguments": json.dumps(args),
},
})
except Exception as e:
logger.warning("Failed to parse tool_code call %r: %s", call_str, e)
return clean_text, tool_calls

278
mlx_server/main.py Normal file
View File

@@ -0,0 +1,278 @@
"""OpenAI-compatible API server for Gemma 3 4B via MLX."""
from __future__ import annotations
import argparse
import json
import logging
import time
import uuid
import uvicorn
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from sse_starlette.sse import EventSourceResponse
from .engine import DEFAULT_MODEL, InferenceEngine
from .models import (
ChatCompletionChunk,
ChatCompletionRequest,
ChatCompletionResponse,
Choice,
ChoiceMessage,
DeltaMessage,
ModelInfo,
ModelListResponse,
StreamChoice,
ToolCall,
FunctionCall,
UsageInfo,
)
logger = logging.getLogger(__name__)
app = FastAPI(title="MLX Server", description="OpenAI-compatible API for Gemma 3 4B")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
engine: InferenceEngine | None = None
def get_engine() -> InferenceEngine:
if engine is None:
raise HTTPException(status_code=503, detail="Model not loaded")
return engine
def _make_id() -> str:
return f"chatcmpl-{uuid.uuid4().hex[:12]}"
# ------------------------------------------------------------------
# Endpoints
# ------------------------------------------------------------------
@app.get("/v1/models")
async def list_models() -> ModelListResponse:
e = get_engine()
return ModelListResponse(data=[ModelInfo(id=e.model_path)])
@app.post("/v1/chat/completions")
async def chat_completions(request: ChatCompletionRequest):
e = get_engine()
# Convert pydantic messages to dicts
messages = [m.model_dump(exclude_none=True) for m in request.messages]
tools = None
if request.tools:
tools = [t.model_dump(exclude_none=True) for t in request.tools]
prompt, images = e.build_prompt(messages, tools)
stop = request.stop
if isinstance(stop, str):
stop = [stop]
temperature = request.temperature if request.temperature is not None else 0.7
top_p = request.top_p if request.top_p is not None else 0.9
max_tokens = request.max_tokens if request.max_tokens is not None else 4096
if request.stream:
return EventSourceResponse(
_stream_response(e, prompt, images, max_tokens, temperature, top_p, stop, tools, request.model),
media_type="text/event-stream",
)
# Non-streaming
text, prompt_tokens, completion_tokens = e.generate(
prompt=prompt,
images=images or None,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
stop=stop,
)
# Check for tool calls in the response
finish_reason = "stop"
tool_calls_parsed = None
if tools:
clean_text, parsed = e.parse_tool_calls(text, tools)
if parsed:
tool_calls_parsed = [
ToolCall(
index=i,
id=tc["id"],
type="function",
function=FunctionCall(
name=tc["function"]["name"],
arguments=tc["function"]["arguments"],
),
)
for i, tc in enumerate(parsed)
]
text = clean_text if clean_text else None
finish_reason = "tool_calls"
return ChatCompletionResponse(
id=_make_id(),
model=request.model,
choices=[
Choice(
message=ChoiceMessage(
role="assistant",
content=text if not tool_calls_parsed else (text or None),
tool_calls=tool_calls_parsed,
),
finish_reason=finish_reason,
)
],
usage=UsageInfo(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
),
)
async def _stream_response(
e: InferenceEngine,
prompt: str,
images: list[str] | None,
max_tokens: int,
temperature: float,
top_p: float,
stop: list[str] | None,
tools: list[dict] | None,
model_name: str,
):
request_id = _make_id()
created = int(time.time())
# Send initial chunk with role
initial_chunk = ChatCompletionChunk(
id=request_id,
created=created,
model=model_name,
choices=[StreamChoice(delta=DeltaMessage(role="assistant"))],
)
yield {"data": initial_chunk.model_dump_json()}
full_text = ""
prompt_tokens = 0
gen_tokens = 0
for token_text, is_final, pt, gt in e.stream_generate(
prompt=prompt,
images=images or None,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
stop=stop,
):
prompt_tokens = pt
gen_tokens = gt
full_text += token_text
if not is_final and token_text:
chunk = ChatCompletionChunk(
id=request_id,
created=created,
model=model_name,
choices=[StreamChoice(delta=DeltaMessage(content=token_text))],
)
yield {"data": chunk.model_dump_json()}
# Check for tool calls in complete response
finish_reason = "stop"
if tools:
clean_text, parsed = e.parse_tool_calls(full_text, tools)
if parsed:
finish_reason = "tool_calls"
# Emit tool call chunks
for i, tc in enumerate(parsed):
tc_chunk = ChatCompletionChunk(
id=request_id,
created=created,
model=model_name,
choices=[
StreamChoice(
delta=DeltaMessage(
tool_calls=[
ToolCall(
index=i,
id=tc["id"],
type="function",
function=FunctionCall(
name=tc["function"]["name"],
arguments=tc["function"]["arguments"],
),
)
]
)
)
],
)
yield {"data": tc_chunk.model_dump_json()}
# Final chunk with finish reason and usage
final_chunk = ChatCompletionChunk(
id=request_id,
created=created,
model=model_name,
choices=[StreamChoice(delta=DeltaMessage(), finish_reason=finish_reason)],
usage=UsageInfo(
prompt_tokens=prompt_tokens,
completion_tokens=gen_tokens,
total_tokens=prompt_tokens + gen_tokens,
),
)
yield {"data": final_chunk.model_dump_json()}
yield {"data": "[DONE]"}
# ------------------------------------------------------------------
# Health / utility
# ------------------------------------------------------------------
@app.get("/health")
async def health():
return {"status": "ok"}
# ------------------------------------------------------------------
# Entrypoint
# ------------------------------------------------------------------
def main():
parser = argparse.ArgumentParser(description="MLX Server OpenAI-compatible API")
parser.add_argument("--model", type=str, default=DEFAULT_MODEL, help="HuggingFace model path")
parser.add_argument("--host", type=str, default="127.0.0.1")
parser.add_argument("--port", type=int, default=1234)
parser.add_argument("--log-level", type=str, default="info")
args = parser.parse_args()
logging.basicConfig(
level=getattr(logging, args.log_level.upper()),
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
)
global engine
engine = InferenceEngine(model_path=args.model)
engine.load()
uvicorn.run(app, host=args.host, port=args.port, log_level=args.log_level)
if __name__ == "__main__":
main()

144
mlx_server/models.py Normal file
View File

@@ -0,0 +1,144 @@
"""OpenAI API compatible request/response models."""
from __future__ import annotations
import time
from typing import Any, Literal
from pydantic import BaseModel, Field
# --- Request models ---
class FunctionDefinition(BaseModel):
name: str
description: str | None = None
parameters: dict[str, Any] | None = None
class ToolDefinition(BaseModel):
type: Literal["function"] = "function"
function: FunctionDefinition
class FunctionCall(BaseModel):
name: str
arguments: str # JSON string
class ToolCall(BaseModel):
index: int = 0
id: str
type: Literal["function"] = "function"
function: FunctionCall
class ContentPartText(BaseModel):
type: Literal["text"] = "text"
text: str
class ImageURL(BaseModel):
url: str # Can be a URL or base64 data URI
detail: str | None = None
class ContentPartImage(BaseModel):
type: Literal["image_url"] = "image_url"
image_url: ImageURL
ContentPart = ContentPartText | ContentPartImage
class ChatMessage(BaseModel):
role: Literal["system", "user", "assistant", "tool"]
content: str | list[ContentPart] | None = None
name: str | None = None
tool_calls: list[ToolCall] | None = None
tool_call_id: str | None = None
class ChatCompletionRequest(BaseModel):
model: str = "gemma-3-4b-it"
messages: list[ChatMessage]
temperature: float | None = 0.7
top_p: float | None = 0.9
max_tokens: int | None = 4096
stream: bool = False
stop: str | list[str] | None = None
tools: list[ToolDefinition] | None = None
tool_choice: str | dict | None = None
frequency_penalty: float | None = None
presence_penalty: float | None = None
n: int | None = 1
# --- Response models ---
class UsageInfo(BaseModel):
prompt_tokens: int = 0
completion_tokens: int = 0
total_tokens: int = 0
class ChoiceMessage(BaseModel):
role: str = "assistant"
content: str | None = None
tool_calls: list[ToolCall] | None = None
class Choice(BaseModel):
index: int = 0
message: ChoiceMessage
finish_reason: str | None = "stop"
class ChatCompletionResponse(BaseModel):
id: str
object: str = "chat.completion"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: list[Choice]
usage: UsageInfo
# --- Streaming response models ---
class DeltaMessage(BaseModel):
role: str | None = None
content: str | None = None
tool_calls: list[ToolCall] | None = None
class StreamChoice(BaseModel):
index: int = 0
delta: DeltaMessage
finish_reason: str | None = None
class ChatCompletionChunk(BaseModel):
id: str
object: str = "chat.completion.chunk"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: list[StreamChoice]
usage: UsageInfo | None = None
# --- Model listing ---
class ModelInfo(BaseModel):
id: str
object: str = "model"
created: int = Field(default_factory=lambda: int(time.time()))
owned_by: str = "local"
class ModelListResponse(BaseModel):
object: str = "list"
data: list[ModelInfo]

20
pyproject.toml Normal file
View File

@@ -0,0 +1,20 @@
[project]
name = "mlx-server"
version = "0.1.0"
description = "OpenAI-compatible API server for Gemma 3 4B via MLX"
requires-python = ">=3.11"
dependencies = [
"fastapi>=0.115.0",
"uvicorn[standard]>=0.30.0",
"mlx>=0.22.0",
"mlx-lm>=0.22.0",
"mlx-vlm>=0.1.18",
"pydantic>=2.0.0",
"sse-starlette>=2.0.0",
"pillow>=10.0.0",
"httpx>=0.27.0",
"torchvision>=0.20.0",
]
[project.scripts]
mlx-server = "mlx_server.main:main"

24
run.sh Executable file
View File

@@ -0,0 +1,24 @@
#!/usr/bin/env bash
set -euo pipefail
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
cd "$SCRIPT_DIR"
# Activate virtual environment
source .venv/bin/activate
# Default model 4-bit quantized Gemma 3 4B IT (vision-capable)
MODEL="${MODEL:-mlx-community/gemma-3-4b-it-4bit}"
HOST="${HOST:-127.0.0.1}"
PORT="${PORT:-1234}"
echo "Starting MLX Server..."
echo " Model: $MODEL"
echo " Endpoint: http://$HOST:$PORT"
echo ""
exec python -m mlx_server.main \
--model "$MODEL" \
--host "$HOST" \
--port "$PORT" \
"$@"

296
test_server.py Normal file
View File

@@ -0,0 +1,296 @@
"""Test script for MLX Server exercises chat, streaming, vision, and tool use."""
import base64
import io
import json
import sys
import httpx
from PIL import Image, ImageDraw
BASE_URL = "http://127.0.0.1:1234/v1"
MODEL = "mlx-community/gemma-3-4b-it-4bit"
def test_models():
"""Test GET /v1/models."""
print("=" * 60)
print("TEST: List models")
print("=" * 60)
r = httpx.get(f"{BASE_URL}/models")
r.raise_for_status()
data = r.json()
print(f"Models: {[m['id'] for m in data['data']]}")
print("PASS\n")
def test_chat_basic():
"""Test basic non-streaming chat."""
print("=" * 60)
print("TEST: Basic chat (non-streaming)")
print("=" * 60)
r = httpx.post(
f"{BASE_URL}/chat/completions",
json={
"model": MODEL,
"messages": [{"role": "user", "content": "Say exactly: 'The sky is blue.' Nothing else."}],
"max_tokens": 50,
"temperature": 0.1,
},
timeout=120,
)
r.raise_for_status()
data = r.json()
msg = data["choices"][0]["message"]["content"]
usage = data["usage"]
print(f"Response: {msg}")
print(f"Usage: {usage}")
print(f"Finish reason: {data['choices'][0]['finish_reason']}")
print("PASS\n")
def test_chat_streaming():
"""Test streaming chat."""
print("=" * 60)
print("TEST: Streaming chat")
print("=" * 60)
collected = ""
with httpx.stream(
"POST",
f"{BASE_URL}/chat/completions",
json={
"model": MODEL,
"messages": [{"role": "user", "content": "Count from 1 to 5, one number per line."}],
"max_tokens": 100,
"temperature": 0.1,
"stream": True,
},
timeout=120,
) as response:
response.raise_for_status()
for line in response.iter_lines():
if not line.startswith("data: "):
continue
payload = line[len("data: "):]
if payload == "[DONE]":
break
chunk = json.loads(payload)
delta = chunk["choices"][0]["delta"]
if delta.get("content"):
collected += delta["content"]
print(delta["content"], end="", flush=True)
if chunk["choices"][0].get("finish_reason"):
print(f"\n[finish_reason: {chunk['choices'][0]['finish_reason']}]")
if chunk.get("usage") and chunk["usage"].get("total_tokens", 0) > 0:
print(f"[usage: {chunk['usage']}]")
print(f"Full collected: {collected!r}")
print("PASS\n")
def _make_test_image() -> str:
"""Create a simple test image and return it as a base64 data URI."""
img = Image.new("RGB", (200, 200), color=(135, 206, 235))
draw = ImageDraw.Draw(img)
# Draw a red circle
draw.ellipse([50, 50, 150, 150], fill=(255, 0, 0), outline=(0, 0, 0), width=2)
# Draw a green triangle
draw.polygon([(100, 20), (60, 80), (140, 80)], fill=(0, 180, 0), outline=(0, 0, 0))
# Draw yellow text area
draw.rectangle([10, 160, 190, 190], fill=(255, 255, 0))
buf = io.BytesIO()
img.save(buf, format="PNG")
b64 = base64.b64encode(buf.getvalue()).decode()
return f"data:image/png;base64,{b64}"
def test_vision():
"""Test vision with an image."""
print("=" * 60)
print("TEST: Vision (image description)")
print("=" * 60)
image_uri = _make_test_image()
print(f"Image: 200x200 PNG with red circle, green triangle, yellow bar")
r = httpx.post(
f"{BASE_URL}/chat/completions",
json={
"model": MODEL,
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": "Describe what shapes and colors you see in this image. Be brief."},
{"type": "image_url", "image_url": {"url": image_uri}},
],
}
],
"max_tokens": 200,
"temperature": 0.1,
},
timeout=120,
)
r.raise_for_status()
data = r.json()
msg = data["choices"][0]["message"]["content"]
print(f"Response: {msg}")
print("PASS\n")
def test_tool_use():
"""Test tool calling."""
print("=" * 60)
print("TEST: Tool use")
print("=" * 60)
tools = [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get the current weather for a given city",
"parameters": {
"type": "object",
"properties": {
"city": {
"type": "string",
"description": "The city name, e.g. 'London'",
},
"units": {
"type": "string",
"description": "Temperature units: 'celsius' or 'fahrenheit'",
},
},
"required": ["city"],
},
},
}
]
# Step 1: Ask the model to use the tool
print("Step 1: Asking model to get weather for Paris...")
r = httpx.post(
f"{BASE_URL}/chat/completions",
json={
"model": MODEL,
"messages": [
{"role": "user", "content": "What is the weather in Paris right now? Use the get_weather tool."},
],
"tools": tools,
"max_tokens": 300,
"temperature": 0.1,
},
timeout=120,
)
r.raise_for_status()
data = r.json()
choice = data["choices"][0]
print(f"Finish reason: {choice['finish_reason']}")
print(f"Content: {choice['message'].get('content')}")
print(f"Tool calls: {choice['message'].get('tool_calls')}")
if choice["message"].get("tool_calls"):
tc = choice["message"]["tool_calls"][0]
print(f"\nTool call detected:")
print(f" ID: {tc['id']}")
print(f" Function: {tc['function']['name']}")
print(f" Arguments: {tc['function']['arguments']}")
# Step 2: Send the tool result back
print("\nStep 2: Sending mock tool result back...")
r2 = httpx.post(
f"{BASE_URL}/chat/completions",
json={
"model": MODEL,
"messages": [
{"role": "user", "content": "What is the weather in Paris right now? Use the get_weather tool."},
{
"role": "assistant",
"content": choice["message"].get("content"),
"tool_calls": choice["message"]["tool_calls"],
},
{
"role": "tool",
"tool_call_id": tc["id"],
"content": json.dumps({"temperature": 18, "condition": "Partly cloudy", "humidity": 65}),
},
],
"tools": tools,
"max_tokens": 300,
"temperature": 0.1,
},
timeout=120,
)
r2.raise_for_status()
data2 = r2.json()
msg2 = data2["choices"][0]["message"]["content"]
print(f"Final response: {msg2}")
else:
print("WARNING: Model did not produce a tool call. Raw response above.")
print("PASS\n")
def test_multi_turn():
"""Test multi-turn conversation."""
print("=" * 60)
print("TEST: Multi-turn conversation")
print("=" * 60)
messages = [
{"role": "user", "content": "My name is Alice."},
]
r = httpx.post(
f"{BASE_URL}/chat/completions",
json={"model": MODEL, "messages": messages, "max_tokens": 100, "temperature": 0.1},
timeout=120,
)
r.raise_for_status()
reply1 = r.json()["choices"][0]["message"]["content"]
print(f"Turn 1 reply: {reply1}")
messages.append({"role": "assistant", "content": reply1})
messages.append({"role": "user", "content": "What is my name?"})
r2 = httpx.post(
f"{BASE_URL}/chat/completions",
json={"model": MODEL, "messages": messages, "max_tokens": 100, "temperature": 0.1},
timeout=120,
)
r2.raise_for_status()
reply2 = r2.json()["choices"][0]["message"]["content"]
print(f"Turn 2 reply: {reply2}")
assert "alice" in reply2.lower(), f"Expected 'Alice' in response, got: {reply2}"
print("PASS\n")
if __name__ == "__main__":
tests = [
test_models,
test_chat_basic,
test_chat_streaming,
test_vision,
test_tool_use,
test_multi_turn,
]
# Allow running a single test by name
if len(sys.argv) > 1:
name = sys.argv[1]
tests = [t for t in tests if name in t.__name__]
if not tests:
print(f"No test matching '{name}'. Available: models, chat_basic, chat_streaming, vision, tool_use, multi_turn")
sys.exit(1)
passed = 0
failed = 0
for test in tests:
try:
test()
passed += 1
except Exception as e:
print(f"FAIL: {e}\n")
failed += 1
print("=" * 60)
print(f"Results: {passed} passed, {failed} failed")
print("=" * 60)