feat: qwen now works, too

This commit is contained in:
2026-03-17 11:44:24 +01:00
parent bdfbd14577
commit cc6e761ed4
5 changed files with 351 additions and 50 deletions

View File

@@ -20,6 +20,58 @@ logger = logging.getLogger(__name__)
DEFAULT_MODEL = "mlx-community/gemma-3-4b-it-4bit"
# Known model aliases for quick selection
MODEL_ALIASES: dict[str, str] = {
"gemma": "mlx-community/gemma-3-4b-it-4bit",
"qwen": "mlx-community/Qwen3-VL-4B-Instruct-4bit",
}
def _resolve_local_model_path(repo_id: str) -> Path | None:
"""If a HuggingFace model is already cached locally, return its snapshot path.
This avoids any network requests (HEAD checks) when the model files are
already present on disk — critical for offline use.
"""
# If it's already a local directory, just use it
local = Path(repo_id)
if local.is_dir():
return local
# Check the HF cache: ~/.cache/huggingface/hub/models--org--name/snapshots/<hash>
cache_root = Path.home() / ".cache" / "huggingface" / "hub"
safe_name = "models--" + repo_id.replace("/", "--")
model_cache = cache_root / safe_name
if not model_cache.is_dir():
return None
# Read the ref to find the snapshot hash
refs_dir = model_cache / "refs"
snapshot_dir = model_cache / "snapshots"
if refs_dir.is_dir() and snapshot_dir.is_dir():
main_ref = refs_dir / "main"
if main_ref.is_file():
commit_hash = main_ref.read_text().strip()
snap = snapshot_dir / commit_hash
if snap.is_dir():
logger.info(
"Found locally cached model at %s — skipping online check", snap
)
return snap
# Fallback: use the first (most recent) snapshot if refs/main is missing
if snapshot_dir.is_dir():
snapshots = sorted(snapshot_dir.iterdir(), key=lambda p: p.stat().st_mtime, reverse=True)
if snapshots:
logger.info(
"Found locally cached model at %s — skipping online check",
snapshots[0],
)
return snapshots[0]
return None
# ------------------------------------------------------------------
# Helpers for Gemma 3 tool_code format
@@ -200,17 +252,35 @@ class InferenceEngine:
self.model = None
self.processor = None
self.config = None
self._model_type: str = "" # e.g. "gemma3", "qwen3"
self._lock = threading.Lock()
self._prompt_cache = PromptCache()
def load(self) -> None:
logger.info("Loading model %s ...", self.model_path)
self.model, self.processor = mlx_vlm.load(self.model_path)
# Prefer the local cache to avoid any network requests
local_path = _resolve_local_model_path(self.model_path)
load_path = str(local_path) if local_path else self.model_path
self.model, self.processor = mlx_vlm.load(load_path)
# Load model config for chat template
from transformers import AutoConfig
self.config = AutoConfig.from_pretrained(self.model_path, trust_remote_code=True)
logger.info("Model loaded successfully.")
self.config = AutoConfig.from_pretrained(load_path, trust_remote_code=True)
# Detect model family for prompt-format branching
self._model_type = getattr(self.config, "model_type", "").lower()
logger.info("Model loaded successfully (type=%s).", self._model_type)
@property
def is_qwen(self) -> bool:
return "qwen" in self._model_type
@property
def is_gemma(self) -> bool:
return "gemma" in self._model_type
# ------------------------------------------------------------------
# Image helpers
@@ -244,6 +314,16 @@ class InferenceEngine:
Returns (prompt_str, image_paths).
"""
if self.is_qwen:
return self._build_prompt_qwen(messages, tools)
return self._build_prompt_gemma(messages, tools)
def _build_prompt_gemma(
self,
messages: list[dict],
tools: list[dict] | None = None,
) -> tuple[str, list[str]]:
"""Gemma 3 prompt builder (tool_code format, no system role)."""
image_paths: list[str] = []
formatted_messages: list[dict] = []
@@ -306,6 +386,58 @@ class InferenceEngine:
return prompt, image_paths
def _build_prompt_qwen(
self,
messages: list[dict],
tools: list[dict] | None = None,
) -> tuple[str, list[str]]:
"""Qwen3 prompt builder (native system role, JSON tool calls)."""
image_paths: list[str] = []
formatted_messages: list[dict] = []
# Qwen3 supports system role natively — inject tools there
has_system = any(m.get("role") == "system" for m in messages)
if tools and not has_system:
formatted_messages.append({
"role": "system",
"content": self._build_qwen_tool_system_prompt(tools),
})
for msg in messages:
role = msg["role"]
content = msg.get("content")
tool_calls = msg.get("tool_calls")
if role == "system":
text = self._get_text_content(content)
if tools:
text = text + "\n\n" + self._build_qwen_tool_system_prompt(tools)
formatted_messages.append({"role": "system", "content": text})
elif role == "user":
text, imgs = self._extract_content_parts(content)
image_paths.extend(imgs)
formatted_messages.append({"role": "user", "content": text})
elif role == "assistant":
text = self._get_text_content(content) or ""
if tool_calls:
tc_text = self._format_qwen_tool_calls_for_prompt(tool_calls)
text = (text + "\n" + tc_text).strip()
formatted_messages.append({"role": "assistant", "content": text})
elif role == "tool":
tool_text = self._get_text_content(content) or ""
formatted_messages.append({"role": "user", "content": tool_text})
# Apply chat template via mlx_vlm
prompt = mlx_vlm.apply_chat_template(
self.processor,
self.config,
formatted_messages,
add_generation_prompt=True,
num_images=len(image_paths),
)
return prompt, image_paths
@staticmethod
def _merge_consecutive_roles(messages: list[dict]) -> list[dict]:
"""Merge consecutive messages with the same role into one.
@@ -439,6 +571,50 @@ class InferenceEngine:
parts.append(f"```tool_code\n{call_str}\n```")
return "\n".join(parts)
# ------------------------------------------------------------------
# Qwen3 tool helpers
# ------------------------------------------------------------------
@staticmethod
def _build_qwen_tool_system_prompt(tools: list[dict]) -> str:
"""Build the tool system prompt for Qwen3 using its native JSON format."""
tool_descs = []
for tool in tools:
func = tool.get("function", tool)
tool_descs.append({
"type": "function",
"function": {
"name": func["name"],
"description": func.get("description", ""),
"parameters": func.get("parameters", {}),
},
})
tools_json = json.dumps(tool_descs, indent=2)
return (
"# Tools\n\n"
"You are a helpful assistant with access to the following tools. "
"Use them when appropriate by responding with a JSON tool call.\n\n"
"## Available Tools\n\n"
f"{tools_json}\n\n"
"## Tool Call Format\n\n"
"When you need to call a tool, respond with:\n"
'<tool_call>\n{"name": "<function_name>", "arguments": {<args>}}\n</tool_call>'
)
@staticmethod
def _format_qwen_tool_calls_for_prompt(tool_calls: list[dict]) -> str:
"""Format OpenAI-style tool calls back into Qwen3's XML tag format."""
parts = []
for tc in tool_calls:
func = tc.get("function", tc)
name = func["name"]
args = func.get("arguments", "{}")
if isinstance(args, str):
args = json.loads(args)
call_obj = {"name": name, "arguments": args}
parts.append(f"<tool_call>\n{json.dumps(call_obj)}\n</tool_call>")
return "\n".join(parts)
# ------------------------------------------------------------------
# Prefix cache & generation
# ------------------------------------------------------------------
@@ -447,6 +623,32 @@ class InferenceEngine:
# Note: KV cache quantization is not supported with Gemma 3's RotatingKVCache
_GENERATE_KWARGS: dict = {}
# Keys in the prep dict that are internal bookkeeping, not kwargs for
# mlx_vlm.stream_generate.
_PREP_INTERNAL_KEYS = frozenset({
"input_ids", "pixel_values", "mask", "prompt_cache",
"_full_token_ids", "_prompt_token_count",
})
def _extra_generate_kwargs(
self, images: list[str] | None, prep: dict | None = None,
) -> dict:
"""Build per-request kwargs for mlx_vlm.stream_generate.
Includes model-specific keys from prepare_inputs (e.g. image_grid_thw
for Qwen3-VL) and works around a chunked-prefill bug where
visual_pos_masks is None for text-only requests.
"""
extra: dict = dict(self._GENERATE_KWARGS)
if self.is_qwen and not images:
extra["prefill_step_size"] = None
# Forward any model-specific keys that prepare_inputs returned
if prep:
for k, v in prep.items():
if k not in self._PREP_INTERNAL_KEYS:
extra[k] = v
return extra
def _get_tokenizer(self):
"""Get the underlying tokenizer from the processor."""
proc = self.processor
@@ -484,6 +686,11 @@ class InferenceEngine:
pixel_values = inputs.get("pixel_values")
mask = inputs.get("attention_mask")
# Collect any model-specific extra keys from prepare_inputs
# (e.g. image_grid_thw for Qwen3-VL) so they reach the model.
_KNOWN_KEYS = {"input_ids", "pixel_values", "attention_mask"}
extra_inputs = {k: v for k, v in inputs.items() if k not in _KNOWN_KEYS}
full_token_list = full_input_ids.flatten().tolist()
prefix_len = self._prompt_cache.get_reusable_length(full_token_list)
@@ -521,7 +728,9 @@ class InferenceEngine:
}
# Cache miss — create a fresh KV cache
cache = cache_module.make_prompt_cache(self.model.language_model)
# VLM models expose .language_model; text-only models are the LM directly
lm = getattr(self.model, "language_model", self.model)
cache = cache_module.make_prompt_cache(lm)
logger.info(
"Prefix cache miss: processing %d tokens from scratch",
len(full_token_list),
@@ -533,6 +742,7 @@ class InferenceEngine:
"prompt_cache": cache,
"_full_token_ids": full_token_list,
"_prompt_token_count": len(full_token_list),
**extra_inputs,
}
def _save_cache(self, prep: dict, generated_tokens: list[int]) -> None:
@@ -555,9 +765,10 @@ class InferenceEngine:
prep = self._prepare_generation(prompt, images)
prompt_token_count = prep["_prompt_token_count"]
# Ensure stopping criteria is initialised
# Ensure stopping criteria is initialised (Gemma-specific; optional for others)
tokenizer = self._get_tokenizer()
tokenizer.stopping_criteria.reset(self.model.config.eos_token_id)
if hasattr(tokenizer, "stopping_criteria"):
tokenizer.stopping_criteria.reset(self.model.config.eos_token_id)
text = ""
generated_tokens: list[int] = []
@@ -575,7 +786,7 @@ class InferenceEngine:
temperature=temperature,
top_p=top_p,
repetition_penalty=repetition_penalty,
**self._GENERATE_KWARGS,
**self._extra_generate_kwargs(images, prep),
):
text += result.text
if result.token is not None:
@@ -603,9 +814,10 @@ class InferenceEngine:
prep = self._prepare_generation(prompt, images)
prompt_token_count = prep["_prompt_token_count"]
# Ensure stopping criteria is initialised
# Ensure stopping criteria is initialised (Gemma-specific; optional for others)
tokenizer = self._get_tokenizer()
tokenizer.stopping_criteria.reset(self.model.config.eos_token_id)
if hasattr(tokenizer, "stopping_criteria"):
tokenizer.stopping_criteria.reset(self.model.config.eos_token_id)
accumulated = ""
generated_tokens: list[int] = []
@@ -624,7 +836,7 @@ class InferenceEngine:
temperature=temperature,
top_p=top_p,
repetition_penalty=repetition_penalty,
**self._GENERATE_KWARGS,
**self._extra_generate_kwargs(images, prep),
):
token_text = result.text
accumulated += token_text
@@ -663,19 +875,26 @@ class InferenceEngine:
# Tool call parsing from model output
# ------------------------------------------------------------------
@staticmethod
def parse_tool_calls(
text: str, tools: list[dict] | None = None
self, text: str, tools: list[dict] | None = None
) -> tuple[str, list[dict]]:
"""Parse tool calls from model output using Gemma 3's tool_code format.
"""Parse tool calls from model output.
Detects ```tool_code ... ``` blocks containing Python-style or
shell-style function calls.
Supports both Gemma 3's ```tool_code``` blocks and Qwen3's
<tool_call> XML tags.
Returns (clean_text, tool_calls) where tool_calls is a list of
{"id": str, "type": "function", "function": {"name": str, "arguments": str}}.
"""
# Build a lookup of function name -> parameter schema
if self.is_qwen:
return self._parse_tool_calls_qwen(text)
return self._parse_tool_calls_gemma(text, tools)
@staticmethod
def _parse_tool_calls_gemma(
text: str, tools: list[dict] | None = None
) -> tuple[str, list[dict]]:
"""Parse Gemma 3 tool_code blocks."""
tool_defs: dict[str, dict] = {}
if tools:
for tool in tools:
@@ -704,3 +923,32 @@ class InferenceEngine:
logger.warning("Failed to parse tool_code call %r: %s", call_str, e)
return clean_text, tool_calls
@staticmethod
def _parse_tool_calls_qwen(text: str) -> tuple[str, list[dict]]:
"""Parse Qwen3 <tool_call> XML tags."""
tool_calls = []
pattern = r"<tool_call>\s*(.*?)\s*</tool_call>"
matches = re.findall(pattern, text, re.DOTALL)
clean_text = re.sub(r"<tool_call>\s*.*?\s*</tool_call>", "", text, flags=re.DOTALL).strip()
for i, match in enumerate(matches):
try:
call_obj = json.loads(match.strip())
name = call_obj.get("name", "")
args = call_obj.get("arguments", {})
if isinstance(args, str):
args = json.loads(args)
tool_calls.append({
"id": f"call_{i}_{hash(match) % 10**8:08d}",
"type": "function",
"function": {
"name": name,
"arguments": json.dumps(args),
},
})
except Exception as e:
logger.warning("Failed to parse tool_call tag %r: %s", match, e)
return clean_text, tool_calls

View File

@@ -1,4 +1,4 @@
"""OpenAI-compatible API server for Gemma 3 4B via MLX."""
"""OpenAI-compatible API server for local LLMs (Gemma 3, Qwen3, …) via MLX."""
from __future__ import annotations
@@ -31,7 +31,7 @@ from .models import (
logger = logging.getLogger(__name__)
app = FastAPI(title="MLX Server", description="OpenAI-compatible API for Gemma 3 4B")
app = FastAPI(title="MLX Server", description="OpenAI-compatible API for local LLMs on Apple Silicon")
app.add_middleware(
CORSMiddleware,
@@ -170,6 +170,11 @@ async def _stream_response(
prompt_tokens = 0
gen_tokens = 0
# When tools are available we must buffer the full response before
# emitting content — otherwise raw tool-call markup (```tool_code```
# or <tool_call>) leaks into the streamed text.
buffer_for_tools = bool(tools)
for token_text, is_final, pt, gt in e.stream_generate(
prompt=prompt,
images=images or None,
@@ -182,7 +187,7 @@ async def _stream_response(
gen_tokens = gt
full_text += token_text
if not is_final and token_text:
if not buffer_for_tools and not is_final and token_text:
chunk = ChatCompletionChunk(
id=request_id,
created=created,
@@ -191,37 +196,53 @@ async def _stream_response(
)
yield {"data": chunk.model_dump_json()}
# Check for tool calls in complete response
# --- Post-generation: parse tool calls and emit clean content ------
finish_reason = "stop"
tool_calls_parsed = []
if tools:
clean_text, parsed = e.parse_tool_calls(full_text, tools)
if parsed:
finish_reason = "tool_calls"
# Emit tool call chunks
for i, tc in enumerate(parsed):
tc_chunk = ChatCompletionChunk(
id=request_id,
created=created,
model=model_name,
choices=[
StreamChoice(
delta=DeltaMessage(
tool_calls=[
ToolCall(
index=i,
id=tc["id"],
type="function",
function=FunctionCall(
name=tc["function"]["name"],
arguments=tc["function"]["arguments"],
),
)
]
tool_calls_parsed = parsed
full_text = clean_text or ""
# Emit buffered content (when tools were present, this is the cleaned
# text with tool-call markup stripped out)
if buffer_for_tools and full_text.strip():
content_chunk = ChatCompletionChunk(
id=request_id,
created=created,
model=model_name,
choices=[StreamChoice(delta=DeltaMessage(content=full_text))],
)
yield {"data": content_chunk.model_dump_json()}
# Emit tool call chunks
for i, tc in enumerate(tool_calls_parsed):
tc_chunk = ChatCompletionChunk(
id=request_id,
created=created,
model=model_name,
choices=[
StreamChoice(
delta=DeltaMessage(
tool_calls=[
ToolCall(
index=i,
id=tc["id"],
type="function",
function=FunctionCall(
name=tc["function"]["name"],
arguments=tc["function"]["arguments"],
),
)
)
],
]
)
)
yield {"data": tc_chunk.model_dump_json()}
],
)
yield {"data": tc_chunk.model_dump_json()}
# Final chunk with finish reason and usage
final_chunk = ChatCompletionChunk(