feat: hot swapping of models

This commit is contained in:
2026-03-17 11:58:24 +01:00
parent cc6e761ed4
commit ef83c24b0b
2 changed files with 96 additions and 13 deletions

View File

@@ -274,6 +274,19 @@ class InferenceEngine:
self._model_type = getattr(self.config, "model_type", "").lower()
logger.info("Model loaded successfully (type=%s).", self._model_type)
def unload(self) -> None:
"""Release model weights and caches to free memory."""
logger.info("Unloading model %s ...", self.model_path)
self._prompt_cache.clear()
self.model = None
self.processor = None
self.config = None
self._model_type = ""
# Force garbage collection + clear MLX cache to reclaim memory
import gc
gc.collect()
mx.metal.clear_cache()
@property
def is_qwen(self) -> bool:
return "qwen" in self._model_type
@@ -952,3 +965,70 @@ class InferenceEngine:
logger.warning("Failed to parse tool_call tag %r: %s", match, e)
return clean_text, tool_calls
class ModelManager:
"""Registry of available models with on-demand loading and swapping.
Only one model is loaded in memory at a time. When a request targets a
different model, the current one is unloaded first.
"""
def __init__(self, default_model: str = DEFAULT_MODEL):
self._lock = threading.Lock()
self._engine: InferenceEngine | None = None
self._current_model: str | None = None
self._default_model = default_model
@property
def available_models(self) -> list[str]:
"""All model IDs that clients can request."""
return list(MODEL_ALIASES.values())
@property
def available_aliases(self) -> dict[str, str]:
"""Short alias -> full HuggingFace model path."""
return dict(MODEL_ALIASES)
def resolve_model(self, requested: str) -> str:
"""Resolve a model string to a full HuggingFace model path.
Accepts aliases ('gemma', 'qwen') or full paths.
"""
if requested in MODEL_ALIASES:
return MODEL_ALIASES[requested]
if requested in MODEL_ALIASES.values():
return requested
# Accept partial matches (e.g. 'gemma-3-4b-it' matches the gemma entry)
for alias, full_path in MODEL_ALIASES.items():
if requested in full_path or requested in alias:
return full_path
# Unknown model — return as-is and let loading fail if invalid
return requested
def get_engine(self, requested_model: str | None = None) -> InferenceEngine:
"""Return an engine for the requested model, swapping if necessary."""
target = self.resolve_model(requested_model) if requested_model else self._default_model
with self._lock:
if self._engine is not None and self._current_model == target:
return self._engine
# Need to swap
if self._engine is not None:
logger.info(
"Swapping model: %s -> %s", self._current_model, target
)
self._engine.unload()
self._engine = None
self._current_model = None
engine = InferenceEngine(model_path=target)
engine.load()
self._engine = engine
self._current_model = target
return self._engine
def preload(self, model: str | None = None) -> None:
"""Eagerly load a model at startup."""
self.get_engine(model)

View File

@@ -13,7 +13,7 @@ from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from sse_starlette.sse import EventSourceResponse
from .engine import DEFAULT_MODEL, InferenceEngine
from .engine import DEFAULT_MODEL, ModelManager
from .models import (
ChatCompletionChunk,
ChatCompletionRequest,
@@ -41,13 +41,13 @@ app.add_middleware(
allow_headers=["*"],
)
engine: InferenceEngine | None = None
manager: ModelManager | None = None
def get_engine() -> InferenceEngine:
if engine is None:
raise HTTPException(status_code=503, detail="Model not loaded")
return engine
def get_engine(requested_model: str | None = None):
if manager is None:
raise HTTPException(status_code=503, detail="Server not initialized")
return manager.get_engine(requested_model)
def _make_id() -> str:
@@ -61,13 +61,16 @@ def _make_id() -> str:
@app.get("/v1/models")
async def list_models() -> ModelListResponse:
e = get_engine()
return ModelListResponse(data=[ModelInfo(id=e.model_path)])
if manager is None:
raise HTTPException(status_code=503, detail="Server not initialized")
return ModelListResponse(
data=[ModelInfo(id=model_id) for model_id in manager.available_models]
)
@app.post("/v1/chat/completions")
async def chat_completions(request: ChatCompletionRequest):
e = get_engine()
e = get_engine(request.model)
# Convert pydantic messages to dicts
messages = [m.model_dump(exclude_none=True) for m in request.messages]
@@ -144,7 +147,7 @@ async def chat_completions(request: ChatCompletionRequest):
async def _stream_response(
e: InferenceEngine,
e,
prompt: str,
images: list[str] | None,
max_tokens: int,
@@ -288,9 +291,9 @@ def main():
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
)
global engine
engine = InferenceEngine(model_path=args.model)
engine.load()
global manager
manager = ModelManager(default_model=args.model)
manager.preload(args.model)
uvicorn.run(app, host=args.host, port=args.port, log_level=args.log_level)