feat: proper support for context size

This commit is contained in:
2026-03-17 12:34:11 +01:00
parent 540b187593
commit cc4f937d9a
5 changed files with 201 additions and 11 deletions

View File

@@ -13,7 +13,7 @@ from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from sse_starlette.sse import EventSourceResponse
from .engine import DEFAULT_MODEL, ModelManager
from .engine import DEFAULT_MODEL, InferenceEngine, ModelManager
from .models import (
ChatCompletionChunk,
ChatCompletionRequest,
@@ -43,6 +43,9 @@ app.add_middleware(
manager: ModelManager | None = None
# Number of recent messages to always preserve when summarizing
_KEEP_RECENT = 6
def get_engine(requested_model: str | None = None):
if manager is None:
@@ -54,6 +57,101 @@ def _make_id() -> str:
return f"chatcmpl-{uuid.uuid4().hex[:12]}"
# ------------------------------------------------------------------
# Context window management
# ------------------------------------------------------------------
def _manage_context(
e: InferenceEngine,
messages: list[dict],
tools: list[dict] | None,
max_tokens: int,
) -> list[dict]:
"""Check if messages fit in the context window; summarize if needed.
Returns the (possibly summarized) message list. Raises HTTPException
with an OpenAI-compatible error if the conversation cannot fit.
"""
context_length = e.context_length
if context_length <= 0:
return messages # unknown context size, skip check
prompt, _ = e.build_prompt(messages, tools)
prompt_tokens = e.count_tokens(prompt)
available = context_length - max_tokens
if prompt_tokens <= available:
return messages
# --- Need to summarize ---
logger.info(
"Context window pressure: %d prompt tokens + %d max_tokens = %d "
"(limit %d). Attempting summarization.",
prompt_tokens, max_tokens, prompt_tokens + max_tokens, context_length,
)
# Split messages: system | middle (summarizable) | recent (kept)
system_msgs = [m for m in messages if m.get("role") == "system"]
non_system = [m for m in messages if m.get("role") != "system"]
if len(non_system) <= _KEEP_RECENT:
_raise_context_exceeded(prompt_tokens, max_tokens, context_length)
recent = non_system[-_KEEP_RECENT:]
middle = non_system[:-_KEEP_RECENT]
# Generate summary of the middle messages
summary_text = e.summarize_messages(middle)
summary_msg = {
"role": "user",
"content": f"[Summary of earlier conversation]\n{summary_text}",
}
ack_msg = {
"role": "assistant",
"content": "Understood, I have the context from our earlier conversation.",
}
new_messages = system_msgs + [summary_msg, ack_msg] + recent
# Re-check fit
new_prompt, _ = e.build_prompt(new_messages, tools)
new_prompt_tokens = e.count_tokens(new_prompt)
if new_prompt_tokens + max_tokens > context_length:
logger.warning(
"Still over context limit after summarization: %d + %d = %d (limit %d)",
new_prompt_tokens, max_tokens, new_prompt_tokens + max_tokens, context_length,
)
_raise_context_exceeded(new_prompt_tokens, max_tokens, context_length)
logger.info(
"Summarization reduced prompt from %d to %d tokens (saved %d).",
prompt_tokens, new_prompt_tokens, prompt_tokens - new_prompt_tokens,
)
return new_messages
def _raise_context_exceeded(prompt_tokens: int, max_tokens: int, context_length: int):
"""Raise an OpenAI-compatible context_length_exceeded error."""
raise HTTPException(
status_code=400,
detail={
"error": {
"message": (
f"This model's maximum context length is {context_length} tokens. "
f"However, your messages resulted in {prompt_tokens} tokens and "
f"{max_tokens} tokens were requested for the completion "
f"({prompt_tokens + max_tokens} total). "
f"Please reduce the length of the messages or completion."
),
"type": "invalid_request_error",
"code": "context_length_exceeded",
}
},
)
# ------------------------------------------------------------------
# Endpoints
# ------------------------------------------------------------------
@@ -64,7 +162,13 @@ async def list_models() -> ModelListResponse:
if manager is None:
raise HTTPException(status_code=503, detail="Server not initialized")
return ModelListResponse(
data=[ModelInfo(id=model_id) for model_id in manager.available_models]
data=[
ModelInfo(
id=model_id,
context_window=manager.get_context_length(model_id),
)
for model_id in manager.available_models
]
)
@@ -78,8 +182,6 @@ async def chat_completions(request: ChatCompletionRequest):
if request.tools:
tools = [t.model_dump(exclude_none=True) for t in request.tools]
prompt, images = e.build_prompt(messages, tools)
stop = request.stop
if isinstance(stop, str):
stop = [stop]
@@ -88,6 +190,11 @@ async def chat_completions(request: ChatCompletionRequest):
top_p = request.top_p if request.top_p is not None else 0.9
max_tokens = request.max_tokens if request.max_tokens is not None else 4096
# Context window management: summarize if needed, error if impossible
messages = _manage_context(e, messages, tools, max_tokens)
prompt, images = e.build_prompt(messages, tools)
if request.stream:
return EventSourceResponse(
_stream_response(e, prompt, images, max_tokens, temperature, top_p, stop, tools, request.model),