From ef83c24b0b527cda534b0b37142dd1174b4b7e06 Mon Sep 17 00:00:00 2001 From: Chili Palmer Date: Tue, 17 Mar 2026 11:58:24 +0100 Subject: [PATCH] feat: hot swapping of models --- mlx_server/engine.py | 80 ++++++++++++++++++++++++++++++++++++++++++++ mlx_server/main.py | 29 +++++++++------- 2 files changed, 96 insertions(+), 13 deletions(-) diff --git a/mlx_server/engine.py b/mlx_server/engine.py index ef9d2f6..a635d63 100644 --- a/mlx_server/engine.py +++ b/mlx_server/engine.py @@ -274,6 +274,19 @@ class InferenceEngine: self._model_type = getattr(self.config, "model_type", "").lower() logger.info("Model loaded successfully (type=%s).", self._model_type) + def unload(self) -> None: + """Release model weights and caches to free memory.""" + logger.info("Unloading model %s ...", self.model_path) + self._prompt_cache.clear() + self.model = None + self.processor = None + self.config = None + self._model_type = "" + # Force garbage collection + clear MLX cache to reclaim memory + import gc + gc.collect() + mx.metal.clear_cache() + @property def is_qwen(self) -> bool: return "qwen" in self._model_type @@ -952,3 +965,70 @@ class InferenceEngine: logger.warning("Failed to parse tool_call tag %r: %s", match, e) return clean_text, tool_calls + + +class ModelManager: + """Registry of available models with on-demand loading and swapping. + + Only one model is loaded in memory at a time. When a request targets a + different model, the current one is unloaded first. + """ + + def __init__(self, default_model: str = DEFAULT_MODEL): + self._lock = threading.Lock() + self._engine: InferenceEngine | None = None + self._current_model: str | None = None + self._default_model = default_model + + @property + def available_models(self) -> list[str]: + """All model IDs that clients can request.""" + return list(MODEL_ALIASES.values()) + + @property + def available_aliases(self) -> dict[str, str]: + """Short alias -> full HuggingFace model path.""" + return dict(MODEL_ALIASES) + + def resolve_model(self, requested: str) -> str: + """Resolve a model string to a full HuggingFace model path. + + Accepts aliases ('gemma', 'qwen') or full paths. + """ + if requested in MODEL_ALIASES: + return MODEL_ALIASES[requested] + if requested in MODEL_ALIASES.values(): + return requested + # Accept partial matches (e.g. 'gemma-3-4b-it' matches the gemma entry) + for alias, full_path in MODEL_ALIASES.items(): + if requested in full_path or requested in alias: + return full_path + # Unknown model — return as-is and let loading fail if invalid + return requested + + def get_engine(self, requested_model: str | None = None) -> InferenceEngine: + """Return an engine for the requested model, swapping if necessary.""" + target = self.resolve_model(requested_model) if requested_model else self._default_model + + with self._lock: + if self._engine is not None and self._current_model == target: + return self._engine + + # Need to swap + if self._engine is not None: + logger.info( + "Swapping model: %s -> %s", self._current_model, target + ) + self._engine.unload() + self._engine = None + self._current_model = None + + engine = InferenceEngine(model_path=target) + engine.load() + self._engine = engine + self._current_model = target + return self._engine + + def preload(self, model: str | None = None) -> None: + """Eagerly load a model at startup.""" + self.get_engine(model) diff --git a/mlx_server/main.py b/mlx_server/main.py index 4920bca..88cff57 100644 --- a/mlx_server/main.py +++ b/mlx_server/main.py @@ -13,7 +13,7 @@ from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from sse_starlette.sse import EventSourceResponse -from .engine import DEFAULT_MODEL, InferenceEngine +from .engine import DEFAULT_MODEL, ModelManager from .models import ( ChatCompletionChunk, ChatCompletionRequest, @@ -41,13 +41,13 @@ app.add_middleware( allow_headers=["*"], ) -engine: InferenceEngine | None = None +manager: ModelManager | None = None -def get_engine() -> InferenceEngine: - if engine is None: - raise HTTPException(status_code=503, detail="Model not loaded") - return engine +def get_engine(requested_model: str | None = None): + if manager is None: + raise HTTPException(status_code=503, detail="Server not initialized") + return manager.get_engine(requested_model) def _make_id() -> str: @@ -61,13 +61,16 @@ def _make_id() -> str: @app.get("/v1/models") async def list_models() -> ModelListResponse: - e = get_engine() - return ModelListResponse(data=[ModelInfo(id=e.model_path)]) + if manager is None: + raise HTTPException(status_code=503, detail="Server not initialized") + return ModelListResponse( + data=[ModelInfo(id=model_id) for model_id in manager.available_models] + ) @app.post("/v1/chat/completions") async def chat_completions(request: ChatCompletionRequest): - e = get_engine() + e = get_engine(request.model) # Convert pydantic messages to dicts messages = [m.model_dump(exclude_none=True) for m in request.messages] @@ -144,7 +147,7 @@ async def chat_completions(request: ChatCompletionRequest): async def _stream_response( - e: InferenceEngine, + e, prompt: str, images: list[str] | None, max_tokens: int, @@ -288,9 +291,9 @@ def main(): format="%(asctime)s %(levelname)s %(name)s: %(message)s", ) - global engine - engine = InferenceEngine(model_path=args.model) - engine.load() + global manager + manager = ModelManager(default_model=args.model) + manager.preload(args.model) uvicorn.run(app, host=args.host, port=args.port, log_level=args.log_level)