feat: hot swapping of models

This commit is contained in:
2026-03-17 11:58:24 +01:00
parent cc6e761ed4
commit ef83c24b0b
2 changed files with 96 additions and 13 deletions

View File

@@ -274,6 +274,19 @@ class InferenceEngine:
self._model_type = getattr(self.config, "model_type", "").lower()
logger.info("Model loaded successfully (type=%s).", self._model_type)
def unload(self) -> None:
"""Release model weights and caches to free memory."""
logger.info("Unloading model %s ...", self.model_path)
self._prompt_cache.clear()
self.model = None
self.processor = None
self.config = None
self._model_type = ""
# Force garbage collection + clear MLX cache to reclaim memory
import gc
gc.collect()
mx.metal.clear_cache()
@property
def is_qwen(self) -> bool:
return "qwen" in self._model_type
@@ -952,3 +965,70 @@ class InferenceEngine:
logger.warning("Failed to parse tool_call tag %r: %s", match, e)
return clean_text, tool_calls
class ModelManager:
"""Registry of available models with on-demand loading and swapping.
Only one model is loaded in memory at a time. When a request targets a
different model, the current one is unloaded first.
"""
def __init__(self, default_model: str = DEFAULT_MODEL):
self._lock = threading.Lock()
self._engine: InferenceEngine | None = None
self._current_model: str | None = None
self._default_model = default_model
@property
def available_models(self) -> list[str]:
"""All model IDs that clients can request."""
return list(MODEL_ALIASES.values())
@property
def available_aliases(self) -> dict[str, str]:
"""Short alias -> full HuggingFace model path."""
return dict(MODEL_ALIASES)
def resolve_model(self, requested: str) -> str:
"""Resolve a model string to a full HuggingFace model path.
Accepts aliases ('gemma', 'qwen') or full paths.
"""
if requested in MODEL_ALIASES:
return MODEL_ALIASES[requested]
if requested in MODEL_ALIASES.values():
return requested
# Accept partial matches (e.g. 'gemma-3-4b-it' matches the gemma entry)
for alias, full_path in MODEL_ALIASES.items():
if requested in full_path or requested in alias:
return full_path
# Unknown model — return as-is and let loading fail if invalid
return requested
def get_engine(self, requested_model: str | None = None) -> InferenceEngine:
"""Return an engine for the requested model, swapping if necessary."""
target = self.resolve_model(requested_model) if requested_model else self._default_model
with self._lock:
if self._engine is not None and self._current_model == target:
return self._engine
# Need to swap
if self._engine is not None:
logger.info(
"Swapping model: %s -> %s", self._current_model, target
)
self._engine.unload()
self._engine = None
self._current_model = None
engine = InferenceEngine(model_path=target)
engine.load()
self._engine = engine
self._current_model = target
return self._engine
def preload(self, model: str | None = None) -> None:
"""Eagerly load a model at startup."""
self.get_engine(model)