284 lines
9.3 KiB
Python
284 lines
9.3 KiB
Python
import asyncio
|
|
import json
|
|
import os
|
|
import ssl
|
|
import time
|
|
from typing import Dict, List
|
|
|
|
import pytest
|
|
import requests
|
|
import websockets
|
|
|
|
WRAPPER_BASE = os.getenv("WRAPPER_BASE", "http://192.168.1.2:9093")
|
|
UI_BASE = os.getenv("UI_BASE", "http://192.168.1.2:9094")
|
|
TRUENAS_WS_URL = os.getenv("TRUENAS_WS_URL", "wss://192.168.1.2/websocket")
|
|
TRUENAS_API_KEY = os.getenv("TRUENAS_API_KEY", "")
|
|
TRUENAS_APP_NAME = os.getenv("TRUENAS_APP_NAME", "llamacpp")
|
|
MODEL_REQUEST = os.getenv("MODEL_REQUEST", "")
|
|
|
|
|
|
async def _rpc_call(method: str, params: List | None = None):
|
|
if not TRUENAS_API_KEY:
|
|
pytest.skip("TRUENAS_API_KEY not set")
|
|
ssl_ctx = ssl.create_default_context()
|
|
ssl_ctx.check_hostname = False
|
|
ssl_ctx.verify_mode = ssl.CERT_NONE
|
|
async with websockets.connect(TRUENAS_WS_URL, ssl=ssl_ctx) as ws:
|
|
await ws.send(json.dumps({"msg": "connect", "version": "1", "support": ["1"]}))
|
|
connected = json.loads(await ws.recv())
|
|
if connected.get("msg") != "connected":
|
|
raise RuntimeError("failed to connect")
|
|
await ws.send(json.dumps({"id": 1, "msg": "method", "method": "auth.login_with_api_key", "params": [TRUENAS_API_KEY]}))
|
|
auth = json.loads(await ws.recv())
|
|
if not auth.get("result"):
|
|
raise RuntimeError("auth failed")
|
|
await ws.send(json.dumps({"id": 2, "msg": "method", "method": method, "params": params or []}))
|
|
while True:
|
|
raw = json.loads(await ws.recv())
|
|
if raw.get("id") != 2:
|
|
continue
|
|
if raw.get("msg") == "error":
|
|
raise RuntimeError(raw.get("error"))
|
|
return raw.get("result")
|
|
|
|
|
|
def _get_models() -> List[str]:
|
|
_wait_for_http(WRAPPER_BASE + "/health")
|
|
resp = requests.get(WRAPPER_BASE + "/v1/models", timeout=30)
|
|
resp.raise_for_status()
|
|
data = resp.json().get("data") or []
|
|
return [m.get("id") for m in data if m.get("id")]
|
|
|
|
|
|
def _assert_chat_ok(resp_json: Dict) -> str:
|
|
choices = resp_json.get("choices") or []
|
|
assert choices, "no choices"
|
|
message = choices[0].get("message") or {}
|
|
text = message.get("content") or ""
|
|
assert text.strip(), "empty content"
|
|
return text
|
|
|
|
|
|
def _wait_for_http(url: str, timeout_s: float = 90) -> None:
|
|
deadline = time.time() + timeout_s
|
|
last_err = None
|
|
while time.time() < deadline:
|
|
try:
|
|
resp = requests.get(url, timeout=5)
|
|
if resp.status_code == 200:
|
|
return
|
|
last_err = f"status {resp.status_code}"
|
|
except Exception as exc:
|
|
last_err = str(exc)
|
|
time.sleep(2)
|
|
raise RuntimeError(f"service not ready: {url} ({last_err})")
|
|
|
|
|
|
def _post_with_retry(url: str, payload: Dict, timeout_s: float = 300, retries: int = 6, delay_s: float = 5.0):
|
|
last = None
|
|
for _ in range(retries):
|
|
try:
|
|
resp = requests.post(url, json=payload, timeout=timeout_s)
|
|
if resp.status_code == 200:
|
|
return resp
|
|
last = resp
|
|
except requests.exceptions.RequestException as exc:
|
|
last = exc
|
|
time.sleep(delay_s)
|
|
if isinstance(last, Exception):
|
|
raise last
|
|
return last
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_active_model_and_multi_gpu_flags():
|
|
cfg = await _rpc_call("app.config", [TRUENAS_APP_NAME])
|
|
command = cfg.get("command") or []
|
|
assert "--model" in command
|
|
assert "--tensor-split" in command
|
|
split_idx = command.index("--tensor-split") + 1
|
|
split = command[split_idx]
|
|
assert "," in split, f"tensor-split missing commas: {split}"
|
|
assert "--split-mode" in command
|
|
|
|
|
|
def test_models_listed():
|
|
models = _get_models()
|
|
assert models, "no models discovered"
|
|
|
|
|
|
def test_chat_completions_switch_and_prompts():
|
|
models = _get_models()
|
|
assert models, "no models"
|
|
if MODEL_REQUEST:
|
|
assert MODEL_REQUEST in models, f"MODEL_REQUEST not found: {MODEL_REQUEST}"
|
|
model_id = MODEL_REQUEST
|
|
else:
|
|
model_id = models[0]
|
|
payload = {
|
|
"model": model_id,
|
|
"messages": [{"role": "user", "content": "Say OK."}],
|
|
"max_tokens": 12,
|
|
"temperature": 0,
|
|
}
|
|
for _ in range(3):
|
|
resp = _post_with_retry(WRAPPER_BASE + "/v1/chat/completions", payload)
|
|
assert resp.status_code == 200
|
|
_assert_chat_ok(resp.json())
|
|
|
|
|
|
def test_tools_flat_format():
|
|
models = _get_models()
|
|
assert models, "no models"
|
|
if MODEL_REQUEST:
|
|
assert MODEL_REQUEST in models, f"MODEL_REQUEST not found: {MODEL_REQUEST}"
|
|
model_id = MODEL_REQUEST
|
|
else:
|
|
model_id = models[0]
|
|
payload = {
|
|
"model": model_id,
|
|
"messages": [{"role": "user", "content": "Say OK and do not call tools."}],
|
|
"tools": [
|
|
{
|
|
"type": "function",
|
|
"name": "format_final_json_response",
|
|
"description": "format output",
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {"ok": {"type": "boolean"}},
|
|
"required": ["ok"],
|
|
},
|
|
}
|
|
],
|
|
"max_tokens": 12,
|
|
}
|
|
resp = _post_with_retry(WRAPPER_BASE + "/v1/chat/completions", payload)
|
|
assert resp.status_code == 200
|
|
_assert_chat_ok(resp.json())
|
|
|
|
|
|
def test_functions_payload_normalized():
|
|
models = _get_models()
|
|
assert models, "no models"
|
|
if MODEL_REQUEST:
|
|
assert MODEL_REQUEST in models, f"MODEL_REQUEST not found: {MODEL_REQUEST}"
|
|
model_id = MODEL_REQUEST
|
|
else:
|
|
model_id = models[0]
|
|
payload = {
|
|
"model": model_id,
|
|
"messages": [{"role": "user", "content": "Say OK and do not call tools."}],
|
|
"functions": [
|
|
{
|
|
"name": "format_final_json_response",
|
|
"description": "format output",
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {"ok": {"type": "boolean"}},
|
|
"required": ["ok"],
|
|
},
|
|
}
|
|
],
|
|
"max_tokens": 12,
|
|
}
|
|
resp = _post_with_retry(WRAPPER_BASE + "/v1/chat/completions", payload)
|
|
assert resp.status_code == 200
|
|
_assert_chat_ok(resp.json())
|
|
|
|
|
|
def test_return_format_json():
|
|
models = _get_models()
|
|
assert models, "no models"
|
|
if MODEL_REQUEST:
|
|
assert MODEL_REQUEST in models, f"MODEL_REQUEST not found: {MODEL_REQUEST}"
|
|
model_id = MODEL_REQUEST
|
|
else:
|
|
model_id = models[0]
|
|
payload = {
|
|
"model": model_id,
|
|
"messages": [{"role": "user", "content": "Return JSON with key ok true."}],
|
|
"return_format": "json",
|
|
"max_tokens": 32,
|
|
"temperature": 0,
|
|
}
|
|
resp = _post_with_retry(WRAPPER_BASE + "/v1/chat/completions", payload)
|
|
assert resp.status_code == 200
|
|
text = _assert_chat_ok(resp.json())
|
|
parsed = json.loads(text)
|
|
assert isinstance(parsed, dict)
|
|
|
|
|
|
def test_responses_endpoint():
|
|
models = _get_models()
|
|
assert models, "no models"
|
|
if MODEL_REQUEST:
|
|
assert MODEL_REQUEST in models, f"MODEL_REQUEST not found: {MODEL_REQUEST}"
|
|
model_id = MODEL_REQUEST
|
|
else:
|
|
model_id = models[0]
|
|
payload = {
|
|
"model": model_id,
|
|
"input": "Say OK.",
|
|
"max_output_tokens": 16,
|
|
}
|
|
resp = _post_with_retry(WRAPPER_BASE + "/v1/responses", payload)
|
|
assert resp.status_code == 200
|
|
output = resp.json().get("output") or []
|
|
assert output, "responses output empty"
|
|
content = output[0].get("content") or []
|
|
text = content[0].get("text") if content else ""
|
|
assert text and text.strip()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_model_switch_applied_to_truenas():
|
|
models = _get_models()
|
|
assert models, "no models"
|
|
target = MODEL_REQUEST or models[0]
|
|
assert target in models, f"MODEL_REQUEST not found: {target}"
|
|
resp = requests.post(UI_BASE + "/ui/api/switch-model", json={"model_id": target, "warmup_prompt": "warmup"}, timeout=600)
|
|
assert resp.status_code == 200
|
|
cfg = await _rpc_call("app.config", [TRUENAS_APP_NAME])
|
|
command = cfg.get("command") or []
|
|
assert "--model" in command
|
|
model_path = command[command.index("--model") + 1]
|
|
assert model_path.endswith(target)
|
|
|
|
|
|
def test_invalid_model_rejected():
|
|
models = _get_models()
|
|
assert models, "no models"
|
|
payload = {
|
|
"model": "modelx-q8:4b",
|
|
"messages": [{"role": "user", "content": "Say OK."}],
|
|
"max_tokens": 8,
|
|
"temperature": 0,
|
|
}
|
|
resp = requests.post(WRAPPER_BASE + "/v1/chat/completions", json=payload, timeout=60)
|
|
assert resp.status_code == 404
|
|
|
|
|
|
def test_llamacpp_logs_streaming():
|
|
logs = ""
|
|
for _ in range(5):
|
|
try:
|
|
resp = requests.get(UI_BASE + "/ui/api/llamacpp-logs", timeout=10)
|
|
if resp.status_code == 200:
|
|
logs = resp.json().get("logs") or ""
|
|
if logs.strip():
|
|
break
|
|
except requests.exceptions.ReadTimeout:
|
|
pass
|
|
time.sleep(2)
|
|
assert logs.strip(), "no logs returned"
|
|
|
|
# Force a log line before streaming.
|
|
try:
|
|
requests.get(WRAPPER_BASE + "/proxy/llamacpp/health", timeout=5)
|
|
except Exception:
|
|
pass
|
|
|
|
# Stream endpoint may not emit immediately, so validate that the endpoint responds.
|
|
with requests.get(UI_BASE + "/ui/api/llamacpp-logs/stream", stream=True, timeout=(5, 5)) as resp:
|
|
assert resp.status_code == 200
|