feat: proper support for context size
This commit is contained in:
@@ -295,6 +295,62 @@ class InferenceEngine:
|
||||
def is_gemma(self) -> bool:
|
||||
return "gemma" in self._model_type
|
||||
|
||||
@property
|
||||
def context_length(self) -> int:
|
||||
"""Max context length from the model config."""
|
||||
if self.config is None:
|
||||
return 0
|
||||
# VLMs nest the LLM config under text_config
|
||||
text_cfg = getattr(self.config, "text_config", self.config)
|
||||
return getattr(text_cfg, "max_position_embeddings", 0)
|
||||
|
||||
def count_tokens(self, text: str) -> int:
|
||||
"""Count tokens in a text string. Thread-safe, no lock needed."""
|
||||
tokenizer = self._get_tokenizer()
|
||||
return len(tokenizer.encode(text))
|
||||
|
||||
def summarize_messages(self, messages: list[dict]) -> str:
|
||||
"""Summarize a list of conversation messages into a concise text.
|
||||
|
||||
Calls generate() internally (acquires and releases the lock).
|
||||
"""
|
||||
# Build a readable transcript from the messages
|
||||
transcript_lines = []
|
||||
for msg in messages:
|
||||
role = msg.get("role", "unknown")
|
||||
content = self._get_text_content(msg.get("content"))
|
||||
# Include tool call info if present
|
||||
if msg.get("tool_calls"):
|
||||
tool_names = [
|
||||
tc.get("function", tc).get("name", "?")
|
||||
for tc in msg["tool_calls"]
|
||||
]
|
||||
content += f" [called tools: {', '.join(tool_names)}]"
|
||||
if content.strip():
|
||||
transcript_lines.append(f"{role}: {content.strip()}")
|
||||
|
||||
transcript = "\n".join(transcript_lines)
|
||||
|
||||
summary_instruction = [{
|
||||
"role": "user",
|
||||
"content": (
|
||||
"Summarize the following conversation concisely. "
|
||||
"Preserve key facts, decisions, tool results, and context "
|
||||
"needed to continue the conversation naturally. "
|
||||
"Be brief but complete.\n\n"
|
||||
f"<conversation>\n{transcript}\n</conversation>"
|
||||
),
|
||||
}]
|
||||
|
||||
prompt, _ = self.build_prompt(summary_instruction, tools=None)
|
||||
summary_text, _, _ = self.generate(
|
||||
prompt=prompt,
|
||||
images=None,
|
||||
max_tokens=1024,
|
||||
temperature=0.2,
|
||||
)
|
||||
return summary_text.strip()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Image helpers
|
||||
# ------------------------------------------------------------------
|
||||
@@ -1029,6 +1085,22 @@ class ModelManager:
|
||||
self._current_model = target
|
||||
return self._engine
|
||||
|
||||
def get_context_length(self, model_id: str) -> int | None:
|
||||
"""Get context length for a model from its cached config, without loading it."""
|
||||
local_path = _resolve_local_model_path(model_id)
|
||||
if local_path is None:
|
||||
return None
|
||||
config_file = local_path / "config.json"
|
||||
if not config_file.is_file():
|
||||
return None
|
||||
try:
|
||||
config = json.loads(config_file.read_text())
|
||||
# VLMs nest under text_config
|
||||
text_cfg = config.get("text_config", config)
|
||||
return text_cfg.get("max_position_embeddings")
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def preload(self, model: str | None = None) -> None:
|
||||
"""Eagerly load a model at startup."""
|
||||
self.get_engine(model)
|
||||
|
||||
Reference in New Issue
Block a user