feat: qwen now works, too
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
"""OpenAI-compatible API server for Gemma 3 4B via MLX."""
|
||||
"""OpenAI-compatible API server for local LLMs (Gemma 3, Qwen3, …) via MLX."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
@@ -31,7 +31,7 @@ from .models import (
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
app = FastAPI(title="MLX Server", description="OpenAI-compatible API for Gemma 3 4B")
|
||||
app = FastAPI(title="MLX Server", description="OpenAI-compatible API for local LLMs on Apple Silicon")
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
@@ -170,6 +170,11 @@ async def _stream_response(
|
||||
prompt_tokens = 0
|
||||
gen_tokens = 0
|
||||
|
||||
# When tools are available we must buffer the full response before
|
||||
# emitting content — otherwise raw tool-call markup (```tool_code```
|
||||
# or <tool_call>) leaks into the streamed text.
|
||||
buffer_for_tools = bool(tools)
|
||||
|
||||
for token_text, is_final, pt, gt in e.stream_generate(
|
||||
prompt=prompt,
|
||||
images=images or None,
|
||||
@@ -182,7 +187,7 @@ async def _stream_response(
|
||||
gen_tokens = gt
|
||||
full_text += token_text
|
||||
|
||||
if not is_final and token_text:
|
||||
if not buffer_for_tools and not is_final and token_text:
|
||||
chunk = ChatCompletionChunk(
|
||||
id=request_id,
|
||||
created=created,
|
||||
@@ -191,37 +196,53 @@ async def _stream_response(
|
||||
)
|
||||
yield {"data": chunk.model_dump_json()}
|
||||
|
||||
# Check for tool calls in complete response
|
||||
# --- Post-generation: parse tool calls and emit clean content ------
|
||||
finish_reason = "stop"
|
||||
tool_calls_parsed = []
|
||||
|
||||
if tools:
|
||||
clean_text, parsed = e.parse_tool_calls(full_text, tools)
|
||||
if parsed:
|
||||
finish_reason = "tool_calls"
|
||||
# Emit tool call chunks
|
||||
for i, tc in enumerate(parsed):
|
||||
tc_chunk = ChatCompletionChunk(
|
||||
id=request_id,
|
||||
created=created,
|
||||
model=model_name,
|
||||
choices=[
|
||||
StreamChoice(
|
||||
delta=DeltaMessage(
|
||||
tool_calls=[
|
||||
ToolCall(
|
||||
index=i,
|
||||
id=tc["id"],
|
||||
type="function",
|
||||
function=FunctionCall(
|
||||
name=tc["function"]["name"],
|
||||
arguments=tc["function"]["arguments"],
|
||||
),
|
||||
)
|
||||
]
|
||||
tool_calls_parsed = parsed
|
||||
full_text = clean_text or ""
|
||||
|
||||
# Emit buffered content (when tools were present, this is the cleaned
|
||||
# text with tool-call markup stripped out)
|
||||
if buffer_for_tools and full_text.strip():
|
||||
content_chunk = ChatCompletionChunk(
|
||||
id=request_id,
|
||||
created=created,
|
||||
model=model_name,
|
||||
choices=[StreamChoice(delta=DeltaMessage(content=full_text))],
|
||||
)
|
||||
yield {"data": content_chunk.model_dump_json()}
|
||||
|
||||
# Emit tool call chunks
|
||||
for i, tc in enumerate(tool_calls_parsed):
|
||||
tc_chunk = ChatCompletionChunk(
|
||||
id=request_id,
|
||||
created=created,
|
||||
model=model_name,
|
||||
choices=[
|
||||
StreamChoice(
|
||||
delta=DeltaMessage(
|
||||
tool_calls=[
|
||||
ToolCall(
|
||||
index=i,
|
||||
id=tc["id"],
|
||||
type="function",
|
||||
function=FunctionCall(
|
||||
name=tc["function"]["name"],
|
||||
arguments=tc["function"]["arguments"],
|
||||
),
|
||||
)
|
||||
)
|
||||
],
|
||||
]
|
||||
)
|
||||
)
|
||||
yield {"data": tc_chunk.model_dump_json()}
|
||||
],
|
||||
)
|
||||
yield {"data": tc_chunk.model_dump_json()}
|
||||
|
||||
# Final chunk with finish reason and usage
|
||||
final_chunk = ChatCompletionChunk(
|
||||
|
||||
Reference in New Issue
Block a user