feat: proper support for context size

This commit is contained in:
2026-03-17 12:34:11 +01:00
parent 540b187593
commit cc4f937d9a
5 changed files with 201 additions and 11 deletions

View File

@@ -295,6 +295,62 @@ class InferenceEngine:
def is_gemma(self) -> bool:
return "gemma" in self._model_type
@property
def context_length(self) -> int:
"""Max context length from the model config."""
if self.config is None:
return 0
# VLMs nest the LLM config under text_config
text_cfg = getattr(self.config, "text_config", self.config)
return getattr(text_cfg, "max_position_embeddings", 0)
def count_tokens(self, text: str) -> int:
"""Count tokens in a text string. Thread-safe, no lock needed."""
tokenizer = self._get_tokenizer()
return len(tokenizer.encode(text))
def summarize_messages(self, messages: list[dict]) -> str:
"""Summarize a list of conversation messages into a concise text.
Calls generate() internally (acquires and releases the lock).
"""
# Build a readable transcript from the messages
transcript_lines = []
for msg in messages:
role = msg.get("role", "unknown")
content = self._get_text_content(msg.get("content"))
# Include tool call info if present
if msg.get("tool_calls"):
tool_names = [
tc.get("function", tc).get("name", "?")
for tc in msg["tool_calls"]
]
content += f" [called tools: {', '.join(tool_names)}]"
if content.strip():
transcript_lines.append(f"{role}: {content.strip()}")
transcript = "\n".join(transcript_lines)
summary_instruction = [{
"role": "user",
"content": (
"Summarize the following conversation concisely. "
"Preserve key facts, decisions, tool results, and context "
"needed to continue the conversation naturally. "
"Be brief but complete.\n\n"
f"<conversation>\n{transcript}\n</conversation>"
),
}]
prompt, _ = self.build_prompt(summary_instruction, tools=None)
summary_text, _, _ = self.generate(
prompt=prompt,
images=None,
max_tokens=1024,
temperature=0.2,
)
return summary_text.strip()
# ------------------------------------------------------------------
# Image helpers
# ------------------------------------------------------------------
@@ -1029,6 +1085,22 @@ class ModelManager:
self._current_model = target
return self._engine
def get_context_length(self, model_id: str) -> int | None:
"""Get context length for a model from its cached config, without loading it."""
local_path = _resolve_local_model_path(model_id)
if local_path is None:
return None
config_file = local_path / "config.json"
if not config_file.is_file():
return None
try:
config = json.loads(config_file.read_text())
# VLMs nest under text_config
text_cfg = config.get("text_config", config)
return text_cfg.get("max_position_embeddings")
except Exception:
return None
def preload(self, model: str | None = None) -> None:
"""Eagerly load a model at startup."""
self.get_engine(model)