feat: hot swapping of models
This commit is contained in:
@@ -274,6 +274,19 @@ class InferenceEngine:
|
||||
self._model_type = getattr(self.config, "model_type", "").lower()
|
||||
logger.info("Model loaded successfully (type=%s).", self._model_type)
|
||||
|
||||
def unload(self) -> None:
|
||||
"""Release model weights and caches to free memory."""
|
||||
logger.info("Unloading model %s ...", self.model_path)
|
||||
self._prompt_cache.clear()
|
||||
self.model = None
|
||||
self.processor = None
|
||||
self.config = None
|
||||
self._model_type = ""
|
||||
# Force garbage collection + clear MLX cache to reclaim memory
|
||||
import gc
|
||||
gc.collect()
|
||||
mx.metal.clear_cache()
|
||||
|
||||
@property
|
||||
def is_qwen(self) -> bool:
|
||||
return "qwen" in self._model_type
|
||||
@@ -952,3 +965,70 @@ class InferenceEngine:
|
||||
logger.warning("Failed to parse tool_call tag %r: %s", match, e)
|
||||
|
||||
return clean_text, tool_calls
|
||||
|
||||
|
||||
class ModelManager:
|
||||
"""Registry of available models with on-demand loading and swapping.
|
||||
|
||||
Only one model is loaded in memory at a time. When a request targets a
|
||||
different model, the current one is unloaded first.
|
||||
"""
|
||||
|
||||
def __init__(self, default_model: str = DEFAULT_MODEL):
|
||||
self._lock = threading.Lock()
|
||||
self._engine: InferenceEngine | None = None
|
||||
self._current_model: str | None = None
|
||||
self._default_model = default_model
|
||||
|
||||
@property
|
||||
def available_models(self) -> list[str]:
|
||||
"""All model IDs that clients can request."""
|
||||
return list(MODEL_ALIASES.values())
|
||||
|
||||
@property
|
||||
def available_aliases(self) -> dict[str, str]:
|
||||
"""Short alias -> full HuggingFace model path."""
|
||||
return dict(MODEL_ALIASES)
|
||||
|
||||
def resolve_model(self, requested: str) -> str:
|
||||
"""Resolve a model string to a full HuggingFace model path.
|
||||
|
||||
Accepts aliases ('gemma', 'qwen') or full paths.
|
||||
"""
|
||||
if requested in MODEL_ALIASES:
|
||||
return MODEL_ALIASES[requested]
|
||||
if requested in MODEL_ALIASES.values():
|
||||
return requested
|
||||
# Accept partial matches (e.g. 'gemma-3-4b-it' matches the gemma entry)
|
||||
for alias, full_path in MODEL_ALIASES.items():
|
||||
if requested in full_path or requested in alias:
|
||||
return full_path
|
||||
# Unknown model — return as-is and let loading fail if invalid
|
||||
return requested
|
||||
|
||||
def get_engine(self, requested_model: str | None = None) -> InferenceEngine:
|
||||
"""Return an engine for the requested model, swapping if necessary."""
|
||||
target = self.resolve_model(requested_model) if requested_model else self._default_model
|
||||
|
||||
with self._lock:
|
||||
if self._engine is not None and self._current_model == target:
|
||||
return self._engine
|
||||
|
||||
# Need to swap
|
||||
if self._engine is not None:
|
||||
logger.info(
|
||||
"Swapping model: %s -> %s", self._current_model, target
|
||||
)
|
||||
self._engine.unload()
|
||||
self._engine = None
|
||||
self._current_model = None
|
||||
|
||||
engine = InferenceEngine(model_path=target)
|
||||
engine.load()
|
||||
self._engine = engine
|
||||
self._current_model = target
|
||||
return self._engine
|
||||
|
||||
def preload(self, model: str | None = None) -> None:
|
||||
"""Eagerly load a model at startup."""
|
||||
self.get_engine(model)
|
||||
|
||||
@@ -13,7 +13,7 @@ from fastapi import FastAPI, HTTPException
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from sse_starlette.sse import EventSourceResponse
|
||||
|
||||
from .engine import DEFAULT_MODEL, InferenceEngine
|
||||
from .engine import DEFAULT_MODEL, ModelManager
|
||||
from .models import (
|
||||
ChatCompletionChunk,
|
||||
ChatCompletionRequest,
|
||||
@@ -41,13 +41,13 @@ app.add_middleware(
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
engine: InferenceEngine | None = None
|
||||
manager: ModelManager | None = None
|
||||
|
||||
|
||||
def get_engine() -> InferenceEngine:
|
||||
if engine is None:
|
||||
raise HTTPException(status_code=503, detail="Model not loaded")
|
||||
return engine
|
||||
def get_engine(requested_model: str | None = None):
|
||||
if manager is None:
|
||||
raise HTTPException(status_code=503, detail="Server not initialized")
|
||||
return manager.get_engine(requested_model)
|
||||
|
||||
|
||||
def _make_id() -> str:
|
||||
@@ -61,13 +61,16 @@ def _make_id() -> str:
|
||||
|
||||
@app.get("/v1/models")
|
||||
async def list_models() -> ModelListResponse:
|
||||
e = get_engine()
|
||||
return ModelListResponse(data=[ModelInfo(id=e.model_path)])
|
||||
if manager is None:
|
||||
raise HTTPException(status_code=503, detail="Server not initialized")
|
||||
return ModelListResponse(
|
||||
data=[ModelInfo(id=model_id) for model_id in manager.available_models]
|
||||
)
|
||||
|
||||
|
||||
@app.post("/v1/chat/completions")
|
||||
async def chat_completions(request: ChatCompletionRequest):
|
||||
e = get_engine()
|
||||
e = get_engine(request.model)
|
||||
|
||||
# Convert pydantic messages to dicts
|
||||
messages = [m.model_dump(exclude_none=True) for m in request.messages]
|
||||
@@ -144,7 +147,7 @@ async def chat_completions(request: ChatCompletionRequest):
|
||||
|
||||
|
||||
async def _stream_response(
|
||||
e: InferenceEngine,
|
||||
e,
|
||||
prompt: str,
|
||||
images: list[str] | None,
|
||||
max_tokens: int,
|
||||
@@ -288,9 +291,9 @@ def main():
|
||||
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
|
||||
)
|
||||
|
||||
global engine
|
||||
engine = InferenceEngine(model_path=args.model)
|
||||
engine.load()
|
||||
global manager
|
||||
manager = ModelManager(default_model=args.model)
|
||||
manager.preload(args.model)
|
||||
|
||||
uvicorn.run(app, host=args.host, port=args.port, log_level=args.log_level)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user