78 lines
2.8 KiB
Python
78 lines
2.8 KiB
Python
import json
|
|
import pytest
|
|
import httpx
|
|
|
|
|
|
@pytest.mark.parametrize("case", list(range(120)))
|
|
def test_chat_completions_non_stream(api_client, respx_mock, case):
|
|
respx_mock.get("http://llama.test/v1/models").mock(
|
|
return_value=httpx.Response(200, json={"data": [{"id": "model-a.gguf"}]})
|
|
)
|
|
respx_mock.post("http://llama.test/v1/chat/completions").mock(
|
|
return_value=httpx.Response(200, json={"id": f"chatcmpl-{case}", "choices": [{"message": {"content": "ok"}}]})
|
|
)
|
|
|
|
payload = {
|
|
"model": "model-a.gguf",
|
|
"messages": [{"role": "user", "content": f"hello {case}"}],
|
|
"temperature": (case % 10) / 10,
|
|
}
|
|
resp = api_client.post("/v1/chat/completions", json=payload)
|
|
assert resp.status_code == 200
|
|
data = resp.json()
|
|
assert data["choices"][0]["message"]["content"] == "ok"
|
|
|
|
|
|
@pytest.mark.parametrize("case", list(range(120)))
|
|
def test_chat_completions_stream(api_client, respx_mock, case):
|
|
respx_mock.get("http://llama.test/v1/models").mock(
|
|
return_value=httpx.Response(200, json={"data": [{"id": "model-a.gguf"}]})
|
|
)
|
|
|
|
def stream_response(request):
|
|
content = b"data: {\"id\": \"chunk\"}\n\n"
|
|
return httpx.Response(200, content=content, headers={"Content-Type": "text/event-stream"})
|
|
|
|
respx_mock.post("http://llama.test/v1/chat/completions").mock(side_effect=stream_response)
|
|
|
|
payload = {
|
|
"model": "model-a.gguf",
|
|
"messages": [{"role": "user", "content": f"hello {case}"}],
|
|
"stream": True,
|
|
}
|
|
with api_client.stream("POST", "/v1/chat/completions", json=payload) as resp:
|
|
assert resp.status_code == 200
|
|
body = b"".join(resp.iter_bytes())
|
|
assert b"data:" in body
|
|
|
|
|
|
def test_chat_completions_tools_normalize(api_client, respx_mock):
|
|
respx_mock.get("http://llama.test/v1/models").mock(
|
|
return_value=httpx.Response(200, json={"data": [{"id": "model-a.gguf"}]})
|
|
)
|
|
|
|
def handler(request):
|
|
data = request.json()
|
|
tools = data.get("tools") or []
|
|
assert tools
|
|
assert tools[0].get("function", {}).get("name") == "format_final_json_response"
|
|
return httpx.Response(200, json={"id": "chatcmpl-tools", "choices": [{"message": {"content": "ok"}}]})
|
|
|
|
respx_mock.post("http://llama.test/v1/chat/completions").mock(side_effect=handler)
|
|
|
|
payload = {
|
|
"model": "model-a.gguf",
|
|
"messages": [{"role": "user", "content": "hello"}],
|
|
"tools": [
|
|
{
|
|
"type": "function",
|
|
"name": "format_final_json_response",
|
|
"parameters": {"type": "object"},
|
|
}
|
|
],
|
|
"tool_choice": {"type": "function", "name": "format_final_json_response"},
|
|
}
|
|
|
|
resp = api_client.post("/v1/chat/completions", json=payload)
|
|
assert resp.status_code == 200
|