feat: qwen now works, too
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -8,3 +8,4 @@ build/
|
||||
.env
|
||||
*.log
|
||||
.DS_Store
|
||||
settings.local.json
|
||||
|
||||
19
CLAUDE.md
19
CLAUDE.md
@@ -1,6 +1,6 @@
|
||||
# MLX Server
|
||||
|
||||
OpenAI-compatible API server for Gemma 3 4B (vision + tool use) on Apple Silicon via MLX.
|
||||
OpenAI-compatible API server for local LLMs on Apple Silicon via MLX. Supports Gemma 3 4B and Qwen3 VL 4B (vision + tool use).
|
||||
|
||||
## Quick Start
|
||||
|
||||
@@ -8,11 +8,15 @@ OpenAI-compatible API server for Gemma 3 4B (vision + tool use) on Apple Silicon
|
||||
# Activate virtual environment
|
||||
source .venv/bin/activate
|
||||
|
||||
# Run the server (downloads model on first run)
|
||||
# Run with Gemma 3 (default)
|
||||
./run.sh
|
||||
|
||||
# Run with Qwen3
|
||||
./run.sh qwen
|
||||
|
||||
# Or directly:
|
||||
python -m mlx_server.main --model mlx-community/gemma-3-4b-it-4bit --port 1234
|
||||
python -m mlx_server.main --model mlx-community/Qwen3-VL-4B-Instruct-4bit --port 1234
|
||||
```
|
||||
|
||||
## Project Structure
|
||||
@@ -21,11 +25,18 @@ python -m mlx_server.main --model mlx-community/gemma-3-4b-it-4bit --port 1234
|
||||
- `mlx_server/engine.py` — Model loading, prompt building, generation (mlx_vlm)
|
||||
- `mlx_server/models.py` — Pydantic models for OpenAI API request/response types
|
||||
|
||||
## Supported Models
|
||||
|
||||
| Alias | HuggingFace ID | Notes |
|
||||
|-------|---------------|-------|
|
||||
| `gemma` | `mlx-community/gemma-3-4b-it-4bit` | Vision + tool use via `tool_code` blocks |
|
||||
| `qwen` | `mlx-community/Qwen3-VL-4B-Instruct-4bit` | Vision + tool use via `<tool_call>` tags |
|
||||
|
||||
## Key Design Decisions
|
||||
|
||||
- Uses `mlx_vlm` (not `mlx_lm`) as the inference backend — this supports both text and vision in a single model load
|
||||
- Gemma 3 has no system role — system messages are converted to user/assistant pairs
|
||||
- Tool use is prompt-engineered: tools are injected into the system prompt with `<tool_call>` XML tags, and parsed from model output
|
||||
- Model-specific prompt formatting: Gemma converts system→user/assistant pairs and uses `tool_code` blocks; Qwen3 uses native system role and `<tool_call>` XML tags
|
||||
- Offline-first: if the model is already cached locally (~/.cache/huggingface/hub/), the server resolves the local snapshot path directly — no network requests are made (HEAD checks, update checks, etc.)
|
||||
- Thread lock on generation (single-request-at-a-time) — MLX models aren't safe for concurrent generation
|
||||
- 128k context window supported via the model's native capabilities
|
||||
|
||||
|
||||
@@ -20,6 +20,58 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_MODEL = "mlx-community/gemma-3-4b-it-4bit"
|
||||
|
||||
# Known model aliases for quick selection
|
||||
MODEL_ALIASES: dict[str, str] = {
|
||||
"gemma": "mlx-community/gemma-3-4b-it-4bit",
|
||||
"qwen": "mlx-community/Qwen3-VL-4B-Instruct-4bit",
|
||||
}
|
||||
|
||||
|
||||
def _resolve_local_model_path(repo_id: str) -> Path | None:
|
||||
"""If a HuggingFace model is already cached locally, return its snapshot path.
|
||||
|
||||
This avoids any network requests (HEAD checks) when the model files are
|
||||
already present on disk — critical for offline use.
|
||||
"""
|
||||
# If it's already a local directory, just use it
|
||||
local = Path(repo_id)
|
||||
if local.is_dir():
|
||||
return local
|
||||
|
||||
# Check the HF cache: ~/.cache/huggingface/hub/models--org--name/snapshots/<hash>
|
||||
cache_root = Path.home() / ".cache" / "huggingface" / "hub"
|
||||
safe_name = "models--" + repo_id.replace("/", "--")
|
||||
model_cache = cache_root / safe_name
|
||||
|
||||
if not model_cache.is_dir():
|
||||
return None
|
||||
|
||||
# Read the ref to find the snapshot hash
|
||||
refs_dir = model_cache / "refs"
|
||||
snapshot_dir = model_cache / "snapshots"
|
||||
if refs_dir.is_dir() and snapshot_dir.is_dir():
|
||||
main_ref = refs_dir / "main"
|
||||
if main_ref.is_file():
|
||||
commit_hash = main_ref.read_text().strip()
|
||||
snap = snapshot_dir / commit_hash
|
||||
if snap.is_dir():
|
||||
logger.info(
|
||||
"Found locally cached model at %s — skipping online check", snap
|
||||
)
|
||||
return snap
|
||||
|
||||
# Fallback: use the first (most recent) snapshot if refs/main is missing
|
||||
if snapshot_dir.is_dir():
|
||||
snapshots = sorted(snapshot_dir.iterdir(), key=lambda p: p.stat().st_mtime, reverse=True)
|
||||
if snapshots:
|
||||
logger.info(
|
||||
"Found locally cached model at %s — skipping online check",
|
||||
snapshots[0],
|
||||
)
|
||||
return snapshots[0]
|
||||
|
||||
return None
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Helpers for Gemma 3 tool_code format
|
||||
@@ -200,17 +252,35 @@ class InferenceEngine:
|
||||
self.model = None
|
||||
self.processor = None
|
||||
self.config = None
|
||||
self._model_type: str = "" # e.g. "gemma3", "qwen3"
|
||||
self._lock = threading.Lock()
|
||||
self._prompt_cache = PromptCache()
|
||||
|
||||
def load(self) -> None:
|
||||
logger.info("Loading model %s ...", self.model_path)
|
||||
self.model, self.processor = mlx_vlm.load(self.model_path)
|
||||
|
||||
# Prefer the local cache to avoid any network requests
|
||||
local_path = _resolve_local_model_path(self.model_path)
|
||||
load_path = str(local_path) if local_path else self.model_path
|
||||
|
||||
self.model, self.processor = mlx_vlm.load(load_path)
|
||||
|
||||
# Load model config for chat template
|
||||
from transformers import AutoConfig
|
||||
|
||||
self.config = AutoConfig.from_pretrained(self.model_path, trust_remote_code=True)
|
||||
logger.info("Model loaded successfully.")
|
||||
self.config = AutoConfig.from_pretrained(load_path, trust_remote_code=True)
|
||||
|
||||
# Detect model family for prompt-format branching
|
||||
self._model_type = getattr(self.config, "model_type", "").lower()
|
||||
logger.info("Model loaded successfully (type=%s).", self._model_type)
|
||||
|
||||
@property
|
||||
def is_qwen(self) -> bool:
|
||||
return "qwen" in self._model_type
|
||||
|
||||
@property
|
||||
def is_gemma(self) -> bool:
|
||||
return "gemma" in self._model_type
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Image helpers
|
||||
@@ -244,6 +314,16 @@ class InferenceEngine:
|
||||
|
||||
Returns (prompt_str, image_paths).
|
||||
"""
|
||||
if self.is_qwen:
|
||||
return self._build_prompt_qwen(messages, tools)
|
||||
return self._build_prompt_gemma(messages, tools)
|
||||
|
||||
def _build_prompt_gemma(
|
||||
self,
|
||||
messages: list[dict],
|
||||
tools: list[dict] | None = None,
|
||||
) -> tuple[str, list[str]]:
|
||||
"""Gemma 3 prompt builder (tool_code format, no system role)."""
|
||||
image_paths: list[str] = []
|
||||
formatted_messages: list[dict] = []
|
||||
|
||||
@@ -306,6 +386,58 @@ class InferenceEngine:
|
||||
|
||||
return prompt, image_paths
|
||||
|
||||
def _build_prompt_qwen(
|
||||
self,
|
||||
messages: list[dict],
|
||||
tools: list[dict] | None = None,
|
||||
) -> tuple[str, list[str]]:
|
||||
"""Qwen3 prompt builder (native system role, JSON tool calls)."""
|
||||
image_paths: list[str] = []
|
||||
formatted_messages: list[dict] = []
|
||||
|
||||
# Qwen3 supports system role natively — inject tools there
|
||||
has_system = any(m.get("role") == "system" for m in messages)
|
||||
if tools and not has_system:
|
||||
formatted_messages.append({
|
||||
"role": "system",
|
||||
"content": self._build_qwen_tool_system_prompt(tools),
|
||||
})
|
||||
|
||||
for msg in messages:
|
||||
role = msg["role"]
|
||||
content = msg.get("content")
|
||||
tool_calls = msg.get("tool_calls")
|
||||
|
||||
if role == "system":
|
||||
text = self._get_text_content(content)
|
||||
if tools:
|
||||
text = text + "\n\n" + self._build_qwen_tool_system_prompt(tools)
|
||||
formatted_messages.append({"role": "system", "content": text})
|
||||
elif role == "user":
|
||||
text, imgs = self._extract_content_parts(content)
|
||||
image_paths.extend(imgs)
|
||||
formatted_messages.append({"role": "user", "content": text})
|
||||
elif role == "assistant":
|
||||
text = self._get_text_content(content) or ""
|
||||
if tool_calls:
|
||||
tc_text = self._format_qwen_tool_calls_for_prompt(tool_calls)
|
||||
text = (text + "\n" + tc_text).strip()
|
||||
formatted_messages.append({"role": "assistant", "content": text})
|
||||
elif role == "tool":
|
||||
tool_text = self._get_text_content(content) or ""
|
||||
formatted_messages.append({"role": "user", "content": tool_text})
|
||||
|
||||
# Apply chat template via mlx_vlm
|
||||
prompt = mlx_vlm.apply_chat_template(
|
||||
self.processor,
|
||||
self.config,
|
||||
formatted_messages,
|
||||
add_generation_prompt=True,
|
||||
num_images=len(image_paths),
|
||||
)
|
||||
|
||||
return prompt, image_paths
|
||||
|
||||
@staticmethod
|
||||
def _merge_consecutive_roles(messages: list[dict]) -> list[dict]:
|
||||
"""Merge consecutive messages with the same role into one.
|
||||
@@ -439,6 +571,50 @@ class InferenceEngine:
|
||||
parts.append(f"```tool_code\n{call_str}\n```")
|
||||
return "\n".join(parts)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Qwen3 tool helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _build_qwen_tool_system_prompt(tools: list[dict]) -> str:
|
||||
"""Build the tool system prompt for Qwen3 using its native JSON format."""
|
||||
tool_descs = []
|
||||
for tool in tools:
|
||||
func = tool.get("function", tool)
|
||||
tool_descs.append({
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": func["name"],
|
||||
"description": func.get("description", ""),
|
||||
"parameters": func.get("parameters", {}),
|
||||
},
|
||||
})
|
||||
tools_json = json.dumps(tool_descs, indent=2)
|
||||
return (
|
||||
"# Tools\n\n"
|
||||
"You are a helpful assistant with access to the following tools. "
|
||||
"Use them when appropriate by responding with a JSON tool call.\n\n"
|
||||
"## Available Tools\n\n"
|
||||
f"{tools_json}\n\n"
|
||||
"## Tool Call Format\n\n"
|
||||
"When you need to call a tool, respond with:\n"
|
||||
'<tool_call>\n{"name": "<function_name>", "arguments": {<args>}}\n</tool_call>'
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _format_qwen_tool_calls_for_prompt(tool_calls: list[dict]) -> str:
|
||||
"""Format OpenAI-style tool calls back into Qwen3's XML tag format."""
|
||||
parts = []
|
||||
for tc in tool_calls:
|
||||
func = tc.get("function", tc)
|
||||
name = func["name"]
|
||||
args = func.get("arguments", "{}")
|
||||
if isinstance(args, str):
|
||||
args = json.loads(args)
|
||||
call_obj = {"name": name, "arguments": args}
|
||||
parts.append(f"<tool_call>\n{json.dumps(call_obj)}\n</tool_call>")
|
||||
return "\n".join(parts)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Prefix cache & generation
|
||||
# ------------------------------------------------------------------
|
||||
@@ -447,6 +623,32 @@ class InferenceEngine:
|
||||
# Note: KV cache quantization is not supported with Gemma 3's RotatingKVCache
|
||||
_GENERATE_KWARGS: dict = {}
|
||||
|
||||
# Keys in the prep dict that are internal bookkeeping, not kwargs for
|
||||
# mlx_vlm.stream_generate.
|
||||
_PREP_INTERNAL_KEYS = frozenset({
|
||||
"input_ids", "pixel_values", "mask", "prompt_cache",
|
||||
"_full_token_ids", "_prompt_token_count",
|
||||
})
|
||||
|
||||
def _extra_generate_kwargs(
|
||||
self, images: list[str] | None, prep: dict | None = None,
|
||||
) -> dict:
|
||||
"""Build per-request kwargs for mlx_vlm.stream_generate.
|
||||
|
||||
Includes model-specific keys from prepare_inputs (e.g. image_grid_thw
|
||||
for Qwen3-VL) and works around a chunked-prefill bug where
|
||||
visual_pos_masks is None for text-only requests.
|
||||
"""
|
||||
extra: dict = dict(self._GENERATE_KWARGS)
|
||||
if self.is_qwen and not images:
|
||||
extra["prefill_step_size"] = None
|
||||
# Forward any model-specific keys that prepare_inputs returned
|
||||
if prep:
|
||||
for k, v in prep.items():
|
||||
if k not in self._PREP_INTERNAL_KEYS:
|
||||
extra[k] = v
|
||||
return extra
|
||||
|
||||
def _get_tokenizer(self):
|
||||
"""Get the underlying tokenizer from the processor."""
|
||||
proc = self.processor
|
||||
@@ -484,6 +686,11 @@ class InferenceEngine:
|
||||
pixel_values = inputs.get("pixel_values")
|
||||
mask = inputs.get("attention_mask")
|
||||
|
||||
# Collect any model-specific extra keys from prepare_inputs
|
||||
# (e.g. image_grid_thw for Qwen3-VL) so they reach the model.
|
||||
_KNOWN_KEYS = {"input_ids", "pixel_values", "attention_mask"}
|
||||
extra_inputs = {k: v for k, v in inputs.items() if k not in _KNOWN_KEYS}
|
||||
|
||||
full_token_list = full_input_ids.flatten().tolist()
|
||||
prefix_len = self._prompt_cache.get_reusable_length(full_token_list)
|
||||
|
||||
@@ -521,7 +728,9 @@ class InferenceEngine:
|
||||
}
|
||||
|
||||
# Cache miss — create a fresh KV cache
|
||||
cache = cache_module.make_prompt_cache(self.model.language_model)
|
||||
# VLM models expose .language_model; text-only models are the LM directly
|
||||
lm = getattr(self.model, "language_model", self.model)
|
||||
cache = cache_module.make_prompt_cache(lm)
|
||||
logger.info(
|
||||
"Prefix cache miss: processing %d tokens from scratch",
|
||||
len(full_token_list),
|
||||
@@ -533,6 +742,7 @@ class InferenceEngine:
|
||||
"prompt_cache": cache,
|
||||
"_full_token_ids": full_token_list,
|
||||
"_prompt_token_count": len(full_token_list),
|
||||
**extra_inputs,
|
||||
}
|
||||
|
||||
def _save_cache(self, prep: dict, generated_tokens: list[int]) -> None:
|
||||
@@ -555,9 +765,10 @@ class InferenceEngine:
|
||||
prep = self._prepare_generation(prompt, images)
|
||||
prompt_token_count = prep["_prompt_token_count"]
|
||||
|
||||
# Ensure stopping criteria is initialised
|
||||
# Ensure stopping criteria is initialised (Gemma-specific; optional for others)
|
||||
tokenizer = self._get_tokenizer()
|
||||
tokenizer.stopping_criteria.reset(self.model.config.eos_token_id)
|
||||
if hasattr(tokenizer, "stopping_criteria"):
|
||||
tokenizer.stopping_criteria.reset(self.model.config.eos_token_id)
|
||||
|
||||
text = ""
|
||||
generated_tokens: list[int] = []
|
||||
@@ -575,7 +786,7 @@ class InferenceEngine:
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
repetition_penalty=repetition_penalty,
|
||||
**self._GENERATE_KWARGS,
|
||||
**self._extra_generate_kwargs(images, prep),
|
||||
):
|
||||
text += result.text
|
||||
if result.token is not None:
|
||||
@@ -603,9 +814,10 @@ class InferenceEngine:
|
||||
prep = self._prepare_generation(prompt, images)
|
||||
prompt_token_count = prep["_prompt_token_count"]
|
||||
|
||||
# Ensure stopping criteria is initialised
|
||||
# Ensure stopping criteria is initialised (Gemma-specific; optional for others)
|
||||
tokenizer = self._get_tokenizer()
|
||||
tokenizer.stopping_criteria.reset(self.model.config.eos_token_id)
|
||||
if hasattr(tokenizer, "stopping_criteria"):
|
||||
tokenizer.stopping_criteria.reset(self.model.config.eos_token_id)
|
||||
|
||||
accumulated = ""
|
||||
generated_tokens: list[int] = []
|
||||
@@ -624,7 +836,7 @@ class InferenceEngine:
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
repetition_penalty=repetition_penalty,
|
||||
**self._GENERATE_KWARGS,
|
||||
**self._extra_generate_kwargs(images, prep),
|
||||
):
|
||||
token_text = result.text
|
||||
accumulated += token_text
|
||||
@@ -663,19 +875,26 @@ class InferenceEngine:
|
||||
# Tool call parsing from model output
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def parse_tool_calls(
|
||||
text: str, tools: list[dict] | None = None
|
||||
self, text: str, tools: list[dict] | None = None
|
||||
) -> tuple[str, list[dict]]:
|
||||
"""Parse tool calls from model output using Gemma 3's tool_code format.
|
||||
"""Parse tool calls from model output.
|
||||
|
||||
Detects ```tool_code ... ``` blocks containing Python-style or
|
||||
shell-style function calls.
|
||||
Supports both Gemma 3's ```tool_code``` blocks and Qwen3's
|
||||
<tool_call> XML tags.
|
||||
|
||||
Returns (clean_text, tool_calls) where tool_calls is a list of
|
||||
{"id": str, "type": "function", "function": {"name": str, "arguments": str}}.
|
||||
"""
|
||||
# Build a lookup of function name -> parameter schema
|
||||
if self.is_qwen:
|
||||
return self._parse_tool_calls_qwen(text)
|
||||
return self._parse_tool_calls_gemma(text, tools)
|
||||
|
||||
@staticmethod
|
||||
def _parse_tool_calls_gemma(
|
||||
text: str, tools: list[dict] | None = None
|
||||
) -> tuple[str, list[dict]]:
|
||||
"""Parse Gemma 3 tool_code blocks."""
|
||||
tool_defs: dict[str, dict] = {}
|
||||
if tools:
|
||||
for tool in tools:
|
||||
@@ -704,3 +923,32 @@ class InferenceEngine:
|
||||
logger.warning("Failed to parse tool_code call %r: %s", call_str, e)
|
||||
|
||||
return clean_text, tool_calls
|
||||
|
||||
@staticmethod
|
||||
def _parse_tool_calls_qwen(text: str) -> tuple[str, list[dict]]:
|
||||
"""Parse Qwen3 <tool_call> XML tags."""
|
||||
tool_calls = []
|
||||
pattern = r"<tool_call>\s*(.*?)\s*</tool_call>"
|
||||
matches = re.findall(pattern, text, re.DOTALL)
|
||||
|
||||
clean_text = re.sub(r"<tool_call>\s*.*?\s*</tool_call>", "", text, flags=re.DOTALL).strip()
|
||||
|
||||
for i, match in enumerate(matches):
|
||||
try:
|
||||
call_obj = json.loads(match.strip())
|
||||
name = call_obj.get("name", "")
|
||||
args = call_obj.get("arguments", {})
|
||||
if isinstance(args, str):
|
||||
args = json.loads(args)
|
||||
tool_calls.append({
|
||||
"id": f"call_{i}_{hash(match) % 10**8:08d}",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": name,
|
||||
"arguments": json.dumps(args),
|
||||
},
|
||||
})
|
||||
except Exception as e:
|
||||
logger.warning("Failed to parse tool_call tag %r: %s", match, e)
|
||||
|
||||
return clean_text, tool_calls
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""OpenAI-compatible API server for Gemma 3 4B via MLX."""
|
||||
"""OpenAI-compatible API server for local LLMs (Gemma 3, Qwen3, …) via MLX."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
@@ -31,7 +31,7 @@ from .models import (
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
app = FastAPI(title="MLX Server", description="OpenAI-compatible API for Gemma 3 4B")
|
||||
app = FastAPI(title="MLX Server", description="OpenAI-compatible API for local LLMs on Apple Silicon")
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
@@ -170,6 +170,11 @@ async def _stream_response(
|
||||
prompt_tokens = 0
|
||||
gen_tokens = 0
|
||||
|
||||
# When tools are available we must buffer the full response before
|
||||
# emitting content — otherwise raw tool-call markup (```tool_code```
|
||||
# or <tool_call>) leaks into the streamed text.
|
||||
buffer_for_tools = bool(tools)
|
||||
|
||||
for token_text, is_final, pt, gt in e.stream_generate(
|
||||
prompt=prompt,
|
||||
images=images or None,
|
||||
@@ -182,7 +187,7 @@ async def _stream_response(
|
||||
gen_tokens = gt
|
||||
full_text += token_text
|
||||
|
||||
if not is_final and token_text:
|
||||
if not buffer_for_tools and not is_final and token_text:
|
||||
chunk = ChatCompletionChunk(
|
||||
id=request_id,
|
||||
created=created,
|
||||
@@ -191,37 +196,53 @@ async def _stream_response(
|
||||
)
|
||||
yield {"data": chunk.model_dump_json()}
|
||||
|
||||
# Check for tool calls in complete response
|
||||
# --- Post-generation: parse tool calls and emit clean content ------
|
||||
finish_reason = "stop"
|
||||
tool_calls_parsed = []
|
||||
|
||||
if tools:
|
||||
clean_text, parsed = e.parse_tool_calls(full_text, tools)
|
||||
if parsed:
|
||||
finish_reason = "tool_calls"
|
||||
# Emit tool call chunks
|
||||
for i, tc in enumerate(parsed):
|
||||
tc_chunk = ChatCompletionChunk(
|
||||
id=request_id,
|
||||
created=created,
|
||||
model=model_name,
|
||||
choices=[
|
||||
StreamChoice(
|
||||
delta=DeltaMessage(
|
||||
tool_calls=[
|
||||
ToolCall(
|
||||
index=i,
|
||||
id=tc["id"],
|
||||
type="function",
|
||||
function=FunctionCall(
|
||||
name=tc["function"]["name"],
|
||||
arguments=tc["function"]["arguments"],
|
||||
),
|
||||
)
|
||||
]
|
||||
tool_calls_parsed = parsed
|
||||
full_text = clean_text or ""
|
||||
|
||||
# Emit buffered content (when tools were present, this is the cleaned
|
||||
# text with tool-call markup stripped out)
|
||||
if buffer_for_tools and full_text.strip():
|
||||
content_chunk = ChatCompletionChunk(
|
||||
id=request_id,
|
||||
created=created,
|
||||
model=model_name,
|
||||
choices=[StreamChoice(delta=DeltaMessage(content=full_text))],
|
||||
)
|
||||
yield {"data": content_chunk.model_dump_json()}
|
||||
|
||||
# Emit tool call chunks
|
||||
for i, tc in enumerate(tool_calls_parsed):
|
||||
tc_chunk = ChatCompletionChunk(
|
||||
id=request_id,
|
||||
created=created,
|
||||
model=model_name,
|
||||
choices=[
|
||||
StreamChoice(
|
||||
delta=DeltaMessage(
|
||||
tool_calls=[
|
||||
ToolCall(
|
||||
index=i,
|
||||
id=tc["id"],
|
||||
type="function",
|
||||
function=FunctionCall(
|
||||
name=tc["function"]["name"],
|
||||
arguments=tc["function"]["arguments"],
|
||||
),
|
||||
)
|
||||
)
|
||||
],
|
||||
]
|
||||
)
|
||||
)
|
||||
yield {"data": tc_chunk.model_dump_json()}
|
||||
],
|
||||
)
|
||||
yield {"data": tc_chunk.model_dump_json()}
|
||||
|
||||
# Final chunk with finish reason and usage
|
||||
final_chunk = ChatCompletionChunk(
|
||||
|
||||
26
run.sh
26
run.sh
@@ -7,8 +7,28 @@ cd "$SCRIPT_DIR"
|
||||
# Activate virtual environment
|
||||
source .venv/bin/activate
|
||||
|
||||
# Default model – 4-bit quantized Gemma 3 4B IT (vision-capable)
|
||||
MODEL="${MODEL:-mlx-community/gemma-3-4b-it-4bit}"
|
||||
# --- Model selection ---
|
||||
# Usage: ./run.sh [gemma|qwen]
|
||||
# Or set MODEL env var directly for a custom model.
|
||||
|
||||
MODEL_CHOICE="${1:-gemma}"
|
||||
|
||||
if [[ -z "${MODEL:-}" ]]; then
|
||||
case "$MODEL_CHOICE" in
|
||||
gemma)
|
||||
MODEL="mlx-community/gemma-3-4b-it-4bit"
|
||||
;;
|
||||
qwen)
|
||||
MODEL="mlx-community/Qwen3-VL-4B-Instruct-4bit"
|
||||
;;
|
||||
*)
|
||||
echo "Unknown model choice: $MODEL_CHOICE"
|
||||
echo "Usage: $0 [gemma|qwen]"
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
fi
|
||||
|
||||
HOST="${HOST:-127.0.0.1}"
|
||||
PORT="${PORT:-1234}"
|
||||
|
||||
@@ -21,4 +41,4 @@ exec python -m mlx_server.main \
|
||||
--model "$MODEL" \
|
||||
--host "$HOST" \
|
||||
--port "$PORT" \
|
||||
"$@"
|
||||
"${@:2}"
|
||||
|
||||
Reference in New Issue
Block a user