75 lines
2.1 KiB
Python
75 lines
2.1 KiB
Python
import asyncio
|
|
import logging
|
|
import time
|
|
from pathlib import Path
|
|
|
|
import httpx
|
|
|
|
|
|
log = logging.getLogger("llamacpp_warmup")
|
|
|
|
|
|
def _is_loading_error(response: httpx.Response) -> bool:
|
|
if response.status_code != 503:
|
|
return False
|
|
try:
|
|
payload = response.json()
|
|
except Exception:
|
|
return False
|
|
message = ""
|
|
if isinstance(payload, dict):
|
|
error = payload.get("error")
|
|
if isinstance(error, dict):
|
|
message = str(error.get("message") or "")
|
|
else:
|
|
message = str(payload.get("message") or "")
|
|
return "loading model" in message.lower()
|
|
|
|
|
|
def resolve_warmup_prompt(override: str | None, fallback_path: str) -> str:
|
|
if override:
|
|
prompt = override.strip()
|
|
if prompt:
|
|
return prompt
|
|
try:
|
|
prompt = Path(fallback_path).read_text(encoding="utf-8").strip()
|
|
if prompt:
|
|
return prompt
|
|
except Exception as exc:
|
|
log.warning("Failed to read warmup prompt from %s: %s", fallback_path, exc)
|
|
return "ok"
|
|
|
|
|
|
async def run_warmup(base_url: str, model_id: str, prompt: str, timeout_s: float) -> None:
|
|
payload = {
|
|
"model": model_id,
|
|
"messages": [{"role": "user", "content": prompt}],
|
|
"max_tokens": 8,
|
|
"temperature": 0,
|
|
}
|
|
async with httpx.AsyncClient(base_url=base_url, timeout=timeout_s) as client:
|
|
resp = await client.post("/v1/chat/completions", json=payload)
|
|
if resp.status_code == 503 and _is_loading_error(resp):
|
|
raise RuntimeError("llama.cpp still loading model")
|
|
resp.raise_for_status()
|
|
|
|
|
|
async def run_warmup_with_retry(
|
|
base_url: str,
|
|
model_id: str,
|
|
prompt: str,
|
|
timeout_s: float,
|
|
interval_s: float = 3.0,
|
|
) -> None:
|
|
deadline = time.time() + timeout_s
|
|
last_exc: Exception | None = None
|
|
while time.time() < deadline:
|
|
try:
|
|
await run_warmup(base_url, model_id, prompt, timeout_s=timeout_s)
|
|
return
|
|
except Exception as exc:
|
|
last_exc = exc
|
|
await asyncio.sleep(interval_s)
|
|
if last_exc:
|
|
raise last_exc
|