feat: proper support for context size

This commit is contained in:
2026-03-17 12:34:11 +01:00
parent 540b187593
commit cc4f937d9a
5 changed files with 201 additions and 11 deletions

View File

@@ -38,7 +38,7 @@ python -m mlx_server.main --model mlx-community/Qwen3-VL-4B-Instruct-4bit --port
- Model-specific prompt formatting: Gemma converts system→user/assistant pairs and uses `tool_code` blocks; Qwen3 uses native system role and `<tool_call>` XML tags - Model-specific prompt formatting: Gemma converts system→user/assistant pairs and uses `tool_code` blocks; Qwen3 uses native system role and `<tool_call>` XML tags
- Offline-first: if the model is already cached locally (~/.cache/huggingface/hub/), the server resolves the local snapshot path directly — no network requests are made (HEAD checks, update checks, etc.) - Offline-first: if the model is already cached locally (~/.cache/huggingface/hub/), the server resolves the local snapshot path directly — no network requests are made (HEAD checks, update checks, etc.)
- Thread lock on generation (single-request-at-a-time) — MLX models aren't safe for concurrent generation - Thread lock on generation (single-request-at-a-time) — MLX models aren't safe for concurrent generation
- 128k context window supported via the model's native capabilities - Context window size is read from each model's config at load time (Gemma 3 4B: 128k, Qwen3-VL 4B: 256k)
## Dependencies ## Dependencies

View File

@@ -4,10 +4,10 @@ OpenAI-compatible API server for running local LLMs on Apple Silicon via [MLX](h
## Supported Models ## Supported Models
| Alias | Model | Capabilities | | Alias | Model | Context | Capabilities |
|-------|-------|-------------| |-------|-------|---------|-------------|
| `gemma` | `mlx-community/gemma-3-4b-it-4bit` | Vision, tool use (`tool_code` blocks) | | `gemma` | `mlx-community/gemma-3-4b-it-4bit` | 128k | Vision, tool use (`tool_code` blocks) |
| `qwen` | `mlx-community/Qwen3-VL-4B-Instruct-4bit` | Vision, tool use (`<tool_call>` tags) | | `qwen` | `mlx-community/Qwen3-VL-4B-Instruct-4bit` | 256k | Vision, tool use (`<tool_call>` tags) |
## Quick Start ## Quick Start
@@ -30,7 +30,7 @@ The server starts at `http://127.0.0.1:1234`.
Standard OpenAI-compatible endpoints: Standard OpenAI-compatible endpoints:
- `GET /v1/models` — lists all available models - `GET /v1/models` — lists all available models with `context_window` sizes
- `POST /v1/chat/completions` — chat completions (streaming and non-streaming) - `POST /v1/chat/completions` — chat completions (streaming and non-streaming)
- `GET /health` — health check - `GET /health` — health check
@@ -67,6 +67,16 @@ Pass images as base64 data URIs or URLs in the `image_url` content part:
} }
``` ```
### Context Window Management
Each model's context window is read from its HuggingFace config (`max_position_embeddings`) and reported in `/v1/models` via the `context_window` field. Clients can use this to manage conversation length proactively.
If a request exceeds the context window, the server:
1. Automatically summarizes older messages (keeping system messages and the last 6 messages intact)
2. Retries with the compressed conversation
3. Returns an OpenAI-compatible `context_length_exceeded` error if it still doesn't fit
### Tool Use ### Tool Use
Pass tools in the `tools` field (OpenAI format). The server handles model-specific formatting and parses tool calls from the output automatically. Pass tools in the `tools` field (OpenAI format). The server handles model-specific formatting and parses tool calls from the output automatically.
@@ -94,4 +104,4 @@ mlx_server/
- Offline-first: if the model is cached locally (`~/.cache/huggingface/hub/`), no network requests are made - Offline-first: if the model is cached locally (`~/.cache/huggingface/hub/`), no network requests are made
- Thread lock on generation — MLX models aren't safe for concurrent generation - Thread lock on generation — MLX models aren't safe for concurrent generation
- KV prefix caching for multi-turn conversations - KV prefix caching for multi-turn conversations
- 128k context window via native model capabilities - Context window read from each model's config (Gemma 3 4B: 128k, Qwen3-VL 4B: 256k) with automatic summarization fallback

View File

@@ -295,6 +295,62 @@ class InferenceEngine:
def is_gemma(self) -> bool: def is_gemma(self) -> bool:
return "gemma" in self._model_type return "gemma" in self._model_type
@property
def context_length(self) -> int:
"""Max context length from the model config."""
if self.config is None:
return 0
# VLMs nest the LLM config under text_config
text_cfg = getattr(self.config, "text_config", self.config)
return getattr(text_cfg, "max_position_embeddings", 0)
def count_tokens(self, text: str) -> int:
"""Count tokens in a text string. Thread-safe, no lock needed."""
tokenizer = self._get_tokenizer()
return len(tokenizer.encode(text))
def summarize_messages(self, messages: list[dict]) -> str:
"""Summarize a list of conversation messages into a concise text.
Calls generate() internally (acquires and releases the lock).
"""
# Build a readable transcript from the messages
transcript_lines = []
for msg in messages:
role = msg.get("role", "unknown")
content = self._get_text_content(msg.get("content"))
# Include tool call info if present
if msg.get("tool_calls"):
tool_names = [
tc.get("function", tc).get("name", "?")
for tc in msg["tool_calls"]
]
content += f" [called tools: {', '.join(tool_names)}]"
if content.strip():
transcript_lines.append(f"{role}: {content.strip()}")
transcript = "\n".join(transcript_lines)
summary_instruction = [{
"role": "user",
"content": (
"Summarize the following conversation concisely. "
"Preserve key facts, decisions, tool results, and context "
"needed to continue the conversation naturally. "
"Be brief but complete.\n\n"
f"<conversation>\n{transcript}\n</conversation>"
),
}]
prompt, _ = self.build_prompt(summary_instruction, tools=None)
summary_text, _, _ = self.generate(
prompt=prompt,
images=None,
max_tokens=1024,
temperature=0.2,
)
return summary_text.strip()
# ------------------------------------------------------------------ # ------------------------------------------------------------------
# Image helpers # Image helpers
# ------------------------------------------------------------------ # ------------------------------------------------------------------
@@ -1029,6 +1085,22 @@ class ModelManager:
self._current_model = target self._current_model = target
return self._engine return self._engine
def get_context_length(self, model_id: str) -> int | None:
"""Get context length for a model from its cached config, without loading it."""
local_path = _resolve_local_model_path(model_id)
if local_path is None:
return None
config_file = local_path / "config.json"
if not config_file.is_file():
return None
try:
config = json.loads(config_file.read_text())
# VLMs nest under text_config
text_cfg = config.get("text_config", config)
return text_cfg.get("max_position_embeddings")
except Exception:
return None
def preload(self, model: str | None = None) -> None: def preload(self, model: str | None = None) -> None:
"""Eagerly load a model at startup.""" """Eagerly load a model at startup."""
self.get_engine(model) self.get_engine(model)

View File

@@ -13,7 +13,7 @@ from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from sse_starlette.sse import EventSourceResponse from sse_starlette.sse import EventSourceResponse
from .engine import DEFAULT_MODEL, ModelManager from .engine import DEFAULT_MODEL, InferenceEngine, ModelManager
from .models import ( from .models import (
ChatCompletionChunk, ChatCompletionChunk,
ChatCompletionRequest, ChatCompletionRequest,
@@ -43,6 +43,9 @@ app.add_middleware(
manager: ModelManager | None = None manager: ModelManager | None = None
# Number of recent messages to always preserve when summarizing
_KEEP_RECENT = 6
def get_engine(requested_model: str | None = None): def get_engine(requested_model: str | None = None):
if manager is None: if manager is None:
@@ -54,6 +57,101 @@ def _make_id() -> str:
return f"chatcmpl-{uuid.uuid4().hex[:12]}" return f"chatcmpl-{uuid.uuid4().hex[:12]}"
# ------------------------------------------------------------------
# Context window management
# ------------------------------------------------------------------
def _manage_context(
e: InferenceEngine,
messages: list[dict],
tools: list[dict] | None,
max_tokens: int,
) -> list[dict]:
"""Check if messages fit in the context window; summarize if needed.
Returns the (possibly summarized) message list. Raises HTTPException
with an OpenAI-compatible error if the conversation cannot fit.
"""
context_length = e.context_length
if context_length <= 0:
return messages # unknown context size, skip check
prompt, _ = e.build_prompt(messages, tools)
prompt_tokens = e.count_tokens(prompt)
available = context_length - max_tokens
if prompt_tokens <= available:
return messages
# --- Need to summarize ---
logger.info(
"Context window pressure: %d prompt tokens + %d max_tokens = %d "
"(limit %d). Attempting summarization.",
prompt_tokens, max_tokens, prompt_tokens + max_tokens, context_length,
)
# Split messages: system | middle (summarizable) | recent (kept)
system_msgs = [m for m in messages if m.get("role") == "system"]
non_system = [m for m in messages if m.get("role") != "system"]
if len(non_system) <= _KEEP_RECENT:
_raise_context_exceeded(prompt_tokens, max_tokens, context_length)
recent = non_system[-_KEEP_RECENT:]
middle = non_system[:-_KEEP_RECENT]
# Generate summary of the middle messages
summary_text = e.summarize_messages(middle)
summary_msg = {
"role": "user",
"content": f"[Summary of earlier conversation]\n{summary_text}",
}
ack_msg = {
"role": "assistant",
"content": "Understood, I have the context from our earlier conversation.",
}
new_messages = system_msgs + [summary_msg, ack_msg] + recent
# Re-check fit
new_prompt, _ = e.build_prompt(new_messages, tools)
new_prompt_tokens = e.count_tokens(new_prompt)
if new_prompt_tokens + max_tokens > context_length:
logger.warning(
"Still over context limit after summarization: %d + %d = %d (limit %d)",
new_prompt_tokens, max_tokens, new_prompt_tokens + max_tokens, context_length,
)
_raise_context_exceeded(new_prompt_tokens, max_tokens, context_length)
logger.info(
"Summarization reduced prompt from %d to %d tokens (saved %d).",
prompt_tokens, new_prompt_tokens, prompt_tokens - new_prompt_tokens,
)
return new_messages
def _raise_context_exceeded(prompt_tokens: int, max_tokens: int, context_length: int):
"""Raise an OpenAI-compatible context_length_exceeded error."""
raise HTTPException(
status_code=400,
detail={
"error": {
"message": (
f"This model's maximum context length is {context_length} tokens. "
f"However, your messages resulted in {prompt_tokens} tokens and "
f"{max_tokens} tokens were requested for the completion "
f"({prompt_tokens + max_tokens} total). "
f"Please reduce the length of the messages or completion."
),
"type": "invalid_request_error",
"code": "context_length_exceeded",
}
},
)
# ------------------------------------------------------------------ # ------------------------------------------------------------------
# Endpoints # Endpoints
# ------------------------------------------------------------------ # ------------------------------------------------------------------
@@ -64,7 +162,13 @@ async def list_models() -> ModelListResponse:
if manager is None: if manager is None:
raise HTTPException(status_code=503, detail="Server not initialized") raise HTTPException(status_code=503, detail="Server not initialized")
return ModelListResponse( return ModelListResponse(
data=[ModelInfo(id=model_id) for model_id in manager.available_models] data=[
ModelInfo(
id=model_id,
context_window=manager.get_context_length(model_id),
)
for model_id in manager.available_models
]
) )
@@ -78,8 +182,6 @@ async def chat_completions(request: ChatCompletionRequest):
if request.tools: if request.tools:
tools = [t.model_dump(exclude_none=True) for t in request.tools] tools = [t.model_dump(exclude_none=True) for t in request.tools]
prompt, images = e.build_prompt(messages, tools)
stop = request.stop stop = request.stop
if isinstance(stop, str): if isinstance(stop, str):
stop = [stop] stop = [stop]
@@ -88,6 +190,11 @@ async def chat_completions(request: ChatCompletionRequest):
top_p = request.top_p if request.top_p is not None else 0.9 top_p = request.top_p if request.top_p is not None else 0.9
max_tokens = request.max_tokens if request.max_tokens is not None else 4096 max_tokens = request.max_tokens if request.max_tokens is not None else 4096
# Context window management: summarize if needed, error if impossible
messages = _manage_context(e, messages, tools, max_tokens)
prompt, images = e.build_prompt(messages, tools)
if request.stream: if request.stream:
return EventSourceResponse( return EventSourceResponse(
_stream_response(e, prompt, images, max_tokens, temperature, top_p, stop, tools, request.model), _stream_response(e, prompt, images, max_tokens, temperature, top_p, stop, tools, request.model),

View File

@@ -137,6 +137,7 @@ class ModelInfo(BaseModel):
object: str = "model" object: str = "model"
created: int = Field(default_factory=lambda: int(time.time())) created: int = Field(default_factory=lambda: int(time.time()))
owned_by: str = "local" owned_by: str = "local"
context_window: int | None = None
class ModelListResponse(BaseModel): class ModelListResponse(BaseModel):