feat: qwen now works, too
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -8,3 +8,4 @@ build/
|
|||||||
.env
|
.env
|
||||||
*.log
|
*.log
|
||||||
.DS_Store
|
.DS_Store
|
||||||
|
settings.local.json
|
||||||
|
|||||||
19
CLAUDE.md
19
CLAUDE.md
@@ -1,6 +1,6 @@
|
|||||||
# MLX Server
|
# MLX Server
|
||||||
|
|
||||||
OpenAI-compatible API server for Gemma 3 4B (vision + tool use) on Apple Silicon via MLX.
|
OpenAI-compatible API server for local LLMs on Apple Silicon via MLX. Supports Gemma 3 4B and Qwen3 VL 4B (vision + tool use).
|
||||||
|
|
||||||
## Quick Start
|
## Quick Start
|
||||||
|
|
||||||
@@ -8,11 +8,15 @@ OpenAI-compatible API server for Gemma 3 4B (vision + tool use) on Apple Silicon
|
|||||||
# Activate virtual environment
|
# Activate virtual environment
|
||||||
source .venv/bin/activate
|
source .venv/bin/activate
|
||||||
|
|
||||||
# Run the server (downloads model on first run)
|
# Run with Gemma 3 (default)
|
||||||
./run.sh
|
./run.sh
|
||||||
|
|
||||||
|
# Run with Qwen3
|
||||||
|
./run.sh qwen
|
||||||
|
|
||||||
# Or directly:
|
# Or directly:
|
||||||
python -m mlx_server.main --model mlx-community/gemma-3-4b-it-4bit --port 1234
|
python -m mlx_server.main --model mlx-community/gemma-3-4b-it-4bit --port 1234
|
||||||
|
python -m mlx_server.main --model mlx-community/Qwen3-VL-4B-Instruct-4bit --port 1234
|
||||||
```
|
```
|
||||||
|
|
||||||
## Project Structure
|
## Project Structure
|
||||||
@@ -21,11 +25,18 @@ python -m mlx_server.main --model mlx-community/gemma-3-4b-it-4bit --port 1234
|
|||||||
- `mlx_server/engine.py` — Model loading, prompt building, generation (mlx_vlm)
|
- `mlx_server/engine.py` — Model loading, prompt building, generation (mlx_vlm)
|
||||||
- `mlx_server/models.py` — Pydantic models for OpenAI API request/response types
|
- `mlx_server/models.py` — Pydantic models for OpenAI API request/response types
|
||||||
|
|
||||||
|
## Supported Models
|
||||||
|
|
||||||
|
| Alias | HuggingFace ID | Notes |
|
||||||
|
|-------|---------------|-------|
|
||||||
|
| `gemma` | `mlx-community/gemma-3-4b-it-4bit` | Vision + tool use via `tool_code` blocks |
|
||||||
|
| `qwen` | `mlx-community/Qwen3-VL-4B-Instruct-4bit` | Vision + tool use via `<tool_call>` tags |
|
||||||
|
|
||||||
## Key Design Decisions
|
## Key Design Decisions
|
||||||
|
|
||||||
- Uses `mlx_vlm` (not `mlx_lm`) as the inference backend — this supports both text and vision in a single model load
|
- Uses `mlx_vlm` (not `mlx_lm`) as the inference backend — this supports both text and vision in a single model load
|
||||||
- Gemma 3 has no system role — system messages are converted to user/assistant pairs
|
- Model-specific prompt formatting: Gemma converts system→user/assistant pairs and uses `tool_code` blocks; Qwen3 uses native system role and `<tool_call>` XML tags
|
||||||
- Tool use is prompt-engineered: tools are injected into the system prompt with `<tool_call>` XML tags, and parsed from model output
|
- Offline-first: if the model is already cached locally (~/.cache/huggingface/hub/), the server resolves the local snapshot path directly — no network requests are made (HEAD checks, update checks, etc.)
|
||||||
- Thread lock on generation (single-request-at-a-time) — MLX models aren't safe for concurrent generation
|
- Thread lock on generation (single-request-at-a-time) — MLX models aren't safe for concurrent generation
|
||||||
- 128k context window supported via the model's native capabilities
|
- 128k context window supported via the model's native capabilities
|
||||||
|
|
||||||
|
|||||||
@@ -20,6 +20,58 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
DEFAULT_MODEL = "mlx-community/gemma-3-4b-it-4bit"
|
DEFAULT_MODEL = "mlx-community/gemma-3-4b-it-4bit"
|
||||||
|
|
||||||
|
# Known model aliases for quick selection
|
||||||
|
MODEL_ALIASES: dict[str, str] = {
|
||||||
|
"gemma": "mlx-community/gemma-3-4b-it-4bit",
|
||||||
|
"qwen": "mlx-community/Qwen3-VL-4B-Instruct-4bit",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_local_model_path(repo_id: str) -> Path | None:
|
||||||
|
"""If a HuggingFace model is already cached locally, return its snapshot path.
|
||||||
|
|
||||||
|
This avoids any network requests (HEAD checks) when the model files are
|
||||||
|
already present on disk — critical for offline use.
|
||||||
|
"""
|
||||||
|
# If it's already a local directory, just use it
|
||||||
|
local = Path(repo_id)
|
||||||
|
if local.is_dir():
|
||||||
|
return local
|
||||||
|
|
||||||
|
# Check the HF cache: ~/.cache/huggingface/hub/models--org--name/snapshots/<hash>
|
||||||
|
cache_root = Path.home() / ".cache" / "huggingface" / "hub"
|
||||||
|
safe_name = "models--" + repo_id.replace("/", "--")
|
||||||
|
model_cache = cache_root / safe_name
|
||||||
|
|
||||||
|
if not model_cache.is_dir():
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Read the ref to find the snapshot hash
|
||||||
|
refs_dir = model_cache / "refs"
|
||||||
|
snapshot_dir = model_cache / "snapshots"
|
||||||
|
if refs_dir.is_dir() and snapshot_dir.is_dir():
|
||||||
|
main_ref = refs_dir / "main"
|
||||||
|
if main_ref.is_file():
|
||||||
|
commit_hash = main_ref.read_text().strip()
|
||||||
|
snap = snapshot_dir / commit_hash
|
||||||
|
if snap.is_dir():
|
||||||
|
logger.info(
|
||||||
|
"Found locally cached model at %s — skipping online check", snap
|
||||||
|
)
|
||||||
|
return snap
|
||||||
|
|
||||||
|
# Fallback: use the first (most recent) snapshot if refs/main is missing
|
||||||
|
if snapshot_dir.is_dir():
|
||||||
|
snapshots = sorted(snapshot_dir.iterdir(), key=lambda p: p.stat().st_mtime, reverse=True)
|
||||||
|
if snapshots:
|
||||||
|
logger.info(
|
||||||
|
"Found locally cached model at %s — skipping online check",
|
||||||
|
snapshots[0],
|
||||||
|
)
|
||||||
|
return snapshots[0]
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# Helpers for Gemma 3 tool_code format
|
# Helpers for Gemma 3 tool_code format
|
||||||
@@ -200,17 +252,35 @@ class InferenceEngine:
|
|||||||
self.model = None
|
self.model = None
|
||||||
self.processor = None
|
self.processor = None
|
||||||
self.config = None
|
self.config = None
|
||||||
|
self._model_type: str = "" # e.g. "gemma3", "qwen3"
|
||||||
self._lock = threading.Lock()
|
self._lock = threading.Lock()
|
||||||
self._prompt_cache = PromptCache()
|
self._prompt_cache = PromptCache()
|
||||||
|
|
||||||
def load(self) -> None:
|
def load(self) -> None:
|
||||||
logger.info("Loading model %s ...", self.model_path)
|
logger.info("Loading model %s ...", self.model_path)
|
||||||
self.model, self.processor = mlx_vlm.load(self.model_path)
|
|
||||||
|
# Prefer the local cache to avoid any network requests
|
||||||
|
local_path = _resolve_local_model_path(self.model_path)
|
||||||
|
load_path = str(local_path) if local_path else self.model_path
|
||||||
|
|
||||||
|
self.model, self.processor = mlx_vlm.load(load_path)
|
||||||
|
|
||||||
# Load model config for chat template
|
# Load model config for chat template
|
||||||
from transformers import AutoConfig
|
from transformers import AutoConfig
|
||||||
|
|
||||||
self.config = AutoConfig.from_pretrained(self.model_path, trust_remote_code=True)
|
self.config = AutoConfig.from_pretrained(load_path, trust_remote_code=True)
|
||||||
logger.info("Model loaded successfully.")
|
|
||||||
|
# Detect model family for prompt-format branching
|
||||||
|
self._model_type = getattr(self.config, "model_type", "").lower()
|
||||||
|
logger.info("Model loaded successfully (type=%s).", self._model_type)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_qwen(self) -> bool:
|
||||||
|
return "qwen" in self._model_type
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_gemma(self) -> bool:
|
||||||
|
return "gemma" in self._model_type
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# Image helpers
|
# Image helpers
|
||||||
@@ -244,6 +314,16 @@ class InferenceEngine:
|
|||||||
|
|
||||||
Returns (prompt_str, image_paths).
|
Returns (prompt_str, image_paths).
|
||||||
"""
|
"""
|
||||||
|
if self.is_qwen:
|
||||||
|
return self._build_prompt_qwen(messages, tools)
|
||||||
|
return self._build_prompt_gemma(messages, tools)
|
||||||
|
|
||||||
|
def _build_prompt_gemma(
|
||||||
|
self,
|
||||||
|
messages: list[dict],
|
||||||
|
tools: list[dict] | None = None,
|
||||||
|
) -> tuple[str, list[str]]:
|
||||||
|
"""Gemma 3 prompt builder (tool_code format, no system role)."""
|
||||||
image_paths: list[str] = []
|
image_paths: list[str] = []
|
||||||
formatted_messages: list[dict] = []
|
formatted_messages: list[dict] = []
|
||||||
|
|
||||||
@@ -306,6 +386,58 @@ class InferenceEngine:
|
|||||||
|
|
||||||
return prompt, image_paths
|
return prompt, image_paths
|
||||||
|
|
||||||
|
def _build_prompt_qwen(
|
||||||
|
self,
|
||||||
|
messages: list[dict],
|
||||||
|
tools: list[dict] | None = None,
|
||||||
|
) -> tuple[str, list[str]]:
|
||||||
|
"""Qwen3 prompt builder (native system role, JSON tool calls)."""
|
||||||
|
image_paths: list[str] = []
|
||||||
|
formatted_messages: list[dict] = []
|
||||||
|
|
||||||
|
# Qwen3 supports system role natively — inject tools there
|
||||||
|
has_system = any(m.get("role") == "system" for m in messages)
|
||||||
|
if tools and not has_system:
|
||||||
|
formatted_messages.append({
|
||||||
|
"role": "system",
|
||||||
|
"content": self._build_qwen_tool_system_prompt(tools),
|
||||||
|
})
|
||||||
|
|
||||||
|
for msg in messages:
|
||||||
|
role = msg["role"]
|
||||||
|
content = msg.get("content")
|
||||||
|
tool_calls = msg.get("tool_calls")
|
||||||
|
|
||||||
|
if role == "system":
|
||||||
|
text = self._get_text_content(content)
|
||||||
|
if tools:
|
||||||
|
text = text + "\n\n" + self._build_qwen_tool_system_prompt(tools)
|
||||||
|
formatted_messages.append({"role": "system", "content": text})
|
||||||
|
elif role == "user":
|
||||||
|
text, imgs = self._extract_content_parts(content)
|
||||||
|
image_paths.extend(imgs)
|
||||||
|
formatted_messages.append({"role": "user", "content": text})
|
||||||
|
elif role == "assistant":
|
||||||
|
text = self._get_text_content(content) or ""
|
||||||
|
if tool_calls:
|
||||||
|
tc_text = self._format_qwen_tool_calls_for_prompt(tool_calls)
|
||||||
|
text = (text + "\n" + tc_text).strip()
|
||||||
|
formatted_messages.append({"role": "assistant", "content": text})
|
||||||
|
elif role == "tool":
|
||||||
|
tool_text = self._get_text_content(content) or ""
|
||||||
|
formatted_messages.append({"role": "user", "content": tool_text})
|
||||||
|
|
||||||
|
# Apply chat template via mlx_vlm
|
||||||
|
prompt = mlx_vlm.apply_chat_template(
|
||||||
|
self.processor,
|
||||||
|
self.config,
|
||||||
|
formatted_messages,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
num_images=len(image_paths),
|
||||||
|
)
|
||||||
|
|
||||||
|
return prompt, image_paths
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _merge_consecutive_roles(messages: list[dict]) -> list[dict]:
|
def _merge_consecutive_roles(messages: list[dict]) -> list[dict]:
|
||||||
"""Merge consecutive messages with the same role into one.
|
"""Merge consecutive messages with the same role into one.
|
||||||
@@ -439,6 +571,50 @@ class InferenceEngine:
|
|||||||
parts.append(f"```tool_code\n{call_str}\n```")
|
parts.append(f"```tool_code\n{call_str}\n```")
|
||||||
return "\n".join(parts)
|
return "\n".join(parts)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Qwen3 tool helpers
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _build_qwen_tool_system_prompt(tools: list[dict]) -> str:
|
||||||
|
"""Build the tool system prompt for Qwen3 using its native JSON format."""
|
||||||
|
tool_descs = []
|
||||||
|
for tool in tools:
|
||||||
|
func = tool.get("function", tool)
|
||||||
|
tool_descs.append({
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": func["name"],
|
||||||
|
"description": func.get("description", ""),
|
||||||
|
"parameters": func.get("parameters", {}),
|
||||||
|
},
|
||||||
|
})
|
||||||
|
tools_json = json.dumps(tool_descs, indent=2)
|
||||||
|
return (
|
||||||
|
"# Tools\n\n"
|
||||||
|
"You are a helpful assistant with access to the following tools. "
|
||||||
|
"Use them when appropriate by responding with a JSON tool call.\n\n"
|
||||||
|
"## Available Tools\n\n"
|
||||||
|
f"{tools_json}\n\n"
|
||||||
|
"## Tool Call Format\n\n"
|
||||||
|
"When you need to call a tool, respond with:\n"
|
||||||
|
'<tool_call>\n{"name": "<function_name>", "arguments": {<args>}}\n</tool_call>'
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _format_qwen_tool_calls_for_prompt(tool_calls: list[dict]) -> str:
|
||||||
|
"""Format OpenAI-style tool calls back into Qwen3's XML tag format."""
|
||||||
|
parts = []
|
||||||
|
for tc in tool_calls:
|
||||||
|
func = tc.get("function", tc)
|
||||||
|
name = func["name"]
|
||||||
|
args = func.get("arguments", "{}")
|
||||||
|
if isinstance(args, str):
|
||||||
|
args = json.loads(args)
|
||||||
|
call_obj = {"name": name, "arguments": args}
|
||||||
|
parts.append(f"<tool_call>\n{json.dumps(call_obj)}\n</tool_call>")
|
||||||
|
return "\n".join(parts)
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# Prefix cache & generation
|
# Prefix cache & generation
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
@@ -447,6 +623,32 @@ class InferenceEngine:
|
|||||||
# Note: KV cache quantization is not supported with Gemma 3's RotatingKVCache
|
# Note: KV cache quantization is not supported with Gemma 3's RotatingKVCache
|
||||||
_GENERATE_KWARGS: dict = {}
|
_GENERATE_KWARGS: dict = {}
|
||||||
|
|
||||||
|
# Keys in the prep dict that are internal bookkeeping, not kwargs for
|
||||||
|
# mlx_vlm.stream_generate.
|
||||||
|
_PREP_INTERNAL_KEYS = frozenset({
|
||||||
|
"input_ids", "pixel_values", "mask", "prompt_cache",
|
||||||
|
"_full_token_ids", "_prompt_token_count",
|
||||||
|
})
|
||||||
|
|
||||||
|
def _extra_generate_kwargs(
|
||||||
|
self, images: list[str] | None, prep: dict | None = None,
|
||||||
|
) -> dict:
|
||||||
|
"""Build per-request kwargs for mlx_vlm.stream_generate.
|
||||||
|
|
||||||
|
Includes model-specific keys from prepare_inputs (e.g. image_grid_thw
|
||||||
|
for Qwen3-VL) and works around a chunked-prefill bug where
|
||||||
|
visual_pos_masks is None for text-only requests.
|
||||||
|
"""
|
||||||
|
extra: dict = dict(self._GENERATE_KWARGS)
|
||||||
|
if self.is_qwen and not images:
|
||||||
|
extra["prefill_step_size"] = None
|
||||||
|
# Forward any model-specific keys that prepare_inputs returned
|
||||||
|
if prep:
|
||||||
|
for k, v in prep.items():
|
||||||
|
if k not in self._PREP_INTERNAL_KEYS:
|
||||||
|
extra[k] = v
|
||||||
|
return extra
|
||||||
|
|
||||||
def _get_tokenizer(self):
|
def _get_tokenizer(self):
|
||||||
"""Get the underlying tokenizer from the processor."""
|
"""Get the underlying tokenizer from the processor."""
|
||||||
proc = self.processor
|
proc = self.processor
|
||||||
@@ -484,6 +686,11 @@ class InferenceEngine:
|
|||||||
pixel_values = inputs.get("pixel_values")
|
pixel_values = inputs.get("pixel_values")
|
||||||
mask = inputs.get("attention_mask")
|
mask = inputs.get("attention_mask")
|
||||||
|
|
||||||
|
# Collect any model-specific extra keys from prepare_inputs
|
||||||
|
# (e.g. image_grid_thw for Qwen3-VL) so they reach the model.
|
||||||
|
_KNOWN_KEYS = {"input_ids", "pixel_values", "attention_mask"}
|
||||||
|
extra_inputs = {k: v for k, v in inputs.items() if k not in _KNOWN_KEYS}
|
||||||
|
|
||||||
full_token_list = full_input_ids.flatten().tolist()
|
full_token_list = full_input_ids.flatten().tolist()
|
||||||
prefix_len = self._prompt_cache.get_reusable_length(full_token_list)
|
prefix_len = self._prompt_cache.get_reusable_length(full_token_list)
|
||||||
|
|
||||||
@@ -521,7 +728,9 @@ class InferenceEngine:
|
|||||||
}
|
}
|
||||||
|
|
||||||
# Cache miss — create a fresh KV cache
|
# Cache miss — create a fresh KV cache
|
||||||
cache = cache_module.make_prompt_cache(self.model.language_model)
|
# VLM models expose .language_model; text-only models are the LM directly
|
||||||
|
lm = getattr(self.model, "language_model", self.model)
|
||||||
|
cache = cache_module.make_prompt_cache(lm)
|
||||||
logger.info(
|
logger.info(
|
||||||
"Prefix cache miss: processing %d tokens from scratch",
|
"Prefix cache miss: processing %d tokens from scratch",
|
||||||
len(full_token_list),
|
len(full_token_list),
|
||||||
@@ -533,6 +742,7 @@ class InferenceEngine:
|
|||||||
"prompt_cache": cache,
|
"prompt_cache": cache,
|
||||||
"_full_token_ids": full_token_list,
|
"_full_token_ids": full_token_list,
|
||||||
"_prompt_token_count": len(full_token_list),
|
"_prompt_token_count": len(full_token_list),
|
||||||
|
**extra_inputs,
|
||||||
}
|
}
|
||||||
|
|
||||||
def _save_cache(self, prep: dict, generated_tokens: list[int]) -> None:
|
def _save_cache(self, prep: dict, generated_tokens: list[int]) -> None:
|
||||||
@@ -555,9 +765,10 @@ class InferenceEngine:
|
|||||||
prep = self._prepare_generation(prompt, images)
|
prep = self._prepare_generation(prompt, images)
|
||||||
prompt_token_count = prep["_prompt_token_count"]
|
prompt_token_count = prep["_prompt_token_count"]
|
||||||
|
|
||||||
# Ensure stopping criteria is initialised
|
# Ensure stopping criteria is initialised (Gemma-specific; optional for others)
|
||||||
tokenizer = self._get_tokenizer()
|
tokenizer = self._get_tokenizer()
|
||||||
tokenizer.stopping_criteria.reset(self.model.config.eos_token_id)
|
if hasattr(tokenizer, "stopping_criteria"):
|
||||||
|
tokenizer.stopping_criteria.reset(self.model.config.eos_token_id)
|
||||||
|
|
||||||
text = ""
|
text = ""
|
||||||
generated_tokens: list[int] = []
|
generated_tokens: list[int] = []
|
||||||
@@ -575,7 +786,7 @@ class InferenceEngine:
|
|||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
repetition_penalty=repetition_penalty,
|
repetition_penalty=repetition_penalty,
|
||||||
**self._GENERATE_KWARGS,
|
**self._extra_generate_kwargs(images, prep),
|
||||||
):
|
):
|
||||||
text += result.text
|
text += result.text
|
||||||
if result.token is not None:
|
if result.token is not None:
|
||||||
@@ -603,9 +814,10 @@ class InferenceEngine:
|
|||||||
prep = self._prepare_generation(prompt, images)
|
prep = self._prepare_generation(prompt, images)
|
||||||
prompt_token_count = prep["_prompt_token_count"]
|
prompt_token_count = prep["_prompt_token_count"]
|
||||||
|
|
||||||
# Ensure stopping criteria is initialised
|
# Ensure stopping criteria is initialised (Gemma-specific; optional for others)
|
||||||
tokenizer = self._get_tokenizer()
|
tokenizer = self._get_tokenizer()
|
||||||
tokenizer.stopping_criteria.reset(self.model.config.eos_token_id)
|
if hasattr(tokenizer, "stopping_criteria"):
|
||||||
|
tokenizer.stopping_criteria.reset(self.model.config.eos_token_id)
|
||||||
|
|
||||||
accumulated = ""
|
accumulated = ""
|
||||||
generated_tokens: list[int] = []
|
generated_tokens: list[int] = []
|
||||||
@@ -624,7 +836,7 @@ class InferenceEngine:
|
|||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
repetition_penalty=repetition_penalty,
|
repetition_penalty=repetition_penalty,
|
||||||
**self._GENERATE_KWARGS,
|
**self._extra_generate_kwargs(images, prep),
|
||||||
):
|
):
|
||||||
token_text = result.text
|
token_text = result.text
|
||||||
accumulated += token_text
|
accumulated += token_text
|
||||||
@@ -663,19 +875,26 @@ class InferenceEngine:
|
|||||||
# Tool call parsing from model output
|
# Tool call parsing from model output
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def parse_tool_calls(
|
def parse_tool_calls(
|
||||||
text: str, tools: list[dict] | None = None
|
self, text: str, tools: list[dict] | None = None
|
||||||
) -> tuple[str, list[dict]]:
|
) -> tuple[str, list[dict]]:
|
||||||
"""Parse tool calls from model output using Gemma 3's tool_code format.
|
"""Parse tool calls from model output.
|
||||||
|
|
||||||
Detects ```tool_code ... ``` blocks containing Python-style or
|
Supports both Gemma 3's ```tool_code``` blocks and Qwen3's
|
||||||
shell-style function calls.
|
<tool_call> XML tags.
|
||||||
|
|
||||||
Returns (clean_text, tool_calls) where tool_calls is a list of
|
Returns (clean_text, tool_calls) where tool_calls is a list of
|
||||||
{"id": str, "type": "function", "function": {"name": str, "arguments": str}}.
|
{"id": str, "type": "function", "function": {"name": str, "arguments": str}}.
|
||||||
"""
|
"""
|
||||||
# Build a lookup of function name -> parameter schema
|
if self.is_qwen:
|
||||||
|
return self._parse_tool_calls_qwen(text)
|
||||||
|
return self._parse_tool_calls_gemma(text, tools)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _parse_tool_calls_gemma(
|
||||||
|
text: str, tools: list[dict] | None = None
|
||||||
|
) -> tuple[str, list[dict]]:
|
||||||
|
"""Parse Gemma 3 tool_code blocks."""
|
||||||
tool_defs: dict[str, dict] = {}
|
tool_defs: dict[str, dict] = {}
|
||||||
if tools:
|
if tools:
|
||||||
for tool in tools:
|
for tool in tools:
|
||||||
@@ -704,3 +923,32 @@ class InferenceEngine:
|
|||||||
logger.warning("Failed to parse tool_code call %r: %s", call_str, e)
|
logger.warning("Failed to parse tool_code call %r: %s", call_str, e)
|
||||||
|
|
||||||
return clean_text, tool_calls
|
return clean_text, tool_calls
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _parse_tool_calls_qwen(text: str) -> tuple[str, list[dict]]:
|
||||||
|
"""Parse Qwen3 <tool_call> XML tags."""
|
||||||
|
tool_calls = []
|
||||||
|
pattern = r"<tool_call>\s*(.*?)\s*</tool_call>"
|
||||||
|
matches = re.findall(pattern, text, re.DOTALL)
|
||||||
|
|
||||||
|
clean_text = re.sub(r"<tool_call>\s*.*?\s*</tool_call>", "", text, flags=re.DOTALL).strip()
|
||||||
|
|
||||||
|
for i, match in enumerate(matches):
|
||||||
|
try:
|
||||||
|
call_obj = json.loads(match.strip())
|
||||||
|
name = call_obj.get("name", "")
|
||||||
|
args = call_obj.get("arguments", {})
|
||||||
|
if isinstance(args, str):
|
||||||
|
args = json.loads(args)
|
||||||
|
tool_calls.append({
|
||||||
|
"id": f"call_{i}_{hash(match) % 10**8:08d}",
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": name,
|
||||||
|
"arguments": json.dumps(args),
|
||||||
|
},
|
||||||
|
})
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("Failed to parse tool_call tag %r: %s", match, e)
|
||||||
|
|
||||||
|
return clean_text, tool_calls
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
"""OpenAI-compatible API server for Gemma 3 4B via MLX."""
|
"""OpenAI-compatible API server for local LLMs (Gemma 3, Qwen3, …) via MLX."""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
@@ -31,7 +31,7 @@ from .models import (
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
app = FastAPI(title="MLX Server", description="OpenAI-compatible API for Gemma 3 4B")
|
app = FastAPI(title="MLX Server", description="OpenAI-compatible API for local LLMs on Apple Silicon")
|
||||||
|
|
||||||
app.add_middleware(
|
app.add_middleware(
|
||||||
CORSMiddleware,
|
CORSMiddleware,
|
||||||
@@ -170,6 +170,11 @@ async def _stream_response(
|
|||||||
prompt_tokens = 0
|
prompt_tokens = 0
|
||||||
gen_tokens = 0
|
gen_tokens = 0
|
||||||
|
|
||||||
|
# When tools are available we must buffer the full response before
|
||||||
|
# emitting content — otherwise raw tool-call markup (```tool_code```
|
||||||
|
# or <tool_call>) leaks into the streamed text.
|
||||||
|
buffer_for_tools = bool(tools)
|
||||||
|
|
||||||
for token_text, is_final, pt, gt in e.stream_generate(
|
for token_text, is_final, pt, gt in e.stream_generate(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
images=images or None,
|
images=images or None,
|
||||||
@@ -182,7 +187,7 @@ async def _stream_response(
|
|||||||
gen_tokens = gt
|
gen_tokens = gt
|
||||||
full_text += token_text
|
full_text += token_text
|
||||||
|
|
||||||
if not is_final and token_text:
|
if not buffer_for_tools and not is_final and token_text:
|
||||||
chunk = ChatCompletionChunk(
|
chunk = ChatCompletionChunk(
|
||||||
id=request_id,
|
id=request_id,
|
||||||
created=created,
|
created=created,
|
||||||
@@ -191,37 +196,53 @@ async def _stream_response(
|
|||||||
)
|
)
|
||||||
yield {"data": chunk.model_dump_json()}
|
yield {"data": chunk.model_dump_json()}
|
||||||
|
|
||||||
# Check for tool calls in complete response
|
# --- Post-generation: parse tool calls and emit clean content ------
|
||||||
finish_reason = "stop"
|
finish_reason = "stop"
|
||||||
|
tool_calls_parsed = []
|
||||||
|
|
||||||
if tools:
|
if tools:
|
||||||
clean_text, parsed = e.parse_tool_calls(full_text, tools)
|
clean_text, parsed = e.parse_tool_calls(full_text, tools)
|
||||||
if parsed:
|
if parsed:
|
||||||
finish_reason = "tool_calls"
|
finish_reason = "tool_calls"
|
||||||
# Emit tool call chunks
|
tool_calls_parsed = parsed
|
||||||
for i, tc in enumerate(parsed):
|
full_text = clean_text or ""
|
||||||
tc_chunk = ChatCompletionChunk(
|
|
||||||
id=request_id,
|
# Emit buffered content (when tools were present, this is the cleaned
|
||||||
created=created,
|
# text with tool-call markup stripped out)
|
||||||
model=model_name,
|
if buffer_for_tools and full_text.strip():
|
||||||
choices=[
|
content_chunk = ChatCompletionChunk(
|
||||||
StreamChoice(
|
id=request_id,
|
||||||
delta=DeltaMessage(
|
created=created,
|
||||||
tool_calls=[
|
model=model_name,
|
||||||
ToolCall(
|
choices=[StreamChoice(delta=DeltaMessage(content=full_text))],
|
||||||
index=i,
|
)
|
||||||
id=tc["id"],
|
yield {"data": content_chunk.model_dump_json()}
|
||||||
type="function",
|
|
||||||
function=FunctionCall(
|
# Emit tool call chunks
|
||||||
name=tc["function"]["name"],
|
for i, tc in enumerate(tool_calls_parsed):
|
||||||
arguments=tc["function"]["arguments"],
|
tc_chunk = ChatCompletionChunk(
|
||||||
),
|
id=request_id,
|
||||||
)
|
created=created,
|
||||||
]
|
model=model_name,
|
||||||
|
choices=[
|
||||||
|
StreamChoice(
|
||||||
|
delta=DeltaMessage(
|
||||||
|
tool_calls=[
|
||||||
|
ToolCall(
|
||||||
|
index=i,
|
||||||
|
id=tc["id"],
|
||||||
|
type="function",
|
||||||
|
function=FunctionCall(
|
||||||
|
name=tc["function"]["name"],
|
||||||
|
arguments=tc["function"]["arguments"],
|
||||||
|
),
|
||||||
)
|
)
|
||||||
)
|
]
|
||||||
],
|
)
|
||||||
)
|
)
|
||||||
yield {"data": tc_chunk.model_dump_json()}
|
],
|
||||||
|
)
|
||||||
|
yield {"data": tc_chunk.model_dump_json()}
|
||||||
|
|
||||||
# Final chunk with finish reason and usage
|
# Final chunk with finish reason and usage
|
||||||
final_chunk = ChatCompletionChunk(
|
final_chunk = ChatCompletionChunk(
|
||||||
|
|||||||
26
run.sh
26
run.sh
@@ -7,8 +7,28 @@ cd "$SCRIPT_DIR"
|
|||||||
# Activate virtual environment
|
# Activate virtual environment
|
||||||
source .venv/bin/activate
|
source .venv/bin/activate
|
||||||
|
|
||||||
# Default model – 4-bit quantized Gemma 3 4B IT (vision-capable)
|
# --- Model selection ---
|
||||||
MODEL="${MODEL:-mlx-community/gemma-3-4b-it-4bit}"
|
# Usage: ./run.sh [gemma|qwen]
|
||||||
|
# Or set MODEL env var directly for a custom model.
|
||||||
|
|
||||||
|
MODEL_CHOICE="${1:-gemma}"
|
||||||
|
|
||||||
|
if [[ -z "${MODEL:-}" ]]; then
|
||||||
|
case "$MODEL_CHOICE" in
|
||||||
|
gemma)
|
||||||
|
MODEL="mlx-community/gemma-3-4b-it-4bit"
|
||||||
|
;;
|
||||||
|
qwen)
|
||||||
|
MODEL="mlx-community/Qwen3-VL-4B-Instruct-4bit"
|
||||||
|
;;
|
||||||
|
*)
|
||||||
|
echo "Unknown model choice: $MODEL_CHOICE"
|
||||||
|
echo "Usage: $0 [gemma|qwen]"
|
||||||
|
exit 1
|
||||||
|
;;
|
||||||
|
esac
|
||||||
|
fi
|
||||||
|
|
||||||
HOST="${HOST:-127.0.0.1}"
|
HOST="${HOST:-127.0.0.1}"
|
||||||
PORT="${PORT:-1234}"
|
PORT="${PORT:-1234}"
|
||||||
|
|
||||||
@@ -21,4 +41,4 @@ exec python -m mlx_server.main \
|
|||||||
--model "$MODEL" \
|
--model "$MODEL" \
|
||||||
--host "$HOST" \
|
--host "$HOST" \
|
||||||
--port "$PORT" \
|
--port "$PORT" \
|
||||||
"$@"
|
"${@:2}"
|
||||||
|
|||||||
Reference in New Issue
Block a user