feat: proper support for context size
This commit is contained in:
@@ -38,7 +38,7 @@ python -m mlx_server.main --model mlx-community/Qwen3-VL-4B-Instruct-4bit --port
|
|||||||
- Model-specific prompt formatting: Gemma converts system→user/assistant pairs and uses `tool_code` blocks; Qwen3 uses native system role and `<tool_call>` XML tags
|
- Model-specific prompt formatting: Gemma converts system→user/assistant pairs and uses `tool_code` blocks; Qwen3 uses native system role and `<tool_call>` XML tags
|
||||||
- Offline-first: if the model is already cached locally (~/.cache/huggingface/hub/), the server resolves the local snapshot path directly — no network requests are made (HEAD checks, update checks, etc.)
|
- Offline-first: if the model is already cached locally (~/.cache/huggingface/hub/), the server resolves the local snapshot path directly — no network requests are made (HEAD checks, update checks, etc.)
|
||||||
- Thread lock on generation (single-request-at-a-time) — MLX models aren't safe for concurrent generation
|
- Thread lock on generation (single-request-at-a-time) — MLX models aren't safe for concurrent generation
|
||||||
- 128k context window supported via the model's native capabilities
|
- Context window size is read from each model's config at load time (Gemma 3 4B: 128k, Qwen3-VL 4B: 256k)
|
||||||
|
|
||||||
## Dependencies
|
## Dependencies
|
||||||
|
|
||||||
|
|||||||
22
README.md
22
README.md
@@ -4,10 +4,10 @@ OpenAI-compatible API server for running local LLMs on Apple Silicon via [MLX](h
|
|||||||
|
|
||||||
## Supported Models
|
## Supported Models
|
||||||
|
|
||||||
| Alias | Model | Capabilities |
|
| Alias | Model | Context | Capabilities |
|
||||||
|-------|-------|-------------|
|
|-------|-------|---------|-------------|
|
||||||
| `gemma` | `mlx-community/gemma-3-4b-it-4bit` | Vision, tool use (`tool_code` blocks) |
|
| `gemma` | `mlx-community/gemma-3-4b-it-4bit` | 128k | Vision, tool use (`tool_code` blocks) |
|
||||||
| `qwen` | `mlx-community/Qwen3-VL-4B-Instruct-4bit` | Vision, tool use (`<tool_call>` tags) |
|
| `qwen` | `mlx-community/Qwen3-VL-4B-Instruct-4bit` | 256k | Vision, tool use (`<tool_call>` tags) |
|
||||||
|
|
||||||
## Quick Start
|
## Quick Start
|
||||||
|
|
||||||
@@ -30,7 +30,7 @@ The server starts at `http://127.0.0.1:1234`.
|
|||||||
|
|
||||||
Standard OpenAI-compatible endpoints:
|
Standard OpenAI-compatible endpoints:
|
||||||
|
|
||||||
- `GET /v1/models` — lists all available models
|
- `GET /v1/models` — lists all available models with `context_window` sizes
|
||||||
- `POST /v1/chat/completions` — chat completions (streaming and non-streaming)
|
- `POST /v1/chat/completions` — chat completions (streaming and non-streaming)
|
||||||
- `GET /health` — health check
|
- `GET /health` — health check
|
||||||
|
|
||||||
@@ -67,6 +67,16 @@ Pass images as base64 data URIs or URLs in the `image_url` content part:
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Context Window Management
|
||||||
|
|
||||||
|
Each model's context window is read from its HuggingFace config (`max_position_embeddings`) and reported in `/v1/models` via the `context_window` field. Clients can use this to manage conversation length proactively.
|
||||||
|
|
||||||
|
If a request exceeds the context window, the server:
|
||||||
|
|
||||||
|
1. Automatically summarizes older messages (keeping system messages and the last 6 messages intact)
|
||||||
|
2. Retries with the compressed conversation
|
||||||
|
3. Returns an OpenAI-compatible `context_length_exceeded` error if it still doesn't fit
|
||||||
|
|
||||||
### Tool Use
|
### Tool Use
|
||||||
|
|
||||||
Pass tools in the `tools` field (OpenAI format). The server handles model-specific formatting and parses tool calls from the output automatically.
|
Pass tools in the `tools` field (OpenAI format). The server handles model-specific formatting and parses tool calls from the output automatically.
|
||||||
@@ -94,4 +104,4 @@ mlx_server/
|
|||||||
- Offline-first: if the model is cached locally (`~/.cache/huggingface/hub/`), no network requests are made
|
- Offline-first: if the model is cached locally (`~/.cache/huggingface/hub/`), no network requests are made
|
||||||
- Thread lock on generation — MLX models aren't safe for concurrent generation
|
- Thread lock on generation — MLX models aren't safe for concurrent generation
|
||||||
- KV prefix caching for multi-turn conversations
|
- KV prefix caching for multi-turn conversations
|
||||||
- 128k context window via native model capabilities
|
- Context window read from each model's config (Gemma 3 4B: 128k, Qwen3-VL 4B: 256k) with automatic summarization fallback
|
||||||
|
|||||||
@@ -295,6 +295,62 @@ class InferenceEngine:
|
|||||||
def is_gemma(self) -> bool:
|
def is_gemma(self) -> bool:
|
||||||
return "gemma" in self._model_type
|
return "gemma" in self._model_type
|
||||||
|
|
||||||
|
@property
|
||||||
|
def context_length(self) -> int:
|
||||||
|
"""Max context length from the model config."""
|
||||||
|
if self.config is None:
|
||||||
|
return 0
|
||||||
|
# VLMs nest the LLM config under text_config
|
||||||
|
text_cfg = getattr(self.config, "text_config", self.config)
|
||||||
|
return getattr(text_cfg, "max_position_embeddings", 0)
|
||||||
|
|
||||||
|
def count_tokens(self, text: str) -> int:
|
||||||
|
"""Count tokens in a text string. Thread-safe, no lock needed."""
|
||||||
|
tokenizer = self._get_tokenizer()
|
||||||
|
return len(tokenizer.encode(text))
|
||||||
|
|
||||||
|
def summarize_messages(self, messages: list[dict]) -> str:
|
||||||
|
"""Summarize a list of conversation messages into a concise text.
|
||||||
|
|
||||||
|
Calls generate() internally (acquires and releases the lock).
|
||||||
|
"""
|
||||||
|
# Build a readable transcript from the messages
|
||||||
|
transcript_lines = []
|
||||||
|
for msg in messages:
|
||||||
|
role = msg.get("role", "unknown")
|
||||||
|
content = self._get_text_content(msg.get("content"))
|
||||||
|
# Include tool call info if present
|
||||||
|
if msg.get("tool_calls"):
|
||||||
|
tool_names = [
|
||||||
|
tc.get("function", tc).get("name", "?")
|
||||||
|
for tc in msg["tool_calls"]
|
||||||
|
]
|
||||||
|
content += f" [called tools: {', '.join(tool_names)}]"
|
||||||
|
if content.strip():
|
||||||
|
transcript_lines.append(f"{role}: {content.strip()}")
|
||||||
|
|
||||||
|
transcript = "\n".join(transcript_lines)
|
||||||
|
|
||||||
|
summary_instruction = [{
|
||||||
|
"role": "user",
|
||||||
|
"content": (
|
||||||
|
"Summarize the following conversation concisely. "
|
||||||
|
"Preserve key facts, decisions, tool results, and context "
|
||||||
|
"needed to continue the conversation naturally. "
|
||||||
|
"Be brief but complete.\n\n"
|
||||||
|
f"<conversation>\n{transcript}\n</conversation>"
|
||||||
|
),
|
||||||
|
}]
|
||||||
|
|
||||||
|
prompt, _ = self.build_prompt(summary_instruction, tools=None)
|
||||||
|
summary_text, _, _ = self.generate(
|
||||||
|
prompt=prompt,
|
||||||
|
images=None,
|
||||||
|
max_tokens=1024,
|
||||||
|
temperature=0.2,
|
||||||
|
)
|
||||||
|
return summary_text.strip()
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# Image helpers
|
# Image helpers
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
@@ -1029,6 +1085,22 @@ class ModelManager:
|
|||||||
self._current_model = target
|
self._current_model = target
|
||||||
return self._engine
|
return self._engine
|
||||||
|
|
||||||
|
def get_context_length(self, model_id: str) -> int | None:
|
||||||
|
"""Get context length for a model from its cached config, without loading it."""
|
||||||
|
local_path = _resolve_local_model_path(model_id)
|
||||||
|
if local_path is None:
|
||||||
|
return None
|
||||||
|
config_file = local_path / "config.json"
|
||||||
|
if not config_file.is_file():
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
config = json.loads(config_file.read_text())
|
||||||
|
# VLMs nest under text_config
|
||||||
|
text_cfg = config.get("text_config", config)
|
||||||
|
return text_cfg.get("max_position_embeddings")
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
def preload(self, model: str | None = None) -> None:
|
def preload(self, model: str | None = None) -> None:
|
||||||
"""Eagerly load a model at startup."""
|
"""Eagerly load a model at startup."""
|
||||||
self.get_engine(model)
|
self.get_engine(model)
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ from fastapi import FastAPI, HTTPException
|
|||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from sse_starlette.sse import EventSourceResponse
|
from sse_starlette.sse import EventSourceResponse
|
||||||
|
|
||||||
from .engine import DEFAULT_MODEL, ModelManager
|
from .engine import DEFAULT_MODEL, InferenceEngine, ModelManager
|
||||||
from .models import (
|
from .models import (
|
||||||
ChatCompletionChunk,
|
ChatCompletionChunk,
|
||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
@@ -43,6 +43,9 @@ app.add_middleware(
|
|||||||
|
|
||||||
manager: ModelManager | None = None
|
manager: ModelManager | None = None
|
||||||
|
|
||||||
|
# Number of recent messages to always preserve when summarizing
|
||||||
|
_KEEP_RECENT = 6
|
||||||
|
|
||||||
|
|
||||||
def get_engine(requested_model: str | None = None):
|
def get_engine(requested_model: str | None = None):
|
||||||
if manager is None:
|
if manager is None:
|
||||||
@@ -54,6 +57,101 @@ def _make_id() -> str:
|
|||||||
return f"chatcmpl-{uuid.uuid4().hex[:12]}"
|
return f"chatcmpl-{uuid.uuid4().hex[:12]}"
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Context window management
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _manage_context(
|
||||||
|
e: InferenceEngine,
|
||||||
|
messages: list[dict],
|
||||||
|
tools: list[dict] | None,
|
||||||
|
max_tokens: int,
|
||||||
|
) -> list[dict]:
|
||||||
|
"""Check if messages fit in the context window; summarize if needed.
|
||||||
|
|
||||||
|
Returns the (possibly summarized) message list. Raises HTTPException
|
||||||
|
with an OpenAI-compatible error if the conversation cannot fit.
|
||||||
|
"""
|
||||||
|
context_length = e.context_length
|
||||||
|
if context_length <= 0:
|
||||||
|
return messages # unknown context size, skip check
|
||||||
|
|
||||||
|
prompt, _ = e.build_prompt(messages, tools)
|
||||||
|
prompt_tokens = e.count_tokens(prompt)
|
||||||
|
available = context_length - max_tokens
|
||||||
|
|
||||||
|
if prompt_tokens <= available:
|
||||||
|
return messages
|
||||||
|
|
||||||
|
# --- Need to summarize ---
|
||||||
|
logger.info(
|
||||||
|
"Context window pressure: %d prompt tokens + %d max_tokens = %d "
|
||||||
|
"(limit %d). Attempting summarization.",
|
||||||
|
prompt_tokens, max_tokens, prompt_tokens + max_tokens, context_length,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Split messages: system | middle (summarizable) | recent (kept)
|
||||||
|
system_msgs = [m for m in messages if m.get("role") == "system"]
|
||||||
|
non_system = [m for m in messages if m.get("role") != "system"]
|
||||||
|
|
||||||
|
if len(non_system) <= _KEEP_RECENT:
|
||||||
|
_raise_context_exceeded(prompt_tokens, max_tokens, context_length)
|
||||||
|
|
||||||
|
recent = non_system[-_KEEP_RECENT:]
|
||||||
|
middle = non_system[:-_KEEP_RECENT]
|
||||||
|
|
||||||
|
# Generate summary of the middle messages
|
||||||
|
summary_text = e.summarize_messages(middle)
|
||||||
|
|
||||||
|
summary_msg = {
|
||||||
|
"role": "user",
|
||||||
|
"content": f"[Summary of earlier conversation]\n{summary_text}",
|
||||||
|
}
|
||||||
|
ack_msg = {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "Understood, I have the context from our earlier conversation.",
|
||||||
|
}
|
||||||
|
new_messages = system_msgs + [summary_msg, ack_msg] + recent
|
||||||
|
|
||||||
|
# Re-check fit
|
||||||
|
new_prompt, _ = e.build_prompt(new_messages, tools)
|
||||||
|
new_prompt_tokens = e.count_tokens(new_prompt)
|
||||||
|
|
||||||
|
if new_prompt_tokens + max_tokens > context_length:
|
||||||
|
logger.warning(
|
||||||
|
"Still over context limit after summarization: %d + %d = %d (limit %d)",
|
||||||
|
new_prompt_tokens, max_tokens, new_prompt_tokens + max_tokens, context_length,
|
||||||
|
)
|
||||||
|
_raise_context_exceeded(new_prompt_tokens, max_tokens, context_length)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Summarization reduced prompt from %d to %d tokens (saved %d).",
|
||||||
|
prompt_tokens, new_prompt_tokens, prompt_tokens - new_prompt_tokens,
|
||||||
|
)
|
||||||
|
return new_messages
|
||||||
|
|
||||||
|
|
||||||
|
def _raise_context_exceeded(prompt_tokens: int, max_tokens: int, context_length: int):
|
||||||
|
"""Raise an OpenAI-compatible context_length_exceeded error."""
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail={
|
||||||
|
"error": {
|
||||||
|
"message": (
|
||||||
|
f"This model's maximum context length is {context_length} tokens. "
|
||||||
|
f"However, your messages resulted in {prompt_tokens} tokens and "
|
||||||
|
f"{max_tokens} tokens were requested for the completion "
|
||||||
|
f"({prompt_tokens + max_tokens} total). "
|
||||||
|
f"Please reduce the length of the messages or completion."
|
||||||
|
),
|
||||||
|
"type": "invalid_request_error",
|
||||||
|
"code": "context_length_exceeded",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# Endpoints
|
# Endpoints
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
@@ -64,7 +162,13 @@ async def list_models() -> ModelListResponse:
|
|||||||
if manager is None:
|
if manager is None:
|
||||||
raise HTTPException(status_code=503, detail="Server not initialized")
|
raise HTTPException(status_code=503, detail="Server not initialized")
|
||||||
return ModelListResponse(
|
return ModelListResponse(
|
||||||
data=[ModelInfo(id=model_id) for model_id in manager.available_models]
|
data=[
|
||||||
|
ModelInfo(
|
||||||
|
id=model_id,
|
||||||
|
context_window=manager.get_context_length(model_id),
|
||||||
|
)
|
||||||
|
for model_id in manager.available_models
|
||||||
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -78,8 +182,6 @@ async def chat_completions(request: ChatCompletionRequest):
|
|||||||
if request.tools:
|
if request.tools:
|
||||||
tools = [t.model_dump(exclude_none=True) for t in request.tools]
|
tools = [t.model_dump(exclude_none=True) for t in request.tools]
|
||||||
|
|
||||||
prompt, images = e.build_prompt(messages, tools)
|
|
||||||
|
|
||||||
stop = request.stop
|
stop = request.stop
|
||||||
if isinstance(stop, str):
|
if isinstance(stop, str):
|
||||||
stop = [stop]
|
stop = [stop]
|
||||||
@@ -88,6 +190,11 @@ async def chat_completions(request: ChatCompletionRequest):
|
|||||||
top_p = request.top_p if request.top_p is not None else 0.9
|
top_p = request.top_p if request.top_p is not None else 0.9
|
||||||
max_tokens = request.max_tokens if request.max_tokens is not None else 4096
|
max_tokens = request.max_tokens if request.max_tokens is not None else 4096
|
||||||
|
|
||||||
|
# Context window management: summarize if needed, error if impossible
|
||||||
|
messages = _manage_context(e, messages, tools, max_tokens)
|
||||||
|
|
||||||
|
prompt, images = e.build_prompt(messages, tools)
|
||||||
|
|
||||||
if request.stream:
|
if request.stream:
|
||||||
return EventSourceResponse(
|
return EventSourceResponse(
|
||||||
_stream_response(e, prompt, images, max_tokens, temperature, top_p, stop, tools, request.model),
|
_stream_response(e, prompt, images, max_tokens, temperature, top_p, stop, tools, request.model),
|
||||||
|
|||||||
@@ -137,6 +137,7 @@ class ModelInfo(BaseModel):
|
|||||||
object: str = "model"
|
object: str = "model"
|
||||||
created: int = Field(default_factory=lambda: int(time.time()))
|
created: int = Field(default_factory=lambda: int(time.time()))
|
||||||
owned_by: str = "local"
|
owned_by: str = "local"
|
||||||
|
context_window: int | None = None
|
||||||
|
|
||||||
|
|
||||||
class ModelListResponse(BaseModel):
|
class ModelListResponse(BaseModel):
|
||||||
|
|||||||
Reference in New Issue
Block a user