diff --git a/CLAUDE.md b/CLAUDE.md index 65391d4..245d90a 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -38,7 +38,7 @@ python -m mlx_server.main --model mlx-community/Qwen3-VL-4B-Instruct-4bit --port - Model-specific prompt formatting: Gemma converts system→user/assistant pairs and uses `tool_code` blocks; Qwen3 uses native system role and `` XML tags - Offline-first: if the model is already cached locally (~/.cache/huggingface/hub/), the server resolves the local snapshot path directly — no network requests are made (HEAD checks, update checks, etc.) - Thread lock on generation (single-request-at-a-time) — MLX models aren't safe for concurrent generation -- 128k context window supported via the model's native capabilities +- Context window size is read from each model's config at load time (Gemma 3 4B: 128k, Qwen3-VL 4B: 256k) ## Dependencies diff --git a/README.md b/README.md index ddb2add..de13332 100644 --- a/README.md +++ b/README.md @@ -4,10 +4,10 @@ OpenAI-compatible API server for running local LLMs on Apple Silicon via [MLX](h ## Supported Models -| Alias | Model | Capabilities | -|-------|-------|-------------| -| `gemma` | `mlx-community/gemma-3-4b-it-4bit` | Vision, tool use (`tool_code` blocks) | -| `qwen` | `mlx-community/Qwen3-VL-4B-Instruct-4bit` | Vision, tool use (`` tags) | +| Alias | Model | Context | Capabilities | +|-------|-------|---------|-------------| +| `gemma` | `mlx-community/gemma-3-4b-it-4bit` | 128k | Vision, tool use (`tool_code` blocks) | +| `qwen` | `mlx-community/Qwen3-VL-4B-Instruct-4bit` | 256k | Vision, tool use (`` tags) | ## Quick Start @@ -30,7 +30,7 @@ The server starts at `http://127.0.0.1:1234`. Standard OpenAI-compatible endpoints: -- `GET /v1/models` — lists all available models +- `GET /v1/models` — lists all available models with `context_window` sizes - `POST /v1/chat/completions` — chat completions (streaming and non-streaming) - `GET /health` — health check @@ -67,6 +67,16 @@ Pass images as base64 data URIs or URLs in the `image_url` content part: } ``` +### Context Window Management + +Each model's context window is read from its HuggingFace config (`max_position_embeddings`) and reported in `/v1/models` via the `context_window` field. Clients can use this to manage conversation length proactively. + +If a request exceeds the context window, the server: + +1. Automatically summarizes older messages (keeping system messages and the last 6 messages intact) +2. Retries with the compressed conversation +3. Returns an OpenAI-compatible `context_length_exceeded` error if it still doesn't fit + ### Tool Use Pass tools in the `tools` field (OpenAI format). The server handles model-specific formatting and parses tool calls from the output automatically. @@ -94,4 +104,4 @@ mlx_server/ - Offline-first: if the model is cached locally (`~/.cache/huggingface/hub/`), no network requests are made - Thread lock on generation — MLX models aren't safe for concurrent generation - KV prefix caching for multi-turn conversations -- 128k context window via native model capabilities +- Context window read from each model's config (Gemma 3 4B: 128k, Qwen3-VL 4B: 256k) with automatic summarization fallback diff --git a/mlx_server/engine.py b/mlx_server/engine.py index a635d63..2048298 100644 --- a/mlx_server/engine.py +++ b/mlx_server/engine.py @@ -295,6 +295,62 @@ class InferenceEngine: def is_gemma(self) -> bool: return "gemma" in self._model_type + @property + def context_length(self) -> int: + """Max context length from the model config.""" + if self.config is None: + return 0 + # VLMs nest the LLM config under text_config + text_cfg = getattr(self.config, "text_config", self.config) + return getattr(text_cfg, "max_position_embeddings", 0) + + def count_tokens(self, text: str) -> int: + """Count tokens in a text string. Thread-safe, no lock needed.""" + tokenizer = self._get_tokenizer() + return len(tokenizer.encode(text)) + + def summarize_messages(self, messages: list[dict]) -> str: + """Summarize a list of conversation messages into a concise text. + + Calls generate() internally (acquires and releases the lock). + """ + # Build a readable transcript from the messages + transcript_lines = [] + for msg in messages: + role = msg.get("role", "unknown") + content = self._get_text_content(msg.get("content")) + # Include tool call info if present + if msg.get("tool_calls"): + tool_names = [ + tc.get("function", tc).get("name", "?") + for tc in msg["tool_calls"] + ] + content += f" [called tools: {', '.join(tool_names)}]" + if content.strip(): + transcript_lines.append(f"{role}: {content.strip()}") + + transcript = "\n".join(transcript_lines) + + summary_instruction = [{ + "role": "user", + "content": ( + "Summarize the following conversation concisely. " + "Preserve key facts, decisions, tool results, and context " + "needed to continue the conversation naturally. " + "Be brief but complete.\n\n" + f"\n{transcript}\n" + ), + }] + + prompt, _ = self.build_prompt(summary_instruction, tools=None) + summary_text, _, _ = self.generate( + prompt=prompt, + images=None, + max_tokens=1024, + temperature=0.2, + ) + return summary_text.strip() + # ------------------------------------------------------------------ # Image helpers # ------------------------------------------------------------------ @@ -1029,6 +1085,22 @@ class ModelManager: self._current_model = target return self._engine + def get_context_length(self, model_id: str) -> int | None: + """Get context length for a model from its cached config, without loading it.""" + local_path = _resolve_local_model_path(model_id) + if local_path is None: + return None + config_file = local_path / "config.json" + if not config_file.is_file(): + return None + try: + config = json.loads(config_file.read_text()) + # VLMs nest under text_config + text_cfg = config.get("text_config", config) + return text_cfg.get("max_position_embeddings") + except Exception: + return None + def preload(self, model: str | None = None) -> None: """Eagerly load a model at startup.""" self.get_engine(model) diff --git a/mlx_server/main.py b/mlx_server/main.py index 88cff57..da5a8f1 100644 --- a/mlx_server/main.py +++ b/mlx_server/main.py @@ -13,7 +13,7 @@ from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from sse_starlette.sse import EventSourceResponse -from .engine import DEFAULT_MODEL, ModelManager +from .engine import DEFAULT_MODEL, InferenceEngine, ModelManager from .models import ( ChatCompletionChunk, ChatCompletionRequest, @@ -43,6 +43,9 @@ app.add_middleware( manager: ModelManager | None = None +# Number of recent messages to always preserve when summarizing +_KEEP_RECENT = 6 + def get_engine(requested_model: str | None = None): if manager is None: @@ -54,6 +57,101 @@ def _make_id() -> str: return f"chatcmpl-{uuid.uuid4().hex[:12]}" +# ------------------------------------------------------------------ +# Context window management +# ------------------------------------------------------------------ + + +def _manage_context( + e: InferenceEngine, + messages: list[dict], + tools: list[dict] | None, + max_tokens: int, +) -> list[dict]: + """Check if messages fit in the context window; summarize if needed. + + Returns the (possibly summarized) message list. Raises HTTPException + with an OpenAI-compatible error if the conversation cannot fit. + """ + context_length = e.context_length + if context_length <= 0: + return messages # unknown context size, skip check + + prompt, _ = e.build_prompt(messages, tools) + prompt_tokens = e.count_tokens(prompt) + available = context_length - max_tokens + + if prompt_tokens <= available: + return messages + + # --- Need to summarize --- + logger.info( + "Context window pressure: %d prompt tokens + %d max_tokens = %d " + "(limit %d). Attempting summarization.", + prompt_tokens, max_tokens, prompt_tokens + max_tokens, context_length, + ) + + # Split messages: system | middle (summarizable) | recent (kept) + system_msgs = [m for m in messages if m.get("role") == "system"] + non_system = [m for m in messages if m.get("role") != "system"] + + if len(non_system) <= _KEEP_RECENT: + _raise_context_exceeded(prompt_tokens, max_tokens, context_length) + + recent = non_system[-_KEEP_RECENT:] + middle = non_system[:-_KEEP_RECENT] + + # Generate summary of the middle messages + summary_text = e.summarize_messages(middle) + + summary_msg = { + "role": "user", + "content": f"[Summary of earlier conversation]\n{summary_text}", + } + ack_msg = { + "role": "assistant", + "content": "Understood, I have the context from our earlier conversation.", + } + new_messages = system_msgs + [summary_msg, ack_msg] + recent + + # Re-check fit + new_prompt, _ = e.build_prompt(new_messages, tools) + new_prompt_tokens = e.count_tokens(new_prompt) + + if new_prompt_tokens + max_tokens > context_length: + logger.warning( + "Still over context limit after summarization: %d + %d = %d (limit %d)", + new_prompt_tokens, max_tokens, new_prompt_tokens + max_tokens, context_length, + ) + _raise_context_exceeded(new_prompt_tokens, max_tokens, context_length) + + logger.info( + "Summarization reduced prompt from %d to %d tokens (saved %d).", + prompt_tokens, new_prompt_tokens, prompt_tokens - new_prompt_tokens, + ) + return new_messages + + +def _raise_context_exceeded(prompt_tokens: int, max_tokens: int, context_length: int): + """Raise an OpenAI-compatible context_length_exceeded error.""" + raise HTTPException( + status_code=400, + detail={ + "error": { + "message": ( + f"This model's maximum context length is {context_length} tokens. " + f"However, your messages resulted in {prompt_tokens} tokens and " + f"{max_tokens} tokens were requested for the completion " + f"({prompt_tokens + max_tokens} total). " + f"Please reduce the length of the messages or completion." + ), + "type": "invalid_request_error", + "code": "context_length_exceeded", + } + }, + ) + + # ------------------------------------------------------------------ # Endpoints # ------------------------------------------------------------------ @@ -64,7 +162,13 @@ async def list_models() -> ModelListResponse: if manager is None: raise HTTPException(status_code=503, detail="Server not initialized") return ModelListResponse( - data=[ModelInfo(id=model_id) for model_id in manager.available_models] + data=[ + ModelInfo( + id=model_id, + context_window=manager.get_context_length(model_id), + ) + for model_id in manager.available_models + ] ) @@ -78,8 +182,6 @@ async def chat_completions(request: ChatCompletionRequest): if request.tools: tools = [t.model_dump(exclude_none=True) for t in request.tools] - prompt, images = e.build_prompt(messages, tools) - stop = request.stop if isinstance(stop, str): stop = [stop] @@ -88,6 +190,11 @@ async def chat_completions(request: ChatCompletionRequest): top_p = request.top_p if request.top_p is not None else 0.9 max_tokens = request.max_tokens if request.max_tokens is not None else 4096 + # Context window management: summarize if needed, error if impossible + messages = _manage_context(e, messages, tools, max_tokens) + + prompt, images = e.build_prompt(messages, tools) + if request.stream: return EventSourceResponse( _stream_response(e, prompt, images, max_tokens, temperature, top_p, stop, tools, request.model), diff --git a/mlx_server/models.py b/mlx_server/models.py index c5c2ea1..781b31b 100644 --- a/mlx_server/models.py +++ b/mlx_server/models.py @@ -137,6 +137,7 @@ class ModelInfo(BaseModel): object: str = "model" created: int = Field(default_factory=lambda: int(time.time())) owned_by: str = "local" + context_window: int | None = None class ModelListResponse(BaseModel):