From cc6e761ed4bbf08e80c34ac9db50baf493623c53 Mon Sep 17 00:00:00 2001 From: Chili Palmer Date: Tue, 17 Mar 2026 11:44:24 +0100 Subject: [PATCH] feat: qwen now works, too --- .gitignore | 1 + CLAUDE.md | 19 ++- mlx_server/engine.py | 280 ++++++++++++++++++++++++++++++++++++++++--- mlx_server/main.py | 75 +++++++----- run.sh | 26 +++- 5 files changed, 351 insertions(+), 50 deletions(-) diff --git a/.gitignore b/.gitignore index 70b6236..dfa92aa 100644 --- a/.gitignore +++ b/.gitignore @@ -8,3 +8,4 @@ build/ .env *.log .DS_Store +settings.local.json diff --git a/CLAUDE.md b/CLAUDE.md index 9542102..65391d4 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -1,6 +1,6 @@ # MLX Server -OpenAI-compatible API server for Gemma 3 4B (vision + tool use) on Apple Silicon via MLX. +OpenAI-compatible API server for local LLMs on Apple Silicon via MLX. Supports Gemma 3 4B and Qwen3 VL 4B (vision + tool use). ## Quick Start @@ -8,11 +8,15 @@ OpenAI-compatible API server for Gemma 3 4B (vision + tool use) on Apple Silicon # Activate virtual environment source .venv/bin/activate -# Run the server (downloads model on first run) +# Run with Gemma 3 (default) ./run.sh +# Run with Qwen3 +./run.sh qwen + # Or directly: python -m mlx_server.main --model mlx-community/gemma-3-4b-it-4bit --port 1234 +python -m mlx_server.main --model mlx-community/Qwen3-VL-4B-Instruct-4bit --port 1234 ``` ## Project Structure @@ -21,11 +25,18 @@ python -m mlx_server.main --model mlx-community/gemma-3-4b-it-4bit --port 1234 - `mlx_server/engine.py` — Model loading, prompt building, generation (mlx_vlm) - `mlx_server/models.py` — Pydantic models for OpenAI API request/response types +## Supported Models + +| Alias | HuggingFace ID | Notes | +|-------|---------------|-------| +| `gemma` | `mlx-community/gemma-3-4b-it-4bit` | Vision + tool use via `tool_code` blocks | +| `qwen` | `mlx-community/Qwen3-VL-4B-Instruct-4bit` | Vision + tool use via `` tags | + ## Key Design Decisions - Uses `mlx_vlm` (not `mlx_lm`) as the inference backend — this supports both text and vision in a single model load -- Gemma 3 has no system role — system messages are converted to user/assistant pairs -- Tool use is prompt-engineered: tools are injected into the system prompt with `` XML tags, and parsed from model output +- Model-specific prompt formatting: Gemma converts system→user/assistant pairs and uses `tool_code` blocks; Qwen3 uses native system role and `` XML tags +- Offline-first: if the model is already cached locally (~/.cache/huggingface/hub/), the server resolves the local snapshot path directly — no network requests are made (HEAD checks, update checks, etc.) - Thread lock on generation (single-request-at-a-time) — MLX models aren't safe for concurrent generation - 128k context window supported via the model's native capabilities diff --git a/mlx_server/engine.py b/mlx_server/engine.py index 5e1014a..ef9d2f6 100644 --- a/mlx_server/engine.py +++ b/mlx_server/engine.py @@ -20,6 +20,58 @@ logger = logging.getLogger(__name__) DEFAULT_MODEL = "mlx-community/gemma-3-4b-it-4bit" +# Known model aliases for quick selection +MODEL_ALIASES: dict[str, str] = { + "gemma": "mlx-community/gemma-3-4b-it-4bit", + "qwen": "mlx-community/Qwen3-VL-4B-Instruct-4bit", +} + + +def _resolve_local_model_path(repo_id: str) -> Path | None: + """If a HuggingFace model is already cached locally, return its snapshot path. + + This avoids any network requests (HEAD checks) when the model files are + already present on disk — critical for offline use. + """ + # If it's already a local directory, just use it + local = Path(repo_id) + if local.is_dir(): + return local + + # Check the HF cache: ~/.cache/huggingface/hub/models--org--name/snapshots/ + cache_root = Path.home() / ".cache" / "huggingface" / "hub" + safe_name = "models--" + repo_id.replace("/", "--") + model_cache = cache_root / safe_name + + if not model_cache.is_dir(): + return None + + # Read the ref to find the snapshot hash + refs_dir = model_cache / "refs" + snapshot_dir = model_cache / "snapshots" + if refs_dir.is_dir() and snapshot_dir.is_dir(): + main_ref = refs_dir / "main" + if main_ref.is_file(): + commit_hash = main_ref.read_text().strip() + snap = snapshot_dir / commit_hash + if snap.is_dir(): + logger.info( + "Found locally cached model at %s — skipping online check", snap + ) + return snap + + # Fallback: use the first (most recent) snapshot if refs/main is missing + if snapshot_dir.is_dir(): + snapshots = sorted(snapshot_dir.iterdir(), key=lambda p: p.stat().st_mtime, reverse=True) + if snapshots: + logger.info( + "Found locally cached model at %s — skipping online check", + snapshots[0], + ) + return snapshots[0] + + return None + # ------------------------------------------------------------------ # Helpers for Gemma 3 tool_code format @@ -200,17 +252,35 @@ class InferenceEngine: self.model = None self.processor = None self.config = None + self._model_type: str = "" # e.g. "gemma3", "qwen3" self._lock = threading.Lock() self._prompt_cache = PromptCache() def load(self) -> None: logger.info("Loading model %s ...", self.model_path) - self.model, self.processor = mlx_vlm.load(self.model_path) + + # Prefer the local cache to avoid any network requests + local_path = _resolve_local_model_path(self.model_path) + load_path = str(local_path) if local_path else self.model_path + + self.model, self.processor = mlx_vlm.load(load_path) + # Load model config for chat template from transformers import AutoConfig - self.config = AutoConfig.from_pretrained(self.model_path, trust_remote_code=True) - logger.info("Model loaded successfully.") + self.config = AutoConfig.from_pretrained(load_path, trust_remote_code=True) + + # Detect model family for prompt-format branching + self._model_type = getattr(self.config, "model_type", "").lower() + logger.info("Model loaded successfully (type=%s).", self._model_type) + + @property + def is_qwen(self) -> bool: + return "qwen" in self._model_type + + @property + def is_gemma(self) -> bool: + return "gemma" in self._model_type # ------------------------------------------------------------------ # Image helpers @@ -244,6 +314,16 @@ class InferenceEngine: Returns (prompt_str, image_paths). """ + if self.is_qwen: + return self._build_prompt_qwen(messages, tools) + return self._build_prompt_gemma(messages, tools) + + def _build_prompt_gemma( + self, + messages: list[dict], + tools: list[dict] | None = None, + ) -> tuple[str, list[str]]: + """Gemma 3 prompt builder (tool_code format, no system role).""" image_paths: list[str] = [] formatted_messages: list[dict] = [] @@ -306,6 +386,58 @@ class InferenceEngine: return prompt, image_paths + def _build_prompt_qwen( + self, + messages: list[dict], + tools: list[dict] | None = None, + ) -> tuple[str, list[str]]: + """Qwen3 prompt builder (native system role, JSON tool calls).""" + image_paths: list[str] = [] + formatted_messages: list[dict] = [] + + # Qwen3 supports system role natively — inject tools there + has_system = any(m.get("role") == "system" for m in messages) + if tools and not has_system: + formatted_messages.append({ + "role": "system", + "content": self._build_qwen_tool_system_prompt(tools), + }) + + for msg in messages: + role = msg["role"] + content = msg.get("content") + tool_calls = msg.get("tool_calls") + + if role == "system": + text = self._get_text_content(content) + if tools: + text = text + "\n\n" + self._build_qwen_tool_system_prompt(tools) + formatted_messages.append({"role": "system", "content": text}) + elif role == "user": + text, imgs = self._extract_content_parts(content) + image_paths.extend(imgs) + formatted_messages.append({"role": "user", "content": text}) + elif role == "assistant": + text = self._get_text_content(content) or "" + if tool_calls: + tc_text = self._format_qwen_tool_calls_for_prompt(tool_calls) + text = (text + "\n" + tc_text).strip() + formatted_messages.append({"role": "assistant", "content": text}) + elif role == "tool": + tool_text = self._get_text_content(content) or "" + formatted_messages.append({"role": "user", "content": tool_text}) + + # Apply chat template via mlx_vlm + prompt = mlx_vlm.apply_chat_template( + self.processor, + self.config, + formatted_messages, + add_generation_prompt=True, + num_images=len(image_paths), + ) + + return prompt, image_paths + @staticmethod def _merge_consecutive_roles(messages: list[dict]) -> list[dict]: """Merge consecutive messages with the same role into one. @@ -439,6 +571,50 @@ class InferenceEngine: parts.append(f"```tool_code\n{call_str}\n```") return "\n".join(parts) + # ------------------------------------------------------------------ + # Qwen3 tool helpers + # ------------------------------------------------------------------ + + @staticmethod + def _build_qwen_tool_system_prompt(tools: list[dict]) -> str: + """Build the tool system prompt for Qwen3 using its native JSON format.""" + tool_descs = [] + for tool in tools: + func = tool.get("function", tool) + tool_descs.append({ + "type": "function", + "function": { + "name": func["name"], + "description": func.get("description", ""), + "parameters": func.get("parameters", {}), + }, + }) + tools_json = json.dumps(tool_descs, indent=2) + return ( + "# Tools\n\n" + "You are a helpful assistant with access to the following tools. " + "Use them when appropriate by responding with a JSON tool call.\n\n" + "## Available Tools\n\n" + f"{tools_json}\n\n" + "## Tool Call Format\n\n" + "When you need to call a tool, respond with:\n" + '\n{"name": "", "arguments": {}}\n' + ) + + @staticmethod + def _format_qwen_tool_calls_for_prompt(tool_calls: list[dict]) -> str: + """Format OpenAI-style tool calls back into Qwen3's XML tag format.""" + parts = [] + for tc in tool_calls: + func = tc.get("function", tc) + name = func["name"] + args = func.get("arguments", "{}") + if isinstance(args, str): + args = json.loads(args) + call_obj = {"name": name, "arguments": args} + parts.append(f"\n{json.dumps(call_obj)}\n") + return "\n".join(parts) + # ------------------------------------------------------------------ # Prefix cache & generation # ------------------------------------------------------------------ @@ -447,6 +623,32 @@ class InferenceEngine: # Note: KV cache quantization is not supported with Gemma 3's RotatingKVCache _GENERATE_KWARGS: dict = {} + # Keys in the prep dict that are internal bookkeeping, not kwargs for + # mlx_vlm.stream_generate. + _PREP_INTERNAL_KEYS = frozenset({ + "input_ids", "pixel_values", "mask", "prompt_cache", + "_full_token_ids", "_prompt_token_count", + }) + + def _extra_generate_kwargs( + self, images: list[str] | None, prep: dict | None = None, + ) -> dict: + """Build per-request kwargs for mlx_vlm.stream_generate. + + Includes model-specific keys from prepare_inputs (e.g. image_grid_thw + for Qwen3-VL) and works around a chunked-prefill bug where + visual_pos_masks is None for text-only requests. + """ + extra: dict = dict(self._GENERATE_KWARGS) + if self.is_qwen and not images: + extra["prefill_step_size"] = None + # Forward any model-specific keys that prepare_inputs returned + if prep: + for k, v in prep.items(): + if k not in self._PREP_INTERNAL_KEYS: + extra[k] = v + return extra + def _get_tokenizer(self): """Get the underlying tokenizer from the processor.""" proc = self.processor @@ -484,6 +686,11 @@ class InferenceEngine: pixel_values = inputs.get("pixel_values") mask = inputs.get("attention_mask") + # Collect any model-specific extra keys from prepare_inputs + # (e.g. image_grid_thw for Qwen3-VL) so they reach the model. + _KNOWN_KEYS = {"input_ids", "pixel_values", "attention_mask"} + extra_inputs = {k: v for k, v in inputs.items() if k not in _KNOWN_KEYS} + full_token_list = full_input_ids.flatten().tolist() prefix_len = self._prompt_cache.get_reusable_length(full_token_list) @@ -521,7 +728,9 @@ class InferenceEngine: } # Cache miss — create a fresh KV cache - cache = cache_module.make_prompt_cache(self.model.language_model) + # VLM models expose .language_model; text-only models are the LM directly + lm = getattr(self.model, "language_model", self.model) + cache = cache_module.make_prompt_cache(lm) logger.info( "Prefix cache miss: processing %d tokens from scratch", len(full_token_list), @@ -533,6 +742,7 @@ class InferenceEngine: "prompt_cache": cache, "_full_token_ids": full_token_list, "_prompt_token_count": len(full_token_list), + **extra_inputs, } def _save_cache(self, prep: dict, generated_tokens: list[int]) -> None: @@ -555,9 +765,10 @@ class InferenceEngine: prep = self._prepare_generation(prompt, images) prompt_token_count = prep["_prompt_token_count"] - # Ensure stopping criteria is initialised + # Ensure stopping criteria is initialised (Gemma-specific; optional for others) tokenizer = self._get_tokenizer() - tokenizer.stopping_criteria.reset(self.model.config.eos_token_id) + if hasattr(tokenizer, "stopping_criteria"): + tokenizer.stopping_criteria.reset(self.model.config.eos_token_id) text = "" generated_tokens: list[int] = [] @@ -575,7 +786,7 @@ class InferenceEngine: temperature=temperature, top_p=top_p, repetition_penalty=repetition_penalty, - **self._GENERATE_KWARGS, + **self._extra_generate_kwargs(images, prep), ): text += result.text if result.token is not None: @@ -603,9 +814,10 @@ class InferenceEngine: prep = self._prepare_generation(prompt, images) prompt_token_count = prep["_prompt_token_count"] - # Ensure stopping criteria is initialised + # Ensure stopping criteria is initialised (Gemma-specific; optional for others) tokenizer = self._get_tokenizer() - tokenizer.stopping_criteria.reset(self.model.config.eos_token_id) + if hasattr(tokenizer, "stopping_criteria"): + tokenizer.stopping_criteria.reset(self.model.config.eos_token_id) accumulated = "" generated_tokens: list[int] = [] @@ -624,7 +836,7 @@ class InferenceEngine: temperature=temperature, top_p=top_p, repetition_penalty=repetition_penalty, - **self._GENERATE_KWARGS, + **self._extra_generate_kwargs(images, prep), ): token_text = result.text accumulated += token_text @@ -663,19 +875,26 @@ class InferenceEngine: # Tool call parsing from model output # ------------------------------------------------------------------ - @staticmethod def parse_tool_calls( - text: str, tools: list[dict] | None = None + self, text: str, tools: list[dict] | None = None ) -> tuple[str, list[dict]]: - """Parse tool calls from model output using Gemma 3's tool_code format. + """Parse tool calls from model output. - Detects ```tool_code ... ``` blocks containing Python-style or - shell-style function calls. + Supports both Gemma 3's ```tool_code``` blocks and Qwen3's + XML tags. Returns (clean_text, tool_calls) where tool_calls is a list of {"id": str, "type": "function", "function": {"name": str, "arguments": str}}. """ - # Build a lookup of function name -> parameter schema + if self.is_qwen: + return self._parse_tool_calls_qwen(text) + return self._parse_tool_calls_gemma(text, tools) + + @staticmethod + def _parse_tool_calls_gemma( + text: str, tools: list[dict] | None = None + ) -> tuple[str, list[dict]]: + """Parse Gemma 3 tool_code blocks.""" tool_defs: dict[str, dict] = {} if tools: for tool in tools: @@ -704,3 +923,32 @@ class InferenceEngine: logger.warning("Failed to parse tool_code call %r: %s", call_str, e) return clean_text, tool_calls + + @staticmethod + def _parse_tool_calls_qwen(text: str) -> tuple[str, list[dict]]: + """Parse Qwen3 XML tags.""" + tool_calls = [] + pattern = r"\s*(.*?)\s*" + matches = re.findall(pattern, text, re.DOTALL) + + clean_text = re.sub(r"\s*.*?\s*", "", text, flags=re.DOTALL).strip() + + for i, match in enumerate(matches): + try: + call_obj = json.loads(match.strip()) + name = call_obj.get("name", "") + args = call_obj.get("arguments", {}) + if isinstance(args, str): + args = json.loads(args) + tool_calls.append({ + "id": f"call_{i}_{hash(match) % 10**8:08d}", + "type": "function", + "function": { + "name": name, + "arguments": json.dumps(args), + }, + }) + except Exception as e: + logger.warning("Failed to parse tool_call tag %r: %s", match, e) + + return clean_text, tool_calls diff --git a/mlx_server/main.py b/mlx_server/main.py index 92e35aa..4920bca 100644 --- a/mlx_server/main.py +++ b/mlx_server/main.py @@ -1,4 +1,4 @@ -"""OpenAI-compatible API server for Gemma 3 4B via MLX.""" +"""OpenAI-compatible API server for local LLMs (Gemma 3, Qwen3, …) via MLX.""" from __future__ import annotations @@ -31,7 +31,7 @@ from .models import ( logger = logging.getLogger(__name__) -app = FastAPI(title="MLX Server", description="OpenAI-compatible API for Gemma 3 4B") +app = FastAPI(title="MLX Server", description="OpenAI-compatible API for local LLMs on Apple Silicon") app.add_middleware( CORSMiddleware, @@ -170,6 +170,11 @@ async def _stream_response( prompt_tokens = 0 gen_tokens = 0 + # When tools are available we must buffer the full response before + # emitting content — otherwise raw tool-call markup (```tool_code``` + # or ) leaks into the streamed text. + buffer_for_tools = bool(tools) + for token_text, is_final, pt, gt in e.stream_generate( prompt=prompt, images=images or None, @@ -182,7 +187,7 @@ async def _stream_response( gen_tokens = gt full_text += token_text - if not is_final and token_text: + if not buffer_for_tools and not is_final and token_text: chunk = ChatCompletionChunk( id=request_id, created=created, @@ -191,37 +196,53 @@ async def _stream_response( ) yield {"data": chunk.model_dump_json()} - # Check for tool calls in complete response + # --- Post-generation: parse tool calls and emit clean content ------ finish_reason = "stop" + tool_calls_parsed = [] + if tools: clean_text, parsed = e.parse_tool_calls(full_text, tools) if parsed: finish_reason = "tool_calls" - # Emit tool call chunks - for i, tc in enumerate(parsed): - tc_chunk = ChatCompletionChunk( - id=request_id, - created=created, - model=model_name, - choices=[ - StreamChoice( - delta=DeltaMessage( - tool_calls=[ - ToolCall( - index=i, - id=tc["id"], - type="function", - function=FunctionCall( - name=tc["function"]["name"], - arguments=tc["function"]["arguments"], - ), - ) - ] + tool_calls_parsed = parsed + full_text = clean_text or "" + + # Emit buffered content (when tools were present, this is the cleaned + # text with tool-call markup stripped out) + if buffer_for_tools and full_text.strip(): + content_chunk = ChatCompletionChunk( + id=request_id, + created=created, + model=model_name, + choices=[StreamChoice(delta=DeltaMessage(content=full_text))], + ) + yield {"data": content_chunk.model_dump_json()} + + # Emit tool call chunks + for i, tc in enumerate(tool_calls_parsed): + tc_chunk = ChatCompletionChunk( + id=request_id, + created=created, + model=model_name, + choices=[ + StreamChoice( + delta=DeltaMessage( + tool_calls=[ + ToolCall( + index=i, + id=tc["id"], + type="function", + function=FunctionCall( + name=tc["function"]["name"], + arguments=tc["function"]["arguments"], + ), ) - ) - ], + ] + ) ) - yield {"data": tc_chunk.model_dump_json()} + ], + ) + yield {"data": tc_chunk.model_dump_json()} # Final chunk with finish reason and usage final_chunk = ChatCompletionChunk( diff --git a/run.sh b/run.sh index 2def61f..ab197c1 100755 --- a/run.sh +++ b/run.sh @@ -7,8 +7,28 @@ cd "$SCRIPT_DIR" # Activate virtual environment source .venv/bin/activate -# Default model – 4-bit quantized Gemma 3 4B IT (vision-capable) -MODEL="${MODEL:-mlx-community/gemma-3-4b-it-4bit}" +# --- Model selection --- +# Usage: ./run.sh [gemma|qwen] +# Or set MODEL env var directly for a custom model. + +MODEL_CHOICE="${1:-gemma}" + +if [[ -z "${MODEL:-}" ]]; then + case "$MODEL_CHOICE" in + gemma) + MODEL="mlx-community/gemma-3-4b-it-4bit" + ;; + qwen) + MODEL="mlx-community/Qwen3-VL-4B-Instruct-4bit" + ;; + *) + echo "Unknown model choice: $MODEL_CHOICE" + echo "Usage: $0 [gemma|qwen]" + exit 1 + ;; + esac +fi + HOST="${HOST:-127.0.0.1}" PORT="${PORT:-1234}" @@ -21,4 +41,4 @@ exec python -m mlx_server.main \ --model "$MODEL" \ --host "$HOST" \ --port "$PORT" \ - "$@" + "${@:2}"