feat: proper support for context size
This commit is contained in:
@@ -295,6 +295,62 @@ class InferenceEngine:
|
||||
def is_gemma(self) -> bool:
|
||||
return "gemma" in self._model_type
|
||||
|
||||
@property
|
||||
def context_length(self) -> int:
|
||||
"""Max context length from the model config."""
|
||||
if self.config is None:
|
||||
return 0
|
||||
# VLMs nest the LLM config under text_config
|
||||
text_cfg = getattr(self.config, "text_config", self.config)
|
||||
return getattr(text_cfg, "max_position_embeddings", 0)
|
||||
|
||||
def count_tokens(self, text: str) -> int:
|
||||
"""Count tokens in a text string. Thread-safe, no lock needed."""
|
||||
tokenizer = self._get_tokenizer()
|
||||
return len(tokenizer.encode(text))
|
||||
|
||||
def summarize_messages(self, messages: list[dict]) -> str:
|
||||
"""Summarize a list of conversation messages into a concise text.
|
||||
|
||||
Calls generate() internally (acquires and releases the lock).
|
||||
"""
|
||||
# Build a readable transcript from the messages
|
||||
transcript_lines = []
|
||||
for msg in messages:
|
||||
role = msg.get("role", "unknown")
|
||||
content = self._get_text_content(msg.get("content"))
|
||||
# Include tool call info if present
|
||||
if msg.get("tool_calls"):
|
||||
tool_names = [
|
||||
tc.get("function", tc).get("name", "?")
|
||||
for tc in msg["tool_calls"]
|
||||
]
|
||||
content += f" [called tools: {', '.join(tool_names)}]"
|
||||
if content.strip():
|
||||
transcript_lines.append(f"{role}: {content.strip()}")
|
||||
|
||||
transcript = "\n".join(transcript_lines)
|
||||
|
||||
summary_instruction = [{
|
||||
"role": "user",
|
||||
"content": (
|
||||
"Summarize the following conversation concisely. "
|
||||
"Preserve key facts, decisions, tool results, and context "
|
||||
"needed to continue the conversation naturally. "
|
||||
"Be brief but complete.\n\n"
|
||||
f"<conversation>\n{transcript}\n</conversation>"
|
||||
),
|
||||
}]
|
||||
|
||||
prompt, _ = self.build_prompt(summary_instruction, tools=None)
|
||||
summary_text, _, _ = self.generate(
|
||||
prompt=prompt,
|
||||
images=None,
|
||||
max_tokens=1024,
|
||||
temperature=0.2,
|
||||
)
|
||||
return summary_text.strip()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Image helpers
|
||||
# ------------------------------------------------------------------
|
||||
@@ -1029,6 +1085,22 @@ class ModelManager:
|
||||
self._current_model = target
|
||||
return self._engine
|
||||
|
||||
def get_context_length(self, model_id: str) -> int | None:
|
||||
"""Get context length for a model from its cached config, without loading it."""
|
||||
local_path = _resolve_local_model_path(model_id)
|
||||
if local_path is None:
|
||||
return None
|
||||
config_file = local_path / "config.json"
|
||||
if not config_file.is_file():
|
||||
return None
|
||||
try:
|
||||
config = json.loads(config_file.read_text())
|
||||
# VLMs nest under text_config
|
||||
text_cfg = config.get("text_config", config)
|
||||
return text_cfg.get("max_position_embeddings")
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def preload(self, model: str | None = None) -> None:
|
||||
"""Eagerly load a model at startup."""
|
||||
self.get_engine(model)
|
||||
|
||||
@@ -13,7 +13,7 @@ from fastapi import FastAPI, HTTPException
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from sse_starlette.sse import EventSourceResponse
|
||||
|
||||
from .engine import DEFAULT_MODEL, ModelManager
|
||||
from .engine import DEFAULT_MODEL, InferenceEngine, ModelManager
|
||||
from .models import (
|
||||
ChatCompletionChunk,
|
||||
ChatCompletionRequest,
|
||||
@@ -43,6 +43,9 @@ app.add_middleware(
|
||||
|
||||
manager: ModelManager | None = None
|
||||
|
||||
# Number of recent messages to always preserve when summarizing
|
||||
_KEEP_RECENT = 6
|
||||
|
||||
|
||||
def get_engine(requested_model: str | None = None):
|
||||
if manager is None:
|
||||
@@ -54,6 +57,101 @@ def _make_id() -> str:
|
||||
return f"chatcmpl-{uuid.uuid4().hex[:12]}"
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Context window management
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
def _manage_context(
|
||||
e: InferenceEngine,
|
||||
messages: list[dict],
|
||||
tools: list[dict] | None,
|
||||
max_tokens: int,
|
||||
) -> list[dict]:
|
||||
"""Check if messages fit in the context window; summarize if needed.
|
||||
|
||||
Returns the (possibly summarized) message list. Raises HTTPException
|
||||
with an OpenAI-compatible error if the conversation cannot fit.
|
||||
"""
|
||||
context_length = e.context_length
|
||||
if context_length <= 0:
|
||||
return messages # unknown context size, skip check
|
||||
|
||||
prompt, _ = e.build_prompt(messages, tools)
|
||||
prompt_tokens = e.count_tokens(prompt)
|
||||
available = context_length - max_tokens
|
||||
|
||||
if prompt_tokens <= available:
|
||||
return messages
|
||||
|
||||
# --- Need to summarize ---
|
||||
logger.info(
|
||||
"Context window pressure: %d prompt tokens + %d max_tokens = %d "
|
||||
"(limit %d). Attempting summarization.",
|
||||
prompt_tokens, max_tokens, prompt_tokens + max_tokens, context_length,
|
||||
)
|
||||
|
||||
# Split messages: system | middle (summarizable) | recent (kept)
|
||||
system_msgs = [m for m in messages if m.get("role") == "system"]
|
||||
non_system = [m for m in messages if m.get("role") != "system"]
|
||||
|
||||
if len(non_system) <= _KEEP_RECENT:
|
||||
_raise_context_exceeded(prompt_tokens, max_tokens, context_length)
|
||||
|
||||
recent = non_system[-_KEEP_RECENT:]
|
||||
middle = non_system[:-_KEEP_RECENT]
|
||||
|
||||
# Generate summary of the middle messages
|
||||
summary_text = e.summarize_messages(middle)
|
||||
|
||||
summary_msg = {
|
||||
"role": "user",
|
||||
"content": f"[Summary of earlier conversation]\n{summary_text}",
|
||||
}
|
||||
ack_msg = {
|
||||
"role": "assistant",
|
||||
"content": "Understood, I have the context from our earlier conversation.",
|
||||
}
|
||||
new_messages = system_msgs + [summary_msg, ack_msg] + recent
|
||||
|
||||
# Re-check fit
|
||||
new_prompt, _ = e.build_prompt(new_messages, tools)
|
||||
new_prompt_tokens = e.count_tokens(new_prompt)
|
||||
|
||||
if new_prompt_tokens + max_tokens > context_length:
|
||||
logger.warning(
|
||||
"Still over context limit after summarization: %d + %d = %d (limit %d)",
|
||||
new_prompt_tokens, max_tokens, new_prompt_tokens + max_tokens, context_length,
|
||||
)
|
||||
_raise_context_exceeded(new_prompt_tokens, max_tokens, context_length)
|
||||
|
||||
logger.info(
|
||||
"Summarization reduced prompt from %d to %d tokens (saved %d).",
|
||||
prompt_tokens, new_prompt_tokens, prompt_tokens - new_prompt_tokens,
|
||||
)
|
||||
return new_messages
|
||||
|
||||
|
||||
def _raise_context_exceeded(prompt_tokens: int, max_tokens: int, context_length: int):
|
||||
"""Raise an OpenAI-compatible context_length_exceeded error."""
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": {
|
||||
"message": (
|
||||
f"This model's maximum context length is {context_length} tokens. "
|
||||
f"However, your messages resulted in {prompt_tokens} tokens and "
|
||||
f"{max_tokens} tokens were requested for the completion "
|
||||
f"({prompt_tokens + max_tokens} total). "
|
||||
f"Please reduce the length of the messages or completion."
|
||||
),
|
||||
"type": "invalid_request_error",
|
||||
"code": "context_length_exceeded",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Endpoints
|
||||
# ------------------------------------------------------------------
|
||||
@@ -64,7 +162,13 @@ async def list_models() -> ModelListResponse:
|
||||
if manager is None:
|
||||
raise HTTPException(status_code=503, detail="Server not initialized")
|
||||
return ModelListResponse(
|
||||
data=[ModelInfo(id=model_id) for model_id in manager.available_models]
|
||||
data=[
|
||||
ModelInfo(
|
||||
id=model_id,
|
||||
context_window=manager.get_context_length(model_id),
|
||||
)
|
||||
for model_id in manager.available_models
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@@ -78,8 +182,6 @@ async def chat_completions(request: ChatCompletionRequest):
|
||||
if request.tools:
|
||||
tools = [t.model_dump(exclude_none=True) for t in request.tools]
|
||||
|
||||
prompt, images = e.build_prompt(messages, tools)
|
||||
|
||||
stop = request.stop
|
||||
if isinstance(stop, str):
|
||||
stop = [stop]
|
||||
@@ -88,6 +190,11 @@ async def chat_completions(request: ChatCompletionRequest):
|
||||
top_p = request.top_p if request.top_p is not None else 0.9
|
||||
max_tokens = request.max_tokens if request.max_tokens is not None else 4096
|
||||
|
||||
# Context window management: summarize if needed, error if impossible
|
||||
messages = _manage_context(e, messages, tools, max_tokens)
|
||||
|
||||
prompt, images = e.build_prompt(messages, tools)
|
||||
|
||||
if request.stream:
|
||||
return EventSourceResponse(
|
||||
_stream_response(e, prompt, images, max_tokens, temperature, top_p, stop, tools, request.model),
|
||||
|
||||
@@ -137,6 +137,7 @@ class ModelInfo(BaseModel):
|
||||
object: str = "model"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
owned_by: str = "local"
|
||||
context_window: int | None = None
|
||||
|
||||
|
||||
class ModelListResponse(BaseModel):
|
||||
|
||||
Reference in New Issue
Block a user