From df81afe8d73e4dd747e706e3da3391e6d439f973 Mon Sep 17 00:00:00 2001 From: Chili Palmer Date: Tue, 17 Mar 2026 09:14:27 +0100 Subject: [PATCH] initial commit --- .gitignore | 10 + CLAUDE.md | 38 +++ mlx_server/__init__.py | 0 mlx_server/__main__.py | 3 + mlx_server/engine.py | 576 +++++++++++++++++++++++++++++++++++++++++ mlx_server/main.py | 278 ++++++++++++++++++++ mlx_server/models.py | 144 +++++++++++ pyproject.toml | 20 ++ run.sh | 24 ++ test_server.py | 296 +++++++++++++++++++++ 10 files changed, 1389 insertions(+) create mode 100644 .gitignore create mode 100644 CLAUDE.md create mode 100644 mlx_server/__init__.py create mode 100644 mlx_server/__main__.py create mode 100644 mlx_server/engine.py create mode 100644 mlx_server/main.py create mode 100644 mlx_server/models.py create mode 100644 pyproject.toml create mode 100755 run.sh create mode 100644 test_server.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..70b6236 --- /dev/null +++ b/.gitignore @@ -0,0 +1,10 @@ +__pycache__/ +*.py[cod] +*$py.class +*.egg-info/ +dist/ +build/ +.venv/ +.env +*.log +.DS_Store diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000..9542102 --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,38 @@ +# MLX Server + +OpenAI-compatible API server for Gemma 3 4B (vision + tool use) on Apple Silicon via MLX. + +## Quick Start + +```bash +# Activate virtual environment +source .venv/bin/activate + +# Run the server (downloads model on first run) +./run.sh + +# Or directly: +python -m mlx_server.main --model mlx-community/gemma-3-4b-it-4bit --port 1234 +``` + +## Project Structure + +- `mlx_server/main.py` — FastAPI server, endpoints, CLI entrypoint +- `mlx_server/engine.py` — Model loading, prompt building, generation (mlx_vlm) +- `mlx_server/models.py` — Pydantic models for OpenAI API request/response types + +## Key Design Decisions + +- Uses `mlx_vlm` (not `mlx_lm`) as the inference backend — this supports both text and vision in a single model load +- Gemma 3 has no system role — system messages are converted to user/assistant pairs +- Tool use is prompt-engineered: tools are injected into the system prompt with `` XML tags, and parsed from model output +- Thread lock on generation (single-request-at-a-time) — MLX models aren't safe for concurrent generation +- 128k context window supported via the model's native capabilities + +## Dependencies + +Managed via `uv` and `pyproject.toml`. Virtual environment in `.venv/`. + +```bash +uv pip install -e "." +``` diff --git a/mlx_server/__init__.py b/mlx_server/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/mlx_server/__main__.py b/mlx_server/__main__.py new file mode 100644 index 0000000..46ce780 --- /dev/null +++ b/mlx_server/__main__.py @@ -0,0 +1,3 @@ +from mlx_server.main import main + +main() diff --git a/mlx_server/engine.py b/mlx_server/engine.py new file mode 100644 index 0000000..7e1cf9a --- /dev/null +++ b/mlx_server/engine.py @@ -0,0 +1,576 @@ +"""Model loading and inference engine using mlx_vlm (supports both text and vision).""" + +from __future__ import annotations + +import base64 +import io +import json +import logging +import re +import tempfile +import threading +from collections.abc import Generator +from pathlib import Path + +import mlx.core as mx +import mlx_vlm +from PIL import Image + +logger = logging.getLogger(__name__) + +DEFAULT_MODEL = "mlx-community/gemma-3-4b-it-4bit" + + +# ------------------------------------------------------------------ +# Helpers for Gemma 3 tool_code format +# ------------------------------------------------------------------ + +_JSON_TO_PYTHON_TYPE = { + "string": "str", + "integer": "int", + "number": "float", + "boolean": "bool", + "array": "list", + "object": "dict", +} + +_JSON_TYPE_DEFAULTS = { + "string": '""', + "integer": "0", + "number": "0.0", + "boolean": "False", + "array": "[]", + "object": "{}", +} + + +def _json_type_to_python(json_type: str) -> str: + return _JSON_TO_PYTHON_TYPE.get(json_type, "str") + + +def _json_type_default(json_type: str) -> str: + return _JSON_TYPE_DEFAULTS.get(json_type, "None") + + +def _python_repr(value) -> str: + """Produce a Python-repr-style string for a value.""" + if isinstance(value, str): + return repr(value) + if isinstance(value, bool): + return "True" if value else "False" + if isinstance(value, (int, float)): + return str(value) + return repr(value) + + +def _parse_python_call(call_str: str, tool_defs: dict[str, dict] | None = None) -> tuple[str, dict]: + """Parse a function call string into (name, args_dict). + + Handles multiple formats: + 1. Python-style: func_name(arg1="value1", arg2=42) + 2. Shell-style: func_name arg1 arg2 (common with small LLMs) + 3. Mixed: func_name("value") (positional args) + + tool_defs maps function names to their parameter schemas, used to + infer which parameter a positional/shell-style argument maps to. + """ + import ast + + call_str = call_str.strip() + + # Try Python-style: function_name(...) + m = re.match(r"(\w+)\s*\((.*)\)\s*$", call_str, re.DOTALL) + if m: + name = m.group(1) + args_str = m.group(2).strip() + + if not args_str: + return name, {} + + # Try parsing as a Python function call via dict() + try: + tree = ast.parse(f"dict({args_str})", mode="eval") + call_node = tree.body + args = {} + # Handle keyword arguments: func(key="val") + for kw in call_node.keywords: + args[kw.arg] = ast.literal_eval(kw.value) + # Handle positional arguments: func("val1", "val2") + if call_node.args and not args: + param_names = _get_param_names(name, tool_defs) + for i, arg_node in enumerate(call_node.args): + val = ast.literal_eval(arg_node) + if i < len(param_names): + args[param_names[i]] = val + else: + args[f"arg{i}"] = val + if args: + return name, args + except Exception: + pass + + # Fallback: regex-based key=value parsing + args = {} + for pair_match in re.finditer(r"(\w+)\s*=\s*(.+?)(?:,\s*(?=\w+\s*=)|$)", args_str, re.DOTALL): + key = pair_match.group(1) + val_str = pair_match.group(2).strip() + try: + args[key] = ast.literal_eval(val_str) + except Exception: + args[key] = val_str + return name, args + + # Shell-style: "func_name arg1 arg2" or "func_name some/path" + # Also handles: "func_name -flag arg" (common with shell tools) + parts = call_str.split(None, 1) + if parts and re.match(r"^\w+$", parts[0]): + name = parts[0] + if len(parts) == 1: + return name, {} + + rest = parts[1].strip() + param_names = _get_param_names(name, tool_defs) + first_param = param_names[0] if param_names else "input" + return name, {first_param: rest} + + # Last resort: treat the entire block as a command for the first + # known tool that looks like a shell/command tool, or just fail + raise ValueError(f"Cannot parse as function call: {call_str!r}") + + +def _get_param_names(func_name: str, tool_defs: dict[str, dict] | None) -> list[str]: + """Get ordered parameter names for a function from tool definitions.""" + if not tool_defs or func_name not in tool_defs: + return [] + params = tool_defs[func_name].get("parameters", {}) + properties = params.get("properties", {}) + required = params.get("required", []) + # Required params first, then optional + optional = [k for k in properties if k not in required] + return list(required) + optional + + +class PromptCache: + """Manages KV cache reuse across requests with shared prompt prefixes.""" + + def __init__(self): + self._cache = None + self._cached_token_ids: list[int] | None = None + + def get_reusable_length(self, new_token_ids: list[int]) -> int: + """Find how many leading tokens match the cached prefix.""" + if self._cached_token_ids is None or self._cache is None: + return 0 + max_match = min(len(self._cached_token_ids), len(new_token_ids)) + match_len = 0 + for i in range(max_match): + if self._cached_token_ids[i] != new_token_ids[i]: + break + match_len = i + 1 + return match_len + + def update(self, cache, token_ids: list[int]) -> None: + """Store cache and the token IDs it was built from.""" + self._cache = cache + self._cached_token_ids = list(token_ids) + + def clear(self) -> None: + self._cache = None + self._cached_token_ids = None + + @property + def cache(self): + return self._cache + + +class InferenceEngine: + """Manages model loading and text/vision generation.""" + + def __init__(self, model_path: str = DEFAULT_MODEL): + self.model_path = model_path + self.model = None + self.processor = None + self.config = None + self._lock = threading.Lock() + self._prompt_cache = PromptCache() + + def load(self) -> None: + logger.info("Loading model %s ...", self.model_path) + self.model, self.processor = mlx_vlm.load(self.model_path) + # Load model config for chat template + from transformers import AutoConfig + + self.config = AutoConfig.from_pretrained(self.model_path, trust_remote_code=True) + logger.info("Model loaded successfully.") + + # ------------------------------------------------------------------ + # Image helpers + # ------------------------------------------------------------------ + + @staticmethod + def _decode_image_url(url: str) -> str: + """Convert a data URI or URL to a file path that mlx_vlm can consume.""" + if url.startswith("data:"): + # data:image/png;base64,iVBOR... + header, b64data = url.split(",", 1) + img_bytes = base64.b64decode(b64data) + img = Image.open(io.BytesIO(img_bytes)) + tmp = tempfile.NamedTemporaryFile(suffix=".png", delete=False) + img.save(tmp, format="PNG") + tmp.close() + return tmp.name + # Assume it's a URL or local path – mlx_vlm handles URLs natively + return url + + # ------------------------------------------------------------------ + # Prompt building + # ------------------------------------------------------------------ + + def build_prompt( + self, + messages: list[dict], + tools: list[dict] | None = None, + ) -> tuple[str, list[str]]: + """Build a prompt string and collect image paths from messages. + + Returns (prompt_str, image_paths). + """ + image_paths: list[str] = [] + formatted_messages: list[dict] = [] + + for msg in messages: + role = msg["role"] + content = msg.get("content") + tool_calls = msg.get("tool_calls") + tool_call_id = msg.get("tool_call_id") + + if role == "system": + text = self._get_text_content(content) + # Inject tool definitions into system prompt + if tools: + text = self._inject_tools_into_system(text, tools) + formatted_messages.append({"role": "user", "content": text}) + # Gemma 3 doesn't have a system role; we use the user role + # and add a model acknowledgment + formatted_messages.append({ + "role": "assistant", + "content": "Understood. I will follow these instructions.", + }) + elif role == "user": + text, imgs = self._extract_content_parts(content) + image_paths.extend(imgs) + formatted_messages.append({"role": "user", "content": text}) + elif role == "assistant": + text = self._get_text_content(content) or "" + if tool_calls: + # Format tool calls in the way Gemma 3 expects + tc_text = self._format_tool_calls_for_prompt(tool_calls) + text = (text + "\n" + tc_text).strip() + formatted_messages.append({"role": "assistant", "content": text}) + elif role == "tool": + # Tool results use Gemma 3's tool_output format + tool_text = self._get_text_content(content) or "" + result_msg = f"```tool_output\n{tool_text}\n```" + formatted_messages.append({"role": "user", "content": result_msg}) + + # If the first system prompt had no tools but we have tools, inject at start + if tools and not any(m.get("role") == "system" for m in messages): + tool_system = self._build_tool_system_prompt(tools) + formatted_messages.insert(0, {"role": "user", "content": tool_system}) + formatted_messages.insert(1, { + "role": "assistant", + "content": "Understood. I will follow these instructions and use tools when appropriate.", + }) + + # Gemma 3 requires strictly alternating user/assistant turns. + # Merge consecutive same-role messages and ensure it starts with user. + formatted_messages = self._merge_consecutive_roles(formatted_messages) + + # Apply chat template via mlx_vlm + prompt = mlx_vlm.apply_chat_template( + self.processor, + self.config, + formatted_messages, + add_generation_prompt=True, + num_images=len(image_paths), + ) + + return prompt, image_paths + + @staticmethod + def _merge_consecutive_roles(messages: list[dict]) -> list[dict]: + """Merge consecutive messages with the same role into one. + + Gemma 3's chat template enforces strict user/assistant alternation. + """ + if not messages: + return messages + + merged = [messages[0].copy()] + for msg in messages[1:]: + if msg["role"] == merged[-1]["role"]: + # Merge content with the previous message + merged[-1]["content"] = ( + merged[-1].get("content", "") + "\n\n" + msg.get("content", "") + ) + else: + merged.append(msg.copy()) + + # Ensure conversation starts with user + if merged and merged[0]["role"] != "user": + merged.insert(0, {"role": "user", "content": ""}) + + return merged + + def _get_text_content(self, content) -> str: + if content is None: + return "" + if isinstance(content, str): + return content + # list of content parts + parts = [] + for part in content: + if isinstance(part, dict) and part.get("type") == "text": + parts.append(part["text"]) + return "\n".join(parts) + + def _extract_content_parts(self, content) -> tuple[str, list[str]]: + """Extract text and image paths from content parts.""" + if isinstance(content, str): + return content, [] + if content is None: + return "", [] + + texts = [] + images = [] + for part in content: + if isinstance(part, dict): + if part.get("type") == "text": + texts.append(part["text"]) + elif part.get("type") == "image_url": + url = part["image_url"]["url"] + images.append(self._decode_image_url(url)) + return "\n".join(texts), images + + def _inject_tools_into_system(self, system_text: str, tools: list[dict]) -> str: + tool_block = self._build_tool_system_prompt(tools) + return f"{system_text}\n\n{tool_block}" + + def _build_tool_system_prompt(self, tools: list[dict]) -> str: + """Build the tool system prompt using Google's official Gemma 3 format. + + Uses the tool_code/tool_output convention recommended by Google: + - Tools defined as Python function signatures with docstrings + - Model outputs calls in ```tool_code``` fenced blocks + - Results returned in ```tool_output``` fenced blocks + """ + func_defs = [] + for tool in tools: + func = tool.get("function", tool) + func_defs.append(self._tool_to_python_signature(func)) + + functions_block = "\n\n".join(func_defs) + + return ( + "At each turn, if you decide to invoke any of the function(s), " + "it should be wrapped with ```tool_code```. " + "The python methods described below are imported and available, " + "you can only use defined methods. " + "The generated code should be readable and efficient. " + "The response to a method will be wrapped in ```tool_output``` " + "use it to call more tools or generate a helpful, friendly response.\n" + "\n" + f"{functions_block}" + ) + + @staticmethod + def _tool_to_python_signature(func: dict) -> str: + """Convert an OpenAI function definition to a Python function signature with docstring.""" + name = func["name"] + desc = func.get("description", "") + params = func.get("parameters", {}) + properties = params.get("properties", {}) + required = set(params.get("required", [])) + + # Build parameter list + param_parts = [] + doc_args = [] + for pname, pinfo in properties.items(): + ptype = _json_type_to_python(pinfo.get("type", "str")) + pdesc = pinfo.get("description", "") + if pname in required: + param_parts.append(f"{pname}: {ptype}") + else: + default = _json_type_default(pinfo.get("type", "str")) + param_parts.append(f"{pname}: {ptype} = {default}") + doc_args.append(f" {pname}: {pdesc}" if pdesc else f" {pname}") + + sig = f"def {name}({', '.join(param_parts)}):" + doc_lines = [f' """{desc}'] + if doc_args: + doc_lines.append("") + doc_lines.append(" Args:") + doc_lines.extend(doc_args) + doc_lines.append(' """') + + return sig + "\n" + "\n".join(doc_lines) + + def _format_tool_calls_for_prompt(self, tool_calls: list[dict]) -> str: + """Format OpenAI-style tool calls back into Gemma 3 tool_code blocks.""" + parts = [] + for tc in tool_calls: + func = tc.get("function", tc) + name = func["name"] + args = func.get("arguments", "{}") + if isinstance(args, str): + args = json.loads(args) + # Format as Python function call + arg_parts = [f"{k}={_python_repr(v)}" for k, v in args.items()] + call_str = f"{name}({', '.join(arg_parts)})" + parts.append(f"```tool_code\n{call_str}\n```") + return "\n".join(parts) + + # ------------------------------------------------------------------ + # Generation + # ------------------------------------------------------------------ + + # Common kwargs for mlx_vlm generate calls — optimized for Apple Silicon + _GENERATE_KWARGS = { + "kv_bits": 8, # Quantize KV cache to 8-bit (halves memory bandwidth) + "kv_group_size": 64, # Group size for KV quantization + } + + def generate( + self, + prompt: str, + images: list[str] | None = None, + max_tokens: int = 4096, + temperature: float = 0.7, + top_p: float = 0.9, + stop: list[str] | None = None, + repetition_penalty: float = 1.1, + ) -> tuple[str, int, int]: + """Generate a complete response. Returns (text, prompt_tokens, completion_tokens).""" + with self._lock: + image_arg = images if images else None + result = mlx_vlm.generate( + self.model, + self.processor, + prompt, + image=image_arg, + max_tokens=max_tokens, + temperature=temperature, + top_p=top_p, + repetition_penalty=repetition_penalty, + verbose=False, + **self._GENERATE_KWARGS, + ) + text = result.text + if stop: + text = self._apply_stop(text, stop) + return text, result.prompt_tokens, result.generation_tokens + + def stream_generate( + self, + prompt: str, + images: list[str] | None = None, + max_tokens: int = 4096, + temperature: float = 0.7, + top_p: float = 0.9, + stop: list[str] | None = None, + repetition_penalty: float = 1.1, + ) -> Generator[tuple[str, bool, int, int], None, None]: + """Stream tokens. Yields (token_text, is_final, prompt_tokens, gen_tokens).""" + with self._lock: + image_arg = images if images else None + accumulated = "" + prompt_tokens = 0 + gen_tokens = 0 + for result in mlx_vlm.stream_generate( + self.model, + self.processor, + prompt, + image=image_arg, + max_tokens=max_tokens, + temperature=temperature, + top_p=top_p, + repetition_penalty=repetition_penalty, + **self._GENERATE_KWARGS, + ): + # result.text is the incremental segment (detokenizer.last_segment), + # NOT the full accumulated text. + token_text = result.text + accumulated += token_text + prompt_tokens = result.prompt_tokens + gen_tokens = result.generation_tokens + + if stop and self._check_stop(accumulated, stop): + # Trim the accumulated text and yield what's safe + trimmed = self._apply_stop(accumulated, stop) + # Only yield the part we haven't yielded yet + safe_delta = trimmed[len(accumulated) - len(token_text):] + yield safe_delta, True, prompt_tokens, gen_tokens + return + + yield token_text, False, prompt_tokens, gen_tokens + + # Final yield to signal completion + yield "", True, prompt_tokens, gen_tokens + + @staticmethod + def _apply_stop(text: str, stop: list[str]) -> str: + for s in stop: + idx = text.find(s) + if idx != -1: + text = text[:idx] + return text + + @staticmethod + def _check_stop(text: str, stop: list[str]) -> bool: + return any(s in text for s in stop) + + # ------------------------------------------------------------------ + # Tool call parsing from model output + # ------------------------------------------------------------------ + + @staticmethod + def parse_tool_calls( + text: str, tools: list[dict] | None = None + ) -> tuple[str, list[dict]]: + """Parse tool calls from model output using Gemma 3's tool_code format. + + Detects ```tool_code ... ``` blocks containing Python-style or + shell-style function calls. + + Returns (clean_text, tool_calls) where tool_calls is a list of + {"id": str, "type": "function", "function": {"name": str, "arguments": str}}. + """ + # Build a lookup of function name -> parameter schema + tool_defs: dict[str, dict] = {} + if tools: + for tool in tools: + func = tool.get("function", tool) + tool_defs[func["name"]] = func + + tool_calls = [] + pattern = r"```tool_code\s*(.*?)\s*```" + matches = re.findall(pattern, text, re.DOTALL) + + clean_text = re.sub(r"```tool_code\s*.*?\s*```", "", text, flags=re.DOTALL).strip() + + for i, match in enumerate(matches): + call_str = match.strip() + try: + name, args = _parse_python_call(call_str, tool_defs) + tool_calls.append({ + "id": f"call_{i}_{hash(call_str) % 10**8:08d}", + "type": "function", + "function": { + "name": name, + "arguments": json.dumps(args), + }, + }) + except Exception as e: + logger.warning("Failed to parse tool_code call %r: %s", call_str, e) + + return clean_text, tool_calls diff --git a/mlx_server/main.py b/mlx_server/main.py new file mode 100644 index 0000000..92e35aa --- /dev/null +++ b/mlx_server/main.py @@ -0,0 +1,278 @@ +"""OpenAI-compatible API server for Gemma 3 4B via MLX.""" + +from __future__ import annotations + +import argparse +import json +import logging +import time +import uuid + +import uvicorn +from fastapi import FastAPI, HTTPException +from fastapi.middleware.cors import CORSMiddleware +from sse_starlette.sse import EventSourceResponse + +from .engine import DEFAULT_MODEL, InferenceEngine +from .models import ( + ChatCompletionChunk, + ChatCompletionRequest, + ChatCompletionResponse, + Choice, + ChoiceMessage, + DeltaMessage, + ModelInfo, + ModelListResponse, + StreamChoice, + ToolCall, + FunctionCall, + UsageInfo, +) + +logger = logging.getLogger(__name__) + +app = FastAPI(title="MLX Server", description="OpenAI-compatible API for Gemma 3 4B") + +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +engine: InferenceEngine | None = None + + +def get_engine() -> InferenceEngine: + if engine is None: + raise HTTPException(status_code=503, detail="Model not loaded") + return engine + + +def _make_id() -> str: + return f"chatcmpl-{uuid.uuid4().hex[:12]}" + + +# ------------------------------------------------------------------ +# Endpoints +# ------------------------------------------------------------------ + + +@app.get("/v1/models") +async def list_models() -> ModelListResponse: + e = get_engine() + return ModelListResponse(data=[ModelInfo(id=e.model_path)]) + + +@app.post("/v1/chat/completions") +async def chat_completions(request: ChatCompletionRequest): + e = get_engine() + + # Convert pydantic messages to dicts + messages = [m.model_dump(exclude_none=True) for m in request.messages] + tools = None + if request.tools: + tools = [t.model_dump(exclude_none=True) for t in request.tools] + + prompt, images = e.build_prompt(messages, tools) + + stop = request.stop + if isinstance(stop, str): + stop = [stop] + + temperature = request.temperature if request.temperature is not None else 0.7 + top_p = request.top_p if request.top_p is not None else 0.9 + max_tokens = request.max_tokens if request.max_tokens is not None else 4096 + + if request.stream: + return EventSourceResponse( + _stream_response(e, prompt, images, max_tokens, temperature, top_p, stop, tools, request.model), + media_type="text/event-stream", + ) + + # Non-streaming + text, prompt_tokens, completion_tokens = e.generate( + prompt=prompt, + images=images or None, + max_tokens=max_tokens, + temperature=temperature, + top_p=top_p, + stop=stop, + ) + + # Check for tool calls in the response + finish_reason = "stop" + tool_calls_parsed = None + if tools: + clean_text, parsed = e.parse_tool_calls(text, tools) + if parsed: + tool_calls_parsed = [ + ToolCall( + index=i, + id=tc["id"], + type="function", + function=FunctionCall( + name=tc["function"]["name"], + arguments=tc["function"]["arguments"], + ), + ) + for i, tc in enumerate(parsed) + ] + text = clean_text if clean_text else None + finish_reason = "tool_calls" + + return ChatCompletionResponse( + id=_make_id(), + model=request.model, + choices=[ + Choice( + message=ChoiceMessage( + role="assistant", + content=text if not tool_calls_parsed else (text or None), + tool_calls=tool_calls_parsed, + ), + finish_reason=finish_reason, + ) + ], + usage=UsageInfo( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + ), + ) + + +async def _stream_response( + e: InferenceEngine, + prompt: str, + images: list[str] | None, + max_tokens: int, + temperature: float, + top_p: float, + stop: list[str] | None, + tools: list[dict] | None, + model_name: str, +): + request_id = _make_id() + created = int(time.time()) + + # Send initial chunk with role + initial_chunk = ChatCompletionChunk( + id=request_id, + created=created, + model=model_name, + choices=[StreamChoice(delta=DeltaMessage(role="assistant"))], + ) + yield {"data": initial_chunk.model_dump_json()} + + full_text = "" + prompt_tokens = 0 + gen_tokens = 0 + + for token_text, is_final, pt, gt in e.stream_generate( + prompt=prompt, + images=images or None, + max_tokens=max_tokens, + temperature=temperature, + top_p=top_p, + stop=stop, + ): + prompt_tokens = pt + gen_tokens = gt + full_text += token_text + + if not is_final and token_text: + chunk = ChatCompletionChunk( + id=request_id, + created=created, + model=model_name, + choices=[StreamChoice(delta=DeltaMessage(content=token_text))], + ) + yield {"data": chunk.model_dump_json()} + + # Check for tool calls in complete response + finish_reason = "stop" + if tools: + clean_text, parsed = e.parse_tool_calls(full_text, tools) + if parsed: + finish_reason = "tool_calls" + # Emit tool call chunks + for i, tc in enumerate(parsed): + tc_chunk = ChatCompletionChunk( + id=request_id, + created=created, + model=model_name, + choices=[ + StreamChoice( + delta=DeltaMessage( + tool_calls=[ + ToolCall( + index=i, + id=tc["id"], + type="function", + function=FunctionCall( + name=tc["function"]["name"], + arguments=tc["function"]["arguments"], + ), + ) + ] + ) + ) + ], + ) + yield {"data": tc_chunk.model_dump_json()} + + # Final chunk with finish reason and usage + final_chunk = ChatCompletionChunk( + id=request_id, + created=created, + model=model_name, + choices=[StreamChoice(delta=DeltaMessage(), finish_reason=finish_reason)], + usage=UsageInfo( + prompt_tokens=prompt_tokens, + completion_tokens=gen_tokens, + total_tokens=prompt_tokens + gen_tokens, + ), + ) + yield {"data": final_chunk.model_dump_json()} + yield {"data": "[DONE]"} + + +# ------------------------------------------------------------------ +# Health / utility +# ------------------------------------------------------------------ + + +@app.get("/health") +async def health(): + return {"status": "ok"} + + +# ------------------------------------------------------------------ +# Entrypoint +# ------------------------------------------------------------------ + + +def main(): + parser = argparse.ArgumentParser(description="MLX Server – OpenAI-compatible API") + parser.add_argument("--model", type=str, default=DEFAULT_MODEL, help="HuggingFace model path") + parser.add_argument("--host", type=str, default="127.0.0.1") + parser.add_argument("--port", type=int, default=1234) + parser.add_argument("--log-level", type=str, default="info") + args = parser.parse_args() + + logging.basicConfig( + level=getattr(logging, args.log_level.upper()), + format="%(asctime)s %(levelname)s %(name)s: %(message)s", + ) + + global engine + engine = InferenceEngine(model_path=args.model) + engine.load() + + uvicorn.run(app, host=args.host, port=args.port, log_level=args.log_level) + + +if __name__ == "__main__": + main() diff --git a/mlx_server/models.py b/mlx_server/models.py new file mode 100644 index 0000000..c5c2ea1 --- /dev/null +++ b/mlx_server/models.py @@ -0,0 +1,144 @@ +"""OpenAI API compatible request/response models.""" + +from __future__ import annotations + +import time +from typing import Any, Literal + +from pydantic import BaseModel, Field + + +# --- Request models --- + + +class FunctionDefinition(BaseModel): + name: str + description: str | None = None + parameters: dict[str, Any] | None = None + + +class ToolDefinition(BaseModel): + type: Literal["function"] = "function" + function: FunctionDefinition + + +class FunctionCall(BaseModel): + name: str + arguments: str # JSON string + + +class ToolCall(BaseModel): + index: int = 0 + id: str + type: Literal["function"] = "function" + function: FunctionCall + + +class ContentPartText(BaseModel): + type: Literal["text"] = "text" + text: str + + +class ImageURL(BaseModel): + url: str # Can be a URL or base64 data URI + detail: str | None = None + + +class ContentPartImage(BaseModel): + type: Literal["image_url"] = "image_url" + image_url: ImageURL + + +ContentPart = ContentPartText | ContentPartImage + + +class ChatMessage(BaseModel): + role: Literal["system", "user", "assistant", "tool"] + content: str | list[ContentPart] | None = None + name: str | None = None + tool_calls: list[ToolCall] | None = None + tool_call_id: str | None = None + + +class ChatCompletionRequest(BaseModel): + model: str = "gemma-3-4b-it" + messages: list[ChatMessage] + temperature: float | None = 0.7 + top_p: float | None = 0.9 + max_tokens: int | None = 4096 + stream: bool = False + stop: str | list[str] | None = None + tools: list[ToolDefinition] | None = None + tool_choice: str | dict | None = None + frequency_penalty: float | None = None + presence_penalty: float | None = None + n: int | None = 1 + + +# --- Response models --- + + +class UsageInfo(BaseModel): + prompt_tokens: int = 0 + completion_tokens: int = 0 + total_tokens: int = 0 + + +class ChoiceMessage(BaseModel): + role: str = "assistant" + content: str | None = None + tool_calls: list[ToolCall] | None = None + + +class Choice(BaseModel): + index: int = 0 + message: ChoiceMessage + finish_reason: str | None = "stop" + + +class ChatCompletionResponse(BaseModel): + id: str + object: str = "chat.completion" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + choices: list[Choice] + usage: UsageInfo + + +# --- Streaming response models --- + + +class DeltaMessage(BaseModel): + role: str | None = None + content: str | None = None + tool_calls: list[ToolCall] | None = None + + +class StreamChoice(BaseModel): + index: int = 0 + delta: DeltaMessage + finish_reason: str | None = None + + +class ChatCompletionChunk(BaseModel): + id: str + object: str = "chat.completion.chunk" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + choices: list[StreamChoice] + usage: UsageInfo | None = None + + +# --- Model listing --- + + +class ModelInfo(BaseModel): + id: str + object: str = "model" + created: int = Field(default_factory=lambda: int(time.time())) + owned_by: str = "local" + + +class ModelListResponse(BaseModel): + object: str = "list" + data: list[ModelInfo] diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..d6b0e67 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,20 @@ +[project] +name = "mlx-server" +version = "0.1.0" +description = "OpenAI-compatible API server for Gemma 3 4B via MLX" +requires-python = ">=3.11" +dependencies = [ + "fastapi>=0.115.0", + "uvicorn[standard]>=0.30.0", + "mlx>=0.22.0", + "mlx-lm>=0.22.0", + "mlx-vlm>=0.1.18", + "pydantic>=2.0.0", + "sse-starlette>=2.0.0", + "pillow>=10.0.0", + "httpx>=0.27.0", + "torchvision>=0.20.0", +] + +[project.scripts] +mlx-server = "mlx_server.main:main" diff --git a/run.sh b/run.sh new file mode 100755 index 0000000..2def61f --- /dev/null +++ b/run.sh @@ -0,0 +1,24 @@ +#!/usr/bin/env bash +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" +cd "$SCRIPT_DIR" + +# Activate virtual environment +source .venv/bin/activate + +# Default model – 4-bit quantized Gemma 3 4B IT (vision-capable) +MODEL="${MODEL:-mlx-community/gemma-3-4b-it-4bit}" +HOST="${HOST:-127.0.0.1}" +PORT="${PORT:-1234}" + +echo "Starting MLX Server..." +echo " Model: $MODEL" +echo " Endpoint: http://$HOST:$PORT" +echo "" + +exec python -m mlx_server.main \ + --model "$MODEL" \ + --host "$HOST" \ + --port "$PORT" \ + "$@" diff --git a/test_server.py b/test_server.py new file mode 100644 index 0000000..88e1ae6 --- /dev/null +++ b/test_server.py @@ -0,0 +1,296 @@ +"""Test script for MLX Server – exercises chat, streaming, vision, and tool use.""" + +import base64 +import io +import json +import sys + +import httpx +from PIL import Image, ImageDraw + +BASE_URL = "http://127.0.0.1:1234/v1" +MODEL = "mlx-community/gemma-3-4b-it-4bit" + + +def test_models(): + """Test GET /v1/models.""" + print("=" * 60) + print("TEST: List models") + print("=" * 60) + r = httpx.get(f"{BASE_URL}/models") + r.raise_for_status() + data = r.json() + print(f"Models: {[m['id'] for m in data['data']]}") + print("PASS\n") + + +def test_chat_basic(): + """Test basic non-streaming chat.""" + print("=" * 60) + print("TEST: Basic chat (non-streaming)") + print("=" * 60) + r = httpx.post( + f"{BASE_URL}/chat/completions", + json={ + "model": MODEL, + "messages": [{"role": "user", "content": "Say exactly: 'The sky is blue.' Nothing else."}], + "max_tokens": 50, + "temperature": 0.1, + }, + timeout=120, + ) + r.raise_for_status() + data = r.json() + msg = data["choices"][0]["message"]["content"] + usage = data["usage"] + print(f"Response: {msg}") + print(f"Usage: {usage}") + print(f"Finish reason: {data['choices'][0]['finish_reason']}") + print("PASS\n") + + +def test_chat_streaming(): + """Test streaming chat.""" + print("=" * 60) + print("TEST: Streaming chat") + print("=" * 60) + collected = "" + with httpx.stream( + "POST", + f"{BASE_URL}/chat/completions", + json={ + "model": MODEL, + "messages": [{"role": "user", "content": "Count from 1 to 5, one number per line."}], + "max_tokens": 100, + "temperature": 0.1, + "stream": True, + }, + timeout=120, + ) as response: + response.raise_for_status() + for line in response.iter_lines(): + if not line.startswith("data: "): + continue + payload = line[len("data: "):] + if payload == "[DONE]": + break + chunk = json.loads(payload) + delta = chunk["choices"][0]["delta"] + if delta.get("content"): + collected += delta["content"] + print(delta["content"], end="", flush=True) + if chunk["choices"][0].get("finish_reason"): + print(f"\n[finish_reason: {chunk['choices'][0]['finish_reason']}]") + if chunk.get("usage") and chunk["usage"].get("total_tokens", 0) > 0: + print(f"[usage: {chunk['usage']}]") + print(f"Full collected: {collected!r}") + print("PASS\n") + + +def _make_test_image() -> str: + """Create a simple test image and return it as a base64 data URI.""" + img = Image.new("RGB", (200, 200), color=(135, 206, 235)) + draw = ImageDraw.Draw(img) + # Draw a red circle + draw.ellipse([50, 50, 150, 150], fill=(255, 0, 0), outline=(0, 0, 0), width=2) + # Draw a green triangle + draw.polygon([(100, 20), (60, 80), (140, 80)], fill=(0, 180, 0), outline=(0, 0, 0)) + # Draw yellow text area + draw.rectangle([10, 160, 190, 190], fill=(255, 255, 0)) + + buf = io.BytesIO() + img.save(buf, format="PNG") + b64 = base64.b64encode(buf.getvalue()).decode() + return f"data:image/png;base64,{b64}" + + +def test_vision(): + """Test vision with an image.""" + print("=" * 60) + print("TEST: Vision (image description)") + print("=" * 60) + image_uri = _make_test_image() + print(f"Image: 200x200 PNG with red circle, green triangle, yellow bar") + + r = httpx.post( + f"{BASE_URL}/chat/completions", + json={ + "model": MODEL, + "messages": [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Describe what shapes and colors you see in this image. Be brief."}, + {"type": "image_url", "image_url": {"url": image_uri}}, + ], + } + ], + "max_tokens": 200, + "temperature": 0.1, + }, + timeout=120, + ) + r.raise_for_status() + data = r.json() + msg = data["choices"][0]["message"]["content"] + print(f"Response: {msg}") + print("PASS\n") + + +def test_tool_use(): + """Test tool calling.""" + print("=" * 60) + print("TEST: Tool use") + print("=" * 60) + + tools = [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get the current weather for a given city", + "parameters": { + "type": "object", + "properties": { + "city": { + "type": "string", + "description": "The city name, e.g. 'London'", + }, + "units": { + "type": "string", + "description": "Temperature units: 'celsius' or 'fahrenheit'", + }, + }, + "required": ["city"], + }, + }, + } + ] + + # Step 1: Ask the model to use the tool + print("Step 1: Asking model to get weather for Paris...") + r = httpx.post( + f"{BASE_URL}/chat/completions", + json={ + "model": MODEL, + "messages": [ + {"role": "user", "content": "What is the weather in Paris right now? Use the get_weather tool."}, + ], + "tools": tools, + "max_tokens": 300, + "temperature": 0.1, + }, + timeout=120, + ) + r.raise_for_status() + data = r.json() + choice = data["choices"][0] + print(f"Finish reason: {choice['finish_reason']}") + print(f"Content: {choice['message'].get('content')}") + print(f"Tool calls: {choice['message'].get('tool_calls')}") + + if choice["message"].get("tool_calls"): + tc = choice["message"]["tool_calls"][0] + print(f"\nTool call detected:") + print(f" ID: {tc['id']}") + print(f" Function: {tc['function']['name']}") + print(f" Arguments: {tc['function']['arguments']}") + + # Step 2: Send the tool result back + print("\nStep 2: Sending mock tool result back...") + r2 = httpx.post( + f"{BASE_URL}/chat/completions", + json={ + "model": MODEL, + "messages": [ + {"role": "user", "content": "What is the weather in Paris right now? Use the get_weather tool."}, + { + "role": "assistant", + "content": choice["message"].get("content"), + "tool_calls": choice["message"]["tool_calls"], + }, + { + "role": "tool", + "tool_call_id": tc["id"], + "content": json.dumps({"temperature": 18, "condition": "Partly cloudy", "humidity": 65}), + }, + ], + "tools": tools, + "max_tokens": 300, + "temperature": 0.1, + }, + timeout=120, + ) + r2.raise_for_status() + data2 = r2.json() + msg2 = data2["choices"][0]["message"]["content"] + print(f"Final response: {msg2}") + else: + print("WARNING: Model did not produce a tool call. Raw response above.") + + print("PASS\n") + + +def test_multi_turn(): + """Test multi-turn conversation.""" + print("=" * 60) + print("TEST: Multi-turn conversation") + print("=" * 60) + messages = [ + {"role": "user", "content": "My name is Alice."}, + ] + r = httpx.post( + f"{BASE_URL}/chat/completions", + json={"model": MODEL, "messages": messages, "max_tokens": 100, "temperature": 0.1}, + timeout=120, + ) + r.raise_for_status() + reply1 = r.json()["choices"][0]["message"]["content"] + print(f"Turn 1 reply: {reply1}") + + messages.append({"role": "assistant", "content": reply1}) + messages.append({"role": "user", "content": "What is my name?"}) + + r2 = httpx.post( + f"{BASE_URL}/chat/completions", + json={"model": MODEL, "messages": messages, "max_tokens": 100, "temperature": 0.1}, + timeout=120, + ) + r2.raise_for_status() + reply2 = r2.json()["choices"][0]["message"]["content"] + print(f"Turn 2 reply: {reply2}") + assert "alice" in reply2.lower(), f"Expected 'Alice' in response, got: {reply2}" + print("PASS\n") + + +if __name__ == "__main__": + tests = [ + test_models, + test_chat_basic, + test_chat_streaming, + test_vision, + test_tool_use, + test_multi_turn, + ] + + # Allow running a single test by name + if len(sys.argv) > 1: + name = sys.argv[1] + tests = [t for t in tests if name in t.__name__] + if not tests: + print(f"No test matching '{name}'. Available: models, chat_basic, chat_streaming, vision, tool_use, multi_turn") + sys.exit(1) + + passed = 0 + failed = 0 + for test in tests: + try: + test() + passed += 1 + except Exception as e: + print(f"FAIL: {e}\n") + failed += 1 + + print("=" * 60) + print(f"Results: {passed} passed, {failed} failed") + print("=" * 60)