initial commit

This commit is contained in:
2026-03-17 09:14:27 +01:00
commit df81afe8d7
10 changed files with 1389 additions and 0 deletions

278
mlx_server/main.py Normal file
View File

@@ -0,0 +1,278 @@
"""OpenAI-compatible API server for Gemma 3 4B via MLX."""
from __future__ import annotations
import argparse
import json
import logging
import time
import uuid
import uvicorn
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from sse_starlette.sse import EventSourceResponse
from .engine import DEFAULT_MODEL, InferenceEngine
from .models import (
ChatCompletionChunk,
ChatCompletionRequest,
ChatCompletionResponse,
Choice,
ChoiceMessage,
DeltaMessage,
ModelInfo,
ModelListResponse,
StreamChoice,
ToolCall,
FunctionCall,
UsageInfo,
)
logger = logging.getLogger(__name__)
app = FastAPI(title="MLX Server", description="OpenAI-compatible API for Gemma 3 4B")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
engine: InferenceEngine | None = None
def get_engine() -> InferenceEngine:
if engine is None:
raise HTTPException(status_code=503, detail="Model not loaded")
return engine
def _make_id() -> str:
return f"chatcmpl-{uuid.uuid4().hex[:12]}"
# ------------------------------------------------------------------
# Endpoints
# ------------------------------------------------------------------
@app.get("/v1/models")
async def list_models() -> ModelListResponse:
e = get_engine()
return ModelListResponse(data=[ModelInfo(id=e.model_path)])
@app.post("/v1/chat/completions")
async def chat_completions(request: ChatCompletionRequest):
e = get_engine()
# Convert pydantic messages to dicts
messages = [m.model_dump(exclude_none=True) for m in request.messages]
tools = None
if request.tools:
tools = [t.model_dump(exclude_none=True) for t in request.tools]
prompt, images = e.build_prompt(messages, tools)
stop = request.stop
if isinstance(stop, str):
stop = [stop]
temperature = request.temperature if request.temperature is not None else 0.7
top_p = request.top_p if request.top_p is not None else 0.9
max_tokens = request.max_tokens if request.max_tokens is not None else 4096
if request.stream:
return EventSourceResponse(
_stream_response(e, prompt, images, max_tokens, temperature, top_p, stop, tools, request.model),
media_type="text/event-stream",
)
# Non-streaming
text, prompt_tokens, completion_tokens = e.generate(
prompt=prompt,
images=images or None,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
stop=stop,
)
# Check for tool calls in the response
finish_reason = "stop"
tool_calls_parsed = None
if tools:
clean_text, parsed = e.parse_tool_calls(text, tools)
if parsed:
tool_calls_parsed = [
ToolCall(
index=i,
id=tc["id"],
type="function",
function=FunctionCall(
name=tc["function"]["name"],
arguments=tc["function"]["arguments"],
),
)
for i, tc in enumerate(parsed)
]
text = clean_text if clean_text else None
finish_reason = "tool_calls"
return ChatCompletionResponse(
id=_make_id(),
model=request.model,
choices=[
Choice(
message=ChoiceMessage(
role="assistant",
content=text if not tool_calls_parsed else (text or None),
tool_calls=tool_calls_parsed,
),
finish_reason=finish_reason,
)
],
usage=UsageInfo(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
),
)
async def _stream_response(
e: InferenceEngine,
prompt: str,
images: list[str] | None,
max_tokens: int,
temperature: float,
top_p: float,
stop: list[str] | None,
tools: list[dict] | None,
model_name: str,
):
request_id = _make_id()
created = int(time.time())
# Send initial chunk with role
initial_chunk = ChatCompletionChunk(
id=request_id,
created=created,
model=model_name,
choices=[StreamChoice(delta=DeltaMessage(role="assistant"))],
)
yield {"data": initial_chunk.model_dump_json()}
full_text = ""
prompt_tokens = 0
gen_tokens = 0
for token_text, is_final, pt, gt in e.stream_generate(
prompt=prompt,
images=images or None,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
stop=stop,
):
prompt_tokens = pt
gen_tokens = gt
full_text += token_text
if not is_final and token_text:
chunk = ChatCompletionChunk(
id=request_id,
created=created,
model=model_name,
choices=[StreamChoice(delta=DeltaMessage(content=token_text))],
)
yield {"data": chunk.model_dump_json()}
# Check for tool calls in complete response
finish_reason = "stop"
if tools:
clean_text, parsed = e.parse_tool_calls(full_text, tools)
if parsed:
finish_reason = "tool_calls"
# Emit tool call chunks
for i, tc in enumerate(parsed):
tc_chunk = ChatCompletionChunk(
id=request_id,
created=created,
model=model_name,
choices=[
StreamChoice(
delta=DeltaMessage(
tool_calls=[
ToolCall(
index=i,
id=tc["id"],
type="function",
function=FunctionCall(
name=tc["function"]["name"],
arguments=tc["function"]["arguments"],
),
)
]
)
)
],
)
yield {"data": tc_chunk.model_dump_json()}
# Final chunk with finish reason and usage
final_chunk = ChatCompletionChunk(
id=request_id,
created=created,
model=model_name,
choices=[StreamChoice(delta=DeltaMessage(), finish_reason=finish_reason)],
usage=UsageInfo(
prompt_tokens=prompt_tokens,
completion_tokens=gen_tokens,
total_tokens=prompt_tokens + gen_tokens,
),
)
yield {"data": final_chunk.model_dump_json()}
yield {"data": "[DONE]"}
# ------------------------------------------------------------------
# Health / utility
# ------------------------------------------------------------------
@app.get("/health")
async def health():
return {"status": "ok"}
# ------------------------------------------------------------------
# Entrypoint
# ------------------------------------------------------------------
def main():
parser = argparse.ArgumentParser(description="MLX Server OpenAI-compatible API")
parser.add_argument("--model", type=str, default=DEFAULT_MODEL, help="HuggingFace model path")
parser.add_argument("--host", type=str, default="127.0.0.1")
parser.add_argument("--port", type=int, default=1234)
parser.add_argument("--log-level", type=str, default="info")
args = parser.parse_args()
logging.basicConfig(
level=getattr(logging, args.log_level.upper()),
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
)
global engine
engine = InferenceEngine(model_path=args.model)
engine.load()
uvicorn.run(app, host=args.host, port=args.port, log_level=args.log_level)
if __name__ == "__main__":
main()