initial commit
This commit is contained in:
296
test_server.py
Normal file
296
test_server.py
Normal file
@@ -0,0 +1,296 @@
|
||||
"""Test script for MLX Server – exercises chat, streaming, vision, and tool use."""
|
||||
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
import sys
|
||||
|
||||
import httpx
|
||||
from PIL import Image, ImageDraw
|
||||
|
||||
BASE_URL = "http://127.0.0.1:1234/v1"
|
||||
MODEL = "mlx-community/gemma-3-4b-it-4bit"
|
||||
|
||||
|
||||
def test_models():
|
||||
"""Test GET /v1/models."""
|
||||
print("=" * 60)
|
||||
print("TEST: List models")
|
||||
print("=" * 60)
|
||||
r = httpx.get(f"{BASE_URL}/models")
|
||||
r.raise_for_status()
|
||||
data = r.json()
|
||||
print(f"Models: {[m['id'] for m in data['data']]}")
|
||||
print("PASS\n")
|
||||
|
||||
|
||||
def test_chat_basic():
|
||||
"""Test basic non-streaming chat."""
|
||||
print("=" * 60)
|
||||
print("TEST: Basic chat (non-streaming)")
|
||||
print("=" * 60)
|
||||
r = httpx.post(
|
||||
f"{BASE_URL}/chat/completions",
|
||||
json={
|
||||
"model": MODEL,
|
||||
"messages": [{"role": "user", "content": "Say exactly: 'The sky is blue.' Nothing else."}],
|
||||
"max_tokens": 50,
|
||||
"temperature": 0.1,
|
||||
},
|
||||
timeout=120,
|
||||
)
|
||||
r.raise_for_status()
|
||||
data = r.json()
|
||||
msg = data["choices"][0]["message"]["content"]
|
||||
usage = data["usage"]
|
||||
print(f"Response: {msg}")
|
||||
print(f"Usage: {usage}")
|
||||
print(f"Finish reason: {data['choices'][0]['finish_reason']}")
|
||||
print("PASS\n")
|
||||
|
||||
|
||||
def test_chat_streaming():
|
||||
"""Test streaming chat."""
|
||||
print("=" * 60)
|
||||
print("TEST: Streaming chat")
|
||||
print("=" * 60)
|
||||
collected = ""
|
||||
with httpx.stream(
|
||||
"POST",
|
||||
f"{BASE_URL}/chat/completions",
|
||||
json={
|
||||
"model": MODEL,
|
||||
"messages": [{"role": "user", "content": "Count from 1 to 5, one number per line."}],
|
||||
"max_tokens": 100,
|
||||
"temperature": 0.1,
|
||||
"stream": True,
|
||||
},
|
||||
timeout=120,
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
for line in response.iter_lines():
|
||||
if not line.startswith("data: "):
|
||||
continue
|
||||
payload = line[len("data: "):]
|
||||
if payload == "[DONE]":
|
||||
break
|
||||
chunk = json.loads(payload)
|
||||
delta = chunk["choices"][0]["delta"]
|
||||
if delta.get("content"):
|
||||
collected += delta["content"]
|
||||
print(delta["content"], end="", flush=True)
|
||||
if chunk["choices"][0].get("finish_reason"):
|
||||
print(f"\n[finish_reason: {chunk['choices'][0]['finish_reason']}]")
|
||||
if chunk.get("usage") and chunk["usage"].get("total_tokens", 0) > 0:
|
||||
print(f"[usage: {chunk['usage']}]")
|
||||
print(f"Full collected: {collected!r}")
|
||||
print("PASS\n")
|
||||
|
||||
|
||||
def _make_test_image() -> str:
|
||||
"""Create a simple test image and return it as a base64 data URI."""
|
||||
img = Image.new("RGB", (200, 200), color=(135, 206, 235))
|
||||
draw = ImageDraw.Draw(img)
|
||||
# Draw a red circle
|
||||
draw.ellipse([50, 50, 150, 150], fill=(255, 0, 0), outline=(0, 0, 0), width=2)
|
||||
# Draw a green triangle
|
||||
draw.polygon([(100, 20), (60, 80), (140, 80)], fill=(0, 180, 0), outline=(0, 0, 0))
|
||||
# Draw yellow text area
|
||||
draw.rectangle([10, 160, 190, 190], fill=(255, 255, 0))
|
||||
|
||||
buf = io.BytesIO()
|
||||
img.save(buf, format="PNG")
|
||||
b64 = base64.b64encode(buf.getvalue()).decode()
|
||||
return f"data:image/png;base64,{b64}"
|
||||
|
||||
|
||||
def test_vision():
|
||||
"""Test vision with an image."""
|
||||
print("=" * 60)
|
||||
print("TEST: Vision (image description)")
|
||||
print("=" * 60)
|
||||
image_uri = _make_test_image()
|
||||
print(f"Image: 200x200 PNG with red circle, green triangle, yellow bar")
|
||||
|
||||
r = httpx.post(
|
||||
f"{BASE_URL}/chat/completions",
|
||||
json={
|
||||
"model": MODEL,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "Describe what shapes and colors you see in this image. Be brief."},
|
||||
{"type": "image_url", "image_url": {"url": image_uri}},
|
||||
],
|
||||
}
|
||||
],
|
||||
"max_tokens": 200,
|
||||
"temperature": 0.1,
|
||||
},
|
||||
timeout=120,
|
||||
)
|
||||
r.raise_for_status()
|
||||
data = r.json()
|
||||
msg = data["choices"][0]["message"]["content"]
|
||||
print(f"Response: {msg}")
|
||||
print("PASS\n")
|
||||
|
||||
|
||||
def test_tool_use():
|
||||
"""Test tool calling."""
|
||||
print("=" * 60)
|
||||
print("TEST: Tool use")
|
||||
print("=" * 60)
|
||||
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "Get the current weather for a given city",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description": "The city name, e.g. 'London'",
|
||||
},
|
||||
"units": {
|
||||
"type": "string",
|
||||
"description": "Temperature units: 'celsius' or 'fahrenheit'",
|
||||
},
|
||||
},
|
||||
"required": ["city"],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
# Step 1: Ask the model to use the tool
|
||||
print("Step 1: Asking model to get weather for Paris...")
|
||||
r = httpx.post(
|
||||
f"{BASE_URL}/chat/completions",
|
||||
json={
|
||||
"model": MODEL,
|
||||
"messages": [
|
||||
{"role": "user", "content": "What is the weather in Paris right now? Use the get_weather tool."},
|
||||
],
|
||||
"tools": tools,
|
||||
"max_tokens": 300,
|
||||
"temperature": 0.1,
|
||||
},
|
||||
timeout=120,
|
||||
)
|
||||
r.raise_for_status()
|
||||
data = r.json()
|
||||
choice = data["choices"][0]
|
||||
print(f"Finish reason: {choice['finish_reason']}")
|
||||
print(f"Content: {choice['message'].get('content')}")
|
||||
print(f"Tool calls: {choice['message'].get('tool_calls')}")
|
||||
|
||||
if choice["message"].get("tool_calls"):
|
||||
tc = choice["message"]["tool_calls"][0]
|
||||
print(f"\nTool call detected:")
|
||||
print(f" ID: {tc['id']}")
|
||||
print(f" Function: {tc['function']['name']}")
|
||||
print(f" Arguments: {tc['function']['arguments']}")
|
||||
|
||||
# Step 2: Send the tool result back
|
||||
print("\nStep 2: Sending mock tool result back...")
|
||||
r2 = httpx.post(
|
||||
f"{BASE_URL}/chat/completions",
|
||||
json={
|
||||
"model": MODEL,
|
||||
"messages": [
|
||||
{"role": "user", "content": "What is the weather in Paris right now? Use the get_weather tool."},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": choice["message"].get("content"),
|
||||
"tool_calls": choice["message"]["tool_calls"],
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": tc["id"],
|
||||
"content": json.dumps({"temperature": 18, "condition": "Partly cloudy", "humidity": 65}),
|
||||
},
|
||||
],
|
||||
"tools": tools,
|
||||
"max_tokens": 300,
|
||||
"temperature": 0.1,
|
||||
},
|
||||
timeout=120,
|
||||
)
|
||||
r2.raise_for_status()
|
||||
data2 = r2.json()
|
||||
msg2 = data2["choices"][0]["message"]["content"]
|
||||
print(f"Final response: {msg2}")
|
||||
else:
|
||||
print("WARNING: Model did not produce a tool call. Raw response above.")
|
||||
|
||||
print("PASS\n")
|
||||
|
||||
|
||||
def test_multi_turn():
|
||||
"""Test multi-turn conversation."""
|
||||
print("=" * 60)
|
||||
print("TEST: Multi-turn conversation")
|
||||
print("=" * 60)
|
||||
messages = [
|
||||
{"role": "user", "content": "My name is Alice."},
|
||||
]
|
||||
r = httpx.post(
|
||||
f"{BASE_URL}/chat/completions",
|
||||
json={"model": MODEL, "messages": messages, "max_tokens": 100, "temperature": 0.1},
|
||||
timeout=120,
|
||||
)
|
||||
r.raise_for_status()
|
||||
reply1 = r.json()["choices"][0]["message"]["content"]
|
||||
print(f"Turn 1 reply: {reply1}")
|
||||
|
||||
messages.append({"role": "assistant", "content": reply1})
|
||||
messages.append({"role": "user", "content": "What is my name?"})
|
||||
|
||||
r2 = httpx.post(
|
||||
f"{BASE_URL}/chat/completions",
|
||||
json={"model": MODEL, "messages": messages, "max_tokens": 100, "temperature": 0.1},
|
||||
timeout=120,
|
||||
)
|
||||
r2.raise_for_status()
|
||||
reply2 = r2.json()["choices"][0]["message"]["content"]
|
||||
print(f"Turn 2 reply: {reply2}")
|
||||
assert "alice" in reply2.lower(), f"Expected 'Alice' in response, got: {reply2}"
|
||||
print("PASS\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tests = [
|
||||
test_models,
|
||||
test_chat_basic,
|
||||
test_chat_streaming,
|
||||
test_vision,
|
||||
test_tool_use,
|
||||
test_multi_turn,
|
||||
]
|
||||
|
||||
# Allow running a single test by name
|
||||
if len(sys.argv) > 1:
|
||||
name = sys.argv[1]
|
||||
tests = [t for t in tests if name in t.__name__]
|
||||
if not tests:
|
||||
print(f"No test matching '{name}'. Available: models, chat_basic, chat_streaming, vision, tool_use, multi_turn")
|
||||
sys.exit(1)
|
||||
|
||||
passed = 0
|
||||
failed = 0
|
||||
for test in tests:
|
||||
try:
|
||||
test()
|
||||
passed += 1
|
||||
except Exception as e:
|
||||
print(f"FAIL: {e}\n")
|
||||
failed += 1
|
||||
|
||||
print("=" * 60)
|
||||
print(f"Results: {passed} passed, {failed} failed")
|
||||
print("=" * 60)
|
||||
Reference in New Issue
Block a user