Files
Rushabh Gosar 5d1a0ee72b Initial commit
2026-01-07 16:54:39 -08:00

75 lines
2.1 KiB
Python

import asyncio
import logging
import time
from pathlib import Path
import httpx
log = logging.getLogger("llamacpp_warmup")
def _is_loading_error(response: httpx.Response) -> bool:
if response.status_code != 503:
return False
try:
payload = response.json()
except Exception:
return False
message = ""
if isinstance(payload, dict):
error = payload.get("error")
if isinstance(error, dict):
message = str(error.get("message") or "")
else:
message = str(payload.get("message") or "")
return "loading model" in message.lower()
def resolve_warmup_prompt(override: str | None, fallback_path: str) -> str:
if override:
prompt = override.strip()
if prompt:
return prompt
try:
prompt = Path(fallback_path).read_text(encoding="utf-8").strip()
if prompt:
return prompt
except Exception as exc:
log.warning("Failed to read warmup prompt from %s: %s", fallback_path, exc)
return "ok"
async def run_warmup(base_url: str, model_id: str, prompt: str, timeout_s: float) -> None:
payload = {
"model": model_id,
"messages": [{"role": "user", "content": prompt}],
"max_tokens": 8,
"temperature": 0,
}
async with httpx.AsyncClient(base_url=base_url, timeout=timeout_s) as client:
resp = await client.post("/v1/chat/completions", json=payload)
if resp.status_code == 503 and _is_loading_error(resp):
raise RuntimeError("llama.cpp still loading model")
resp.raise_for_status()
async def run_warmup_with_retry(
base_url: str,
model_id: str,
prompt: str,
timeout_s: float,
interval_s: float = 3.0,
) -> None:
deadline = time.time() + timeout_s
last_exc: Exception | None = None
while time.time() < deadline:
try:
await run_warmup(base_url, model_id, prompt, timeout_s=timeout_s)
return
except Exception as exc:
last_exc = exc
await asyncio.sleep(interval_s)
if last_exc:
raise last_exc