feat: hot swapping of models

This commit is contained in:
2026-03-17 11:58:24 +01:00
parent cc6e761ed4
commit ef83c24b0b
2 changed files with 96 additions and 13 deletions

View File

@@ -13,7 +13,7 @@ from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from sse_starlette.sse import EventSourceResponse
from .engine import DEFAULT_MODEL, InferenceEngine
from .engine import DEFAULT_MODEL, ModelManager
from .models import (
ChatCompletionChunk,
ChatCompletionRequest,
@@ -41,13 +41,13 @@ app.add_middleware(
allow_headers=["*"],
)
engine: InferenceEngine | None = None
manager: ModelManager | None = None
def get_engine() -> InferenceEngine:
if engine is None:
raise HTTPException(status_code=503, detail="Model not loaded")
return engine
def get_engine(requested_model: str | None = None):
if manager is None:
raise HTTPException(status_code=503, detail="Server not initialized")
return manager.get_engine(requested_model)
def _make_id() -> str:
@@ -61,13 +61,16 @@ def _make_id() -> str:
@app.get("/v1/models")
async def list_models() -> ModelListResponse:
e = get_engine()
return ModelListResponse(data=[ModelInfo(id=e.model_path)])
if manager is None:
raise HTTPException(status_code=503, detail="Server not initialized")
return ModelListResponse(
data=[ModelInfo(id=model_id) for model_id in manager.available_models]
)
@app.post("/v1/chat/completions")
async def chat_completions(request: ChatCompletionRequest):
e = get_engine()
e = get_engine(request.model)
# Convert pydantic messages to dicts
messages = [m.model_dump(exclude_none=True) for m in request.messages]
@@ -144,7 +147,7 @@ async def chat_completions(request: ChatCompletionRequest):
async def _stream_response(
e: InferenceEngine,
e,
prompt: str,
images: list[str] | None,
max_tokens: int,
@@ -288,9 +291,9 @@ def main():
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
)
global engine
engine = InferenceEngine(model_path=args.model)
engine.load()
global manager
manager = ModelManager(default_model=args.model)
manager.preload(args.model)
uvicorn.run(app, host=args.host, port=args.port, log_level=args.log_level)