feat: hot swapping of models
This commit is contained in:
@@ -274,6 +274,19 @@ class InferenceEngine:
|
||||
self._model_type = getattr(self.config, "model_type", "").lower()
|
||||
logger.info("Model loaded successfully (type=%s).", self._model_type)
|
||||
|
||||
def unload(self) -> None:
|
||||
"""Release model weights and caches to free memory."""
|
||||
logger.info("Unloading model %s ...", self.model_path)
|
||||
self._prompt_cache.clear()
|
||||
self.model = None
|
||||
self.processor = None
|
||||
self.config = None
|
||||
self._model_type = ""
|
||||
# Force garbage collection + clear MLX cache to reclaim memory
|
||||
import gc
|
||||
gc.collect()
|
||||
mx.metal.clear_cache()
|
||||
|
||||
@property
|
||||
def is_qwen(self) -> bool:
|
||||
return "qwen" in self._model_type
|
||||
@@ -952,3 +965,70 @@ class InferenceEngine:
|
||||
logger.warning("Failed to parse tool_call tag %r: %s", match, e)
|
||||
|
||||
return clean_text, tool_calls
|
||||
|
||||
|
||||
class ModelManager:
|
||||
"""Registry of available models with on-demand loading and swapping.
|
||||
|
||||
Only one model is loaded in memory at a time. When a request targets a
|
||||
different model, the current one is unloaded first.
|
||||
"""
|
||||
|
||||
def __init__(self, default_model: str = DEFAULT_MODEL):
|
||||
self._lock = threading.Lock()
|
||||
self._engine: InferenceEngine | None = None
|
||||
self._current_model: str | None = None
|
||||
self._default_model = default_model
|
||||
|
||||
@property
|
||||
def available_models(self) -> list[str]:
|
||||
"""All model IDs that clients can request."""
|
||||
return list(MODEL_ALIASES.values())
|
||||
|
||||
@property
|
||||
def available_aliases(self) -> dict[str, str]:
|
||||
"""Short alias -> full HuggingFace model path."""
|
||||
return dict(MODEL_ALIASES)
|
||||
|
||||
def resolve_model(self, requested: str) -> str:
|
||||
"""Resolve a model string to a full HuggingFace model path.
|
||||
|
||||
Accepts aliases ('gemma', 'qwen') or full paths.
|
||||
"""
|
||||
if requested in MODEL_ALIASES:
|
||||
return MODEL_ALIASES[requested]
|
||||
if requested in MODEL_ALIASES.values():
|
||||
return requested
|
||||
# Accept partial matches (e.g. 'gemma-3-4b-it' matches the gemma entry)
|
||||
for alias, full_path in MODEL_ALIASES.items():
|
||||
if requested in full_path or requested in alias:
|
||||
return full_path
|
||||
# Unknown model — return as-is and let loading fail if invalid
|
||||
return requested
|
||||
|
||||
def get_engine(self, requested_model: str | None = None) -> InferenceEngine:
|
||||
"""Return an engine for the requested model, swapping if necessary."""
|
||||
target = self.resolve_model(requested_model) if requested_model else self._default_model
|
||||
|
||||
with self._lock:
|
||||
if self._engine is not None and self._current_model == target:
|
||||
return self._engine
|
||||
|
||||
# Need to swap
|
||||
if self._engine is not None:
|
||||
logger.info(
|
||||
"Swapping model: %s -> %s", self._current_model, target
|
||||
)
|
||||
self._engine.unload()
|
||||
self._engine = None
|
||||
self._current_model = None
|
||||
|
||||
engine = InferenceEngine(model_path=target)
|
||||
engine.load()
|
||||
self._engine = engine
|
||||
self._current_model = target
|
||||
return self._engine
|
||||
|
||||
def preload(self, model: str | None = None) -> None:
|
||||
"""Eagerly load a model at startup."""
|
||||
self.get_engine(model)
|
||||
|
||||
Reference in New Issue
Block a user