feat: hot swapping of models
This commit is contained in:
@@ -13,7 +13,7 @@ from fastapi import FastAPI, HTTPException
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from sse_starlette.sse import EventSourceResponse
|
||||
|
||||
from .engine import DEFAULT_MODEL, InferenceEngine
|
||||
from .engine import DEFAULT_MODEL, ModelManager
|
||||
from .models import (
|
||||
ChatCompletionChunk,
|
||||
ChatCompletionRequest,
|
||||
@@ -41,13 +41,13 @@ app.add_middleware(
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
engine: InferenceEngine | None = None
|
||||
manager: ModelManager | None = None
|
||||
|
||||
|
||||
def get_engine() -> InferenceEngine:
|
||||
if engine is None:
|
||||
raise HTTPException(status_code=503, detail="Model not loaded")
|
||||
return engine
|
||||
def get_engine(requested_model: str | None = None):
|
||||
if manager is None:
|
||||
raise HTTPException(status_code=503, detail="Server not initialized")
|
||||
return manager.get_engine(requested_model)
|
||||
|
||||
|
||||
def _make_id() -> str:
|
||||
@@ -61,13 +61,16 @@ def _make_id() -> str:
|
||||
|
||||
@app.get("/v1/models")
|
||||
async def list_models() -> ModelListResponse:
|
||||
e = get_engine()
|
||||
return ModelListResponse(data=[ModelInfo(id=e.model_path)])
|
||||
if manager is None:
|
||||
raise HTTPException(status_code=503, detail="Server not initialized")
|
||||
return ModelListResponse(
|
||||
data=[ModelInfo(id=model_id) for model_id in manager.available_models]
|
||||
)
|
||||
|
||||
|
||||
@app.post("/v1/chat/completions")
|
||||
async def chat_completions(request: ChatCompletionRequest):
|
||||
e = get_engine()
|
||||
e = get_engine(request.model)
|
||||
|
||||
# Convert pydantic messages to dicts
|
||||
messages = [m.model_dump(exclude_none=True) for m in request.messages]
|
||||
@@ -144,7 +147,7 @@ async def chat_completions(request: ChatCompletionRequest):
|
||||
|
||||
|
||||
async def _stream_response(
|
||||
e: InferenceEngine,
|
||||
e,
|
||||
prompt: str,
|
||||
images: list[str] | None,
|
||||
max_tokens: int,
|
||||
@@ -288,9 +291,9 @@ def main():
|
||||
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
|
||||
)
|
||||
|
||||
global engine
|
||||
engine = InferenceEngine(model_path=args.model)
|
||||
engine.load()
|
||||
global manager
|
||||
manager = ModelManager(default_model=args.model)
|
||||
manager.preload(args.model)
|
||||
|
||||
uvicorn.run(app, host=args.host, port=args.port, log_level=args.log_level)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user