feat: hot swapping of models
This commit is contained in:
@@ -274,6 +274,19 @@ class InferenceEngine:
|
|||||||
self._model_type = getattr(self.config, "model_type", "").lower()
|
self._model_type = getattr(self.config, "model_type", "").lower()
|
||||||
logger.info("Model loaded successfully (type=%s).", self._model_type)
|
logger.info("Model loaded successfully (type=%s).", self._model_type)
|
||||||
|
|
||||||
|
def unload(self) -> None:
|
||||||
|
"""Release model weights and caches to free memory."""
|
||||||
|
logger.info("Unloading model %s ...", self.model_path)
|
||||||
|
self._prompt_cache.clear()
|
||||||
|
self.model = None
|
||||||
|
self.processor = None
|
||||||
|
self.config = None
|
||||||
|
self._model_type = ""
|
||||||
|
# Force garbage collection + clear MLX cache to reclaim memory
|
||||||
|
import gc
|
||||||
|
gc.collect()
|
||||||
|
mx.metal.clear_cache()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_qwen(self) -> bool:
|
def is_qwen(self) -> bool:
|
||||||
return "qwen" in self._model_type
|
return "qwen" in self._model_type
|
||||||
@@ -952,3 +965,70 @@ class InferenceEngine:
|
|||||||
logger.warning("Failed to parse tool_call tag %r: %s", match, e)
|
logger.warning("Failed to parse tool_call tag %r: %s", match, e)
|
||||||
|
|
||||||
return clean_text, tool_calls
|
return clean_text, tool_calls
|
||||||
|
|
||||||
|
|
||||||
|
class ModelManager:
|
||||||
|
"""Registry of available models with on-demand loading and swapping.
|
||||||
|
|
||||||
|
Only one model is loaded in memory at a time. When a request targets a
|
||||||
|
different model, the current one is unloaded first.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, default_model: str = DEFAULT_MODEL):
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
self._engine: InferenceEngine | None = None
|
||||||
|
self._current_model: str | None = None
|
||||||
|
self._default_model = default_model
|
||||||
|
|
||||||
|
@property
|
||||||
|
def available_models(self) -> list[str]:
|
||||||
|
"""All model IDs that clients can request."""
|
||||||
|
return list(MODEL_ALIASES.values())
|
||||||
|
|
||||||
|
@property
|
||||||
|
def available_aliases(self) -> dict[str, str]:
|
||||||
|
"""Short alias -> full HuggingFace model path."""
|
||||||
|
return dict(MODEL_ALIASES)
|
||||||
|
|
||||||
|
def resolve_model(self, requested: str) -> str:
|
||||||
|
"""Resolve a model string to a full HuggingFace model path.
|
||||||
|
|
||||||
|
Accepts aliases ('gemma', 'qwen') or full paths.
|
||||||
|
"""
|
||||||
|
if requested in MODEL_ALIASES:
|
||||||
|
return MODEL_ALIASES[requested]
|
||||||
|
if requested in MODEL_ALIASES.values():
|
||||||
|
return requested
|
||||||
|
# Accept partial matches (e.g. 'gemma-3-4b-it' matches the gemma entry)
|
||||||
|
for alias, full_path in MODEL_ALIASES.items():
|
||||||
|
if requested in full_path or requested in alias:
|
||||||
|
return full_path
|
||||||
|
# Unknown model — return as-is and let loading fail if invalid
|
||||||
|
return requested
|
||||||
|
|
||||||
|
def get_engine(self, requested_model: str | None = None) -> InferenceEngine:
|
||||||
|
"""Return an engine for the requested model, swapping if necessary."""
|
||||||
|
target = self.resolve_model(requested_model) if requested_model else self._default_model
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
if self._engine is not None and self._current_model == target:
|
||||||
|
return self._engine
|
||||||
|
|
||||||
|
# Need to swap
|
||||||
|
if self._engine is not None:
|
||||||
|
logger.info(
|
||||||
|
"Swapping model: %s -> %s", self._current_model, target
|
||||||
|
)
|
||||||
|
self._engine.unload()
|
||||||
|
self._engine = None
|
||||||
|
self._current_model = None
|
||||||
|
|
||||||
|
engine = InferenceEngine(model_path=target)
|
||||||
|
engine.load()
|
||||||
|
self._engine = engine
|
||||||
|
self._current_model = target
|
||||||
|
return self._engine
|
||||||
|
|
||||||
|
def preload(self, model: str | None = None) -> None:
|
||||||
|
"""Eagerly load a model at startup."""
|
||||||
|
self.get_engine(model)
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ from fastapi import FastAPI, HTTPException
|
|||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from sse_starlette.sse import EventSourceResponse
|
from sse_starlette.sse import EventSourceResponse
|
||||||
|
|
||||||
from .engine import DEFAULT_MODEL, InferenceEngine
|
from .engine import DEFAULT_MODEL, ModelManager
|
||||||
from .models import (
|
from .models import (
|
||||||
ChatCompletionChunk,
|
ChatCompletionChunk,
|
||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
@@ -41,13 +41,13 @@ app.add_middleware(
|
|||||||
allow_headers=["*"],
|
allow_headers=["*"],
|
||||||
)
|
)
|
||||||
|
|
||||||
engine: InferenceEngine | None = None
|
manager: ModelManager | None = None
|
||||||
|
|
||||||
|
|
||||||
def get_engine() -> InferenceEngine:
|
def get_engine(requested_model: str | None = None):
|
||||||
if engine is None:
|
if manager is None:
|
||||||
raise HTTPException(status_code=503, detail="Model not loaded")
|
raise HTTPException(status_code=503, detail="Server not initialized")
|
||||||
return engine
|
return manager.get_engine(requested_model)
|
||||||
|
|
||||||
|
|
||||||
def _make_id() -> str:
|
def _make_id() -> str:
|
||||||
@@ -61,13 +61,16 @@ def _make_id() -> str:
|
|||||||
|
|
||||||
@app.get("/v1/models")
|
@app.get("/v1/models")
|
||||||
async def list_models() -> ModelListResponse:
|
async def list_models() -> ModelListResponse:
|
||||||
e = get_engine()
|
if manager is None:
|
||||||
return ModelListResponse(data=[ModelInfo(id=e.model_path)])
|
raise HTTPException(status_code=503, detail="Server not initialized")
|
||||||
|
return ModelListResponse(
|
||||||
|
data=[ModelInfo(id=model_id) for model_id in manager.available_models]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@app.post("/v1/chat/completions")
|
@app.post("/v1/chat/completions")
|
||||||
async def chat_completions(request: ChatCompletionRequest):
|
async def chat_completions(request: ChatCompletionRequest):
|
||||||
e = get_engine()
|
e = get_engine(request.model)
|
||||||
|
|
||||||
# Convert pydantic messages to dicts
|
# Convert pydantic messages to dicts
|
||||||
messages = [m.model_dump(exclude_none=True) for m in request.messages]
|
messages = [m.model_dump(exclude_none=True) for m in request.messages]
|
||||||
@@ -144,7 +147,7 @@ async def chat_completions(request: ChatCompletionRequest):
|
|||||||
|
|
||||||
|
|
||||||
async def _stream_response(
|
async def _stream_response(
|
||||||
e: InferenceEngine,
|
e,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
images: list[str] | None,
|
images: list[str] | None,
|
||||||
max_tokens: int,
|
max_tokens: int,
|
||||||
@@ -288,9 +291,9 @@ def main():
|
|||||||
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
|
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
|
||||||
)
|
)
|
||||||
|
|
||||||
global engine
|
global manager
|
||||||
engine = InferenceEngine(model_path=args.model)
|
manager = ModelManager(default_model=args.model)
|
||||||
engine.load()
|
manager.preload(args.model)
|
||||||
|
|
||||||
uvicorn.run(app, host=args.host, port=args.port, log_level=args.log_level)
|
uvicorn.run(app, host=args.host, port=args.port, log_level=args.log_level)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user