feat: qwen now works, too

This commit is contained in:
2026-03-17 11:44:24 +01:00
parent bdfbd14577
commit cc6e761ed4
5 changed files with 351 additions and 50 deletions

1
.gitignore vendored
View File

@@ -8,3 +8,4 @@ build/
.env .env
*.log *.log
.DS_Store .DS_Store
settings.local.json

View File

@@ -1,6 +1,6 @@
# MLX Server # MLX Server
OpenAI-compatible API server for Gemma 3 4B (vision + tool use) on Apple Silicon via MLX. OpenAI-compatible API server for local LLMs on Apple Silicon via MLX. Supports Gemma 3 4B and Qwen3 VL 4B (vision + tool use).
## Quick Start ## Quick Start
@@ -8,11 +8,15 @@ OpenAI-compatible API server for Gemma 3 4B (vision + tool use) on Apple Silicon
# Activate virtual environment # Activate virtual environment
source .venv/bin/activate source .venv/bin/activate
# Run the server (downloads model on first run) # Run with Gemma 3 (default)
./run.sh ./run.sh
# Run with Qwen3
./run.sh qwen
# Or directly: # Or directly:
python -m mlx_server.main --model mlx-community/gemma-3-4b-it-4bit --port 1234 python -m mlx_server.main --model mlx-community/gemma-3-4b-it-4bit --port 1234
python -m mlx_server.main --model mlx-community/Qwen3-VL-4B-Instruct-4bit --port 1234
``` ```
## Project Structure ## Project Structure
@@ -21,11 +25,18 @@ python -m mlx_server.main --model mlx-community/gemma-3-4b-it-4bit --port 1234
- `mlx_server/engine.py` — Model loading, prompt building, generation (mlx_vlm) - `mlx_server/engine.py` — Model loading, prompt building, generation (mlx_vlm)
- `mlx_server/models.py` — Pydantic models for OpenAI API request/response types - `mlx_server/models.py` — Pydantic models for OpenAI API request/response types
## Supported Models
| Alias | HuggingFace ID | Notes |
|-------|---------------|-------|
| `gemma` | `mlx-community/gemma-3-4b-it-4bit` | Vision + tool use via `tool_code` blocks |
| `qwen` | `mlx-community/Qwen3-VL-4B-Instruct-4bit` | Vision + tool use via `<tool_call>` tags |
## Key Design Decisions ## Key Design Decisions
- Uses `mlx_vlm` (not `mlx_lm`) as the inference backend — this supports both text and vision in a single model load - Uses `mlx_vlm` (not `mlx_lm`) as the inference backend — this supports both text and vision in a single model load
- Gemma 3 has no system role — system messages are converted to user/assistant pairs - Model-specific prompt formatting: Gemma converts system→user/assistant pairs and uses `tool_code` blocks; Qwen3 uses native system role and `<tool_call>` XML tags
- Tool use is prompt-engineered: tools are injected into the system prompt with `<tool_call>` XML tags, and parsed from model output - Offline-first: if the model is already cached locally (~/.cache/huggingface/hub/), the server resolves the local snapshot path directly — no network requests are made (HEAD checks, update checks, etc.)
- Thread lock on generation (single-request-at-a-time) — MLX models aren't safe for concurrent generation - Thread lock on generation (single-request-at-a-time) — MLX models aren't safe for concurrent generation
- 128k context window supported via the model's native capabilities - 128k context window supported via the model's native capabilities

View File

@@ -20,6 +20,58 @@ logger = logging.getLogger(__name__)
DEFAULT_MODEL = "mlx-community/gemma-3-4b-it-4bit" DEFAULT_MODEL = "mlx-community/gemma-3-4b-it-4bit"
# Known model aliases for quick selection
MODEL_ALIASES: dict[str, str] = {
"gemma": "mlx-community/gemma-3-4b-it-4bit",
"qwen": "mlx-community/Qwen3-VL-4B-Instruct-4bit",
}
def _resolve_local_model_path(repo_id: str) -> Path | None:
"""If a HuggingFace model is already cached locally, return its snapshot path.
This avoids any network requests (HEAD checks) when the model files are
already present on disk — critical for offline use.
"""
# If it's already a local directory, just use it
local = Path(repo_id)
if local.is_dir():
return local
# Check the HF cache: ~/.cache/huggingface/hub/models--org--name/snapshots/<hash>
cache_root = Path.home() / ".cache" / "huggingface" / "hub"
safe_name = "models--" + repo_id.replace("/", "--")
model_cache = cache_root / safe_name
if not model_cache.is_dir():
return None
# Read the ref to find the snapshot hash
refs_dir = model_cache / "refs"
snapshot_dir = model_cache / "snapshots"
if refs_dir.is_dir() and snapshot_dir.is_dir():
main_ref = refs_dir / "main"
if main_ref.is_file():
commit_hash = main_ref.read_text().strip()
snap = snapshot_dir / commit_hash
if snap.is_dir():
logger.info(
"Found locally cached model at %s — skipping online check", snap
)
return snap
# Fallback: use the first (most recent) snapshot if refs/main is missing
if snapshot_dir.is_dir():
snapshots = sorted(snapshot_dir.iterdir(), key=lambda p: p.stat().st_mtime, reverse=True)
if snapshots:
logger.info(
"Found locally cached model at %s — skipping online check",
snapshots[0],
)
return snapshots[0]
return None
# ------------------------------------------------------------------ # ------------------------------------------------------------------
# Helpers for Gemma 3 tool_code format # Helpers for Gemma 3 tool_code format
@@ -200,17 +252,35 @@ class InferenceEngine:
self.model = None self.model = None
self.processor = None self.processor = None
self.config = None self.config = None
self._model_type: str = "" # e.g. "gemma3", "qwen3"
self._lock = threading.Lock() self._lock = threading.Lock()
self._prompt_cache = PromptCache() self._prompt_cache = PromptCache()
def load(self) -> None: def load(self) -> None:
logger.info("Loading model %s ...", self.model_path) logger.info("Loading model %s ...", self.model_path)
self.model, self.processor = mlx_vlm.load(self.model_path)
# Prefer the local cache to avoid any network requests
local_path = _resolve_local_model_path(self.model_path)
load_path = str(local_path) if local_path else self.model_path
self.model, self.processor = mlx_vlm.load(load_path)
# Load model config for chat template # Load model config for chat template
from transformers import AutoConfig from transformers import AutoConfig
self.config = AutoConfig.from_pretrained(self.model_path, trust_remote_code=True) self.config = AutoConfig.from_pretrained(load_path, trust_remote_code=True)
logger.info("Model loaded successfully.")
# Detect model family for prompt-format branching
self._model_type = getattr(self.config, "model_type", "").lower()
logger.info("Model loaded successfully (type=%s).", self._model_type)
@property
def is_qwen(self) -> bool:
return "qwen" in self._model_type
@property
def is_gemma(self) -> bool:
return "gemma" in self._model_type
# ------------------------------------------------------------------ # ------------------------------------------------------------------
# Image helpers # Image helpers
@@ -244,6 +314,16 @@ class InferenceEngine:
Returns (prompt_str, image_paths). Returns (prompt_str, image_paths).
""" """
if self.is_qwen:
return self._build_prompt_qwen(messages, tools)
return self._build_prompt_gemma(messages, tools)
def _build_prompt_gemma(
self,
messages: list[dict],
tools: list[dict] | None = None,
) -> tuple[str, list[str]]:
"""Gemma 3 prompt builder (tool_code format, no system role)."""
image_paths: list[str] = [] image_paths: list[str] = []
formatted_messages: list[dict] = [] formatted_messages: list[dict] = []
@@ -306,6 +386,58 @@ class InferenceEngine:
return prompt, image_paths return prompt, image_paths
def _build_prompt_qwen(
self,
messages: list[dict],
tools: list[dict] | None = None,
) -> tuple[str, list[str]]:
"""Qwen3 prompt builder (native system role, JSON tool calls)."""
image_paths: list[str] = []
formatted_messages: list[dict] = []
# Qwen3 supports system role natively — inject tools there
has_system = any(m.get("role") == "system" for m in messages)
if tools and not has_system:
formatted_messages.append({
"role": "system",
"content": self._build_qwen_tool_system_prompt(tools),
})
for msg in messages:
role = msg["role"]
content = msg.get("content")
tool_calls = msg.get("tool_calls")
if role == "system":
text = self._get_text_content(content)
if tools:
text = text + "\n\n" + self._build_qwen_tool_system_prompt(tools)
formatted_messages.append({"role": "system", "content": text})
elif role == "user":
text, imgs = self._extract_content_parts(content)
image_paths.extend(imgs)
formatted_messages.append({"role": "user", "content": text})
elif role == "assistant":
text = self._get_text_content(content) or ""
if tool_calls:
tc_text = self._format_qwen_tool_calls_for_prompt(tool_calls)
text = (text + "\n" + tc_text).strip()
formatted_messages.append({"role": "assistant", "content": text})
elif role == "tool":
tool_text = self._get_text_content(content) or ""
formatted_messages.append({"role": "user", "content": tool_text})
# Apply chat template via mlx_vlm
prompt = mlx_vlm.apply_chat_template(
self.processor,
self.config,
formatted_messages,
add_generation_prompt=True,
num_images=len(image_paths),
)
return prompt, image_paths
@staticmethod @staticmethod
def _merge_consecutive_roles(messages: list[dict]) -> list[dict]: def _merge_consecutive_roles(messages: list[dict]) -> list[dict]:
"""Merge consecutive messages with the same role into one. """Merge consecutive messages with the same role into one.
@@ -439,6 +571,50 @@ class InferenceEngine:
parts.append(f"```tool_code\n{call_str}\n```") parts.append(f"```tool_code\n{call_str}\n```")
return "\n".join(parts) return "\n".join(parts)
# ------------------------------------------------------------------
# Qwen3 tool helpers
# ------------------------------------------------------------------
@staticmethod
def _build_qwen_tool_system_prompt(tools: list[dict]) -> str:
"""Build the tool system prompt for Qwen3 using its native JSON format."""
tool_descs = []
for tool in tools:
func = tool.get("function", tool)
tool_descs.append({
"type": "function",
"function": {
"name": func["name"],
"description": func.get("description", ""),
"parameters": func.get("parameters", {}),
},
})
tools_json = json.dumps(tool_descs, indent=2)
return (
"# Tools\n\n"
"You are a helpful assistant with access to the following tools. "
"Use them when appropriate by responding with a JSON tool call.\n\n"
"## Available Tools\n\n"
f"{tools_json}\n\n"
"## Tool Call Format\n\n"
"When you need to call a tool, respond with:\n"
'<tool_call>\n{"name": "<function_name>", "arguments": {<args>}}\n</tool_call>'
)
@staticmethod
def _format_qwen_tool_calls_for_prompt(tool_calls: list[dict]) -> str:
"""Format OpenAI-style tool calls back into Qwen3's XML tag format."""
parts = []
for tc in tool_calls:
func = tc.get("function", tc)
name = func["name"]
args = func.get("arguments", "{}")
if isinstance(args, str):
args = json.loads(args)
call_obj = {"name": name, "arguments": args}
parts.append(f"<tool_call>\n{json.dumps(call_obj)}\n</tool_call>")
return "\n".join(parts)
# ------------------------------------------------------------------ # ------------------------------------------------------------------
# Prefix cache & generation # Prefix cache & generation
# ------------------------------------------------------------------ # ------------------------------------------------------------------
@@ -447,6 +623,32 @@ class InferenceEngine:
# Note: KV cache quantization is not supported with Gemma 3's RotatingKVCache # Note: KV cache quantization is not supported with Gemma 3's RotatingKVCache
_GENERATE_KWARGS: dict = {} _GENERATE_KWARGS: dict = {}
# Keys in the prep dict that are internal bookkeeping, not kwargs for
# mlx_vlm.stream_generate.
_PREP_INTERNAL_KEYS = frozenset({
"input_ids", "pixel_values", "mask", "prompt_cache",
"_full_token_ids", "_prompt_token_count",
})
def _extra_generate_kwargs(
self, images: list[str] | None, prep: dict | None = None,
) -> dict:
"""Build per-request kwargs for mlx_vlm.stream_generate.
Includes model-specific keys from prepare_inputs (e.g. image_grid_thw
for Qwen3-VL) and works around a chunked-prefill bug where
visual_pos_masks is None for text-only requests.
"""
extra: dict = dict(self._GENERATE_KWARGS)
if self.is_qwen and not images:
extra["prefill_step_size"] = None
# Forward any model-specific keys that prepare_inputs returned
if prep:
for k, v in prep.items():
if k not in self._PREP_INTERNAL_KEYS:
extra[k] = v
return extra
def _get_tokenizer(self): def _get_tokenizer(self):
"""Get the underlying tokenizer from the processor.""" """Get the underlying tokenizer from the processor."""
proc = self.processor proc = self.processor
@@ -484,6 +686,11 @@ class InferenceEngine:
pixel_values = inputs.get("pixel_values") pixel_values = inputs.get("pixel_values")
mask = inputs.get("attention_mask") mask = inputs.get("attention_mask")
# Collect any model-specific extra keys from prepare_inputs
# (e.g. image_grid_thw for Qwen3-VL) so they reach the model.
_KNOWN_KEYS = {"input_ids", "pixel_values", "attention_mask"}
extra_inputs = {k: v for k, v in inputs.items() if k not in _KNOWN_KEYS}
full_token_list = full_input_ids.flatten().tolist() full_token_list = full_input_ids.flatten().tolist()
prefix_len = self._prompt_cache.get_reusable_length(full_token_list) prefix_len = self._prompt_cache.get_reusable_length(full_token_list)
@@ -521,7 +728,9 @@ class InferenceEngine:
} }
# Cache miss — create a fresh KV cache # Cache miss — create a fresh KV cache
cache = cache_module.make_prompt_cache(self.model.language_model) # VLM models expose .language_model; text-only models are the LM directly
lm = getattr(self.model, "language_model", self.model)
cache = cache_module.make_prompt_cache(lm)
logger.info( logger.info(
"Prefix cache miss: processing %d tokens from scratch", "Prefix cache miss: processing %d tokens from scratch",
len(full_token_list), len(full_token_list),
@@ -533,6 +742,7 @@ class InferenceEngine:
"prompt_cache": cache, "prompt_cache": cache,
"_full_token_ids": full_token_list, "_full_token_ids": full_token_list,
"_prompt_token_count": len(full_token_list), "_prompt_token_count": len(full_token_list),
**extra_inputs,
} }
def _save_cache(self, prep: dict, generated_tokens: list[int]) -> None: def _save_cache(self, prep: dict, generated_tokens: list[int]) -> None:
@@ -555,8 +765,9 @@ class InferenceEngine:
prep = self._prepare_generation(prompt, images) prep = self._prepare_generation(prompt, images)
prompt_token_count = prep["_prompt_token_count"] prompt_token_count = prep["_prompt_token_count"]
# Ensure stopping criteria is initialised # Ensure stopping criteria is initialised (Gemma-specific; optional for others)
tokenizer = self._get_tokenizer() tokenizer = self._get_tokenizer()
if hasattr(tokenizer, "stopping_criteria"):
tokenizer.stopping_criteria.reset(self.model.config.eos_token_id) tokenizer.stopping_criteria.reset(self.model.config.eos_token_id)
text = "" text = ""
@@ -575,7 +786,7 @@ class InferenceEngine:
temperature=temperature, temperature=temperature,
top_p=top_p, top_p=top_p,
repetition_penalty=repetition_penalty, repetition_penalty=repetition_penalty,
**self._GENERATE_KWARGS, **self._extra_generate_kwargs(images, prep),
): ):
text += result.text text += result.text
if result.token is not None: if result.token is not None:
@@ -603,8 +814,9 @@ class InferenceEngine:
prep = self._prepare_generation(prompt, images) prep = self._prepare_generation(prompt, images)
prompt_token_count = prep["_prompt_token_count"] prompt_token_count = prep["_prompt_token_count"]
# Ensure stopping criteria is initialised # Ensure stopping criteria is initialised (Gemma-specific; optional for others)
tokenizer = self._get_tokenizer() tokenizer = self._get_tokenizer()
if hasattr(tokenizer, "stopping_criteria"):
tokenizer.stopping_criteria.reset(self.model.config.eos_token_id) tokenizer.stopping_criteria.reset(self.model.config.eos_token_id)
accumulated = "" accumulated = ""
@@ -624,7 +836,7 @@ class InferenceEngine:
temperature=temperature, temperature=temperature,
top_p=top_p, top_p=top_p,
repetition_penalty=repetition_penalty, repetition_penalty=repetition_penalty,
**self._GENERATE_KWARGS, **self._extra_generate_kwargs(images, prep),
): ):
token_text = result.text token_text = result.text
accumulated += token_text accumulated += token_text
@@ -663,19 +875,26 @@ class InferenceEngine:
# Tool call parsing from model output # Tool call parsing from model output
# ------------------------------------------------------------------ # ------------------------------------------------------------------
@staticmethod
def parse_tool_calls( def parse_tool_calls(
text: str, tools: list[dict] | None = None self, text: str, tools: list[dict] | None = None
) -> tuple[str, list[dict]]: ) -> tuple[str, list[dict]]:
"""Parse tool calls from model output using Gemma 3's tool_code format. """Parse tool calls from model output.
Detects ```tool_code ... ``` blocks containing Python-style or Supports both Gemma 3's ```tool_code``` blocks and Qwen3's
shell-style function calls. <tool_call> XML tags.
Returns (clean_text, tool_calls) where tool_calls is a list of Returns (clean_text, tool_calls) where tool_calls is a list of
{"id": str, "type": "function", "function": {"name": str, "arguments": str}}. {"id": str, "type": "function", "function": {"name": str, "arguments": str}}.
""" """
# Build a lookup of function name -> parameter schema if self.is_qwen:
return self._parse_tool_calls_qwen(text)
return self._parse_tool_calls_gemma(text, tools)
@staticmethod
def _parse_tool_calls_gemma(
text: str, tools: list[dict] | None = None
) -> tuple[str, list[dict]]:
"""Parse Gemma 3 tool_code blocks."""
tool_defs: dict[str, dict] = {} tool_defs: dict[str, dict] = {}
if tools: if tools:
for tool in tools: for tool in tools:
@@ -704,3 +923,32 @@ class InferenceEngine:
logger.warning("Failed to parse tool_code call %r: %s", call_str, e) logger.warning("Failed to parse tool_code call %r: %s", call_str, e)
return clean_text, tool_calls return clean_text, tool_calls
@staticmethod
def _parse_tool_calls_qwen(text: str) -> tuple[str, list[dict]]:
"""Parse Qwen3 <tool_call> XML tags."""
tool_calls = []
pattern = r"<tool_call>\s*(.*?)\s*</tool_call>"
matches = re.findall(pattern, text, re.DOTALL)
clean_text = re.sub(r"<tool_call>\s*.*?\s*</tool_call>", "", text, flags=re.DOTALL).strip()
for i, match in enumerate(matches):
try:
call_obj = json.loads(match.strip())
name = call_obj.get("name", "")
args = call_obj.get("arguments", {})
if isinstance(args, str):
args = json.loads(args)
tool_calls.append({
"id": f"call_{i}_{hash(match) % 10**8:08d}",
"type": "function",
"function": {
"name": name,
"arguments": json.dumps(args),
},
})
except Exception as e:
logger.warning("Failed to parse tool_call tag %r: %s", match, e)
return clean_text, tool_calls

View File

@@ -1,4 +1,4 @@
"""OpenAI-compatible API server for Gemma 3 4B via MLX.""" """OpenAI-compatible API server for local LLMs (Gemma 3, Qwen3, …) via MLX."""
from __future__ import annotations from __future__ import annotations
@@ -31,7 +31,7 @@ from .models import (
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
app = FastAPI(title="MLX Server", description="OpenAI-compatible API for Gemma 3 4B") app = FastAPI(title="MLX Server", description="OpenAI-compatible API for local LLMs on Apple Silicon")
app.add_middleware( app.add_middleware(
CORSMiddleware, CORSMiddleware,
@@ -170,6 +170,11 @@ async def _stream_response(
prompt_tokens = 0 prompt_tokens = 0
gen_tokens = 0 gen_tokens = 0
# When tools are available we must buffer the full response before
# emitting content — otherwise raw tool-call markup (```tool_code```
# or <tool_call>) leaks into the streamed text.
buffer_for_tools = bool(tools)
for token_text, is_final, pt, gt in e.stream_generate( for token_text, is_final, pt, gt in e.stream_generate(
prompt=prompt, prompt=prompt,
images=images or None, images=images or None,
@@ -182,7 +187,7 @@ async def _stream_response(
gen_tokens = gt gen_tokens = gt
full_text += token_text full_text += token_text
if not is_final and token_text: if not buffer_for_tools and not is_final and token_text:
chunk = ChatCompletionChunk( chunk = ChatCompletionChunk(
id=request_id, id=request_id,
created=created, created=created,
@@ -191,14 +196,30 @@ async def _stream_response(
) )
yield {"data": chunk.model_dump_json()} yield {"data": chunk.model_dump_json()}
# Check for tool calls in complete response # --- Post-generation: parse tool calls and emit clean content ------
finish_reason = "stop" finish_reason = "stop"
tool_calls_parsed = []
if tools: if tools:
clean_text, parsed = e.parse_tool_calls(full_text, tools) clean_text, parsed = e.parse_tool_calls(full_text, tools)
if parsed: if parsed:
finish_reason = "tool_calls" finish_reason = "tool_calls"
tool_calls_parsed = parsed
full_text = clean_text or ""
# Emit buffered content (when tools were present, this is the cleaned
# text with tool-call markup stripped out)
if buffer_for_tools and full_text.strip():
content_chunk = ChatCompletionChunk(
id=request_id,
created=created,
model=model_name,
choices=[StreamChoice(delta=DeltaMessage(content=full_text))],
)
yield {"data": content_chunk.model_dump_json()}
# Emit tool call chunks # Emit tool call chunks
for i, tc in enumerate(parsed): for i, tc in enumerate(tool_calls_parsed):
tc_chunk = ChatCompletionChunk( tc_chunk = ChatCompletionChunk(
id=request_id, id=request_id,
created=created, created=created,

26
run.sh
View File

@@ -7,8 +7,28 @@ cd "$SCRIPT_DIR"
# Activate virtual environment # Activate virtual environment
source .venv/bin/activate source .venv/bin/activate
# Default model 4-bit quantized Gemma 3 4B IT (vision-capable) # --- Model selection ---
MODEL="${MODEL:-mlx-community/gemma-3-4b-it-4bit}" # Usage: ./run.sh [gemma|qwen]
# Or set MODEL env var directly for a custom model.
MODEL_CHOICE="${1:-gemma}"
if [[ -z "${MODEL:-}" ]]; then
case "$MODEL_CHOICE" in
gemma)
MODEL="mlx-community/gemma-3-4b-it-4bit"
;;
qwen)
MODEL="mlx-community/Qwen3-VL-4B-Instruct-4bit"
;;
*)
echo "Unknown model choice: $MODEL_CHOICE"
echo "Usage: $0 [gemma|qwen]"
exit 1
;;
esac
fi
HOST="${HOST:-127.0.0.1}" HOST="${HOST:-127.0.0.1}"
PORT="${PORT:-1234}" PORT="${PORT:-1234}"
@@ -21,4 +41,4 @@ exec python -m mlx_server.main \
--model "$MODEL" \ --model "$MODEL" \
--host "$HOST" \ --host "$HOST" \
--port "$PORT" \ --port "$PORT" \
"$@" "${@:2}"