Files
codex_truenas_helper/llamaCpp.Wrapper.app/api_app.py
Rushabh Gosar 5d1a0ee72b Initial commit
2026-01-07 16:54:39 -08:00

310 lines
13 KiB
Python

import asyncio
import logging
import time
from pathlib import Path
from typing import Any, Dict
from fastapi import APIRouter, FastAPI, HTTPException, Request, Response
from fastapi.responses import JSONResponse, StreamingResponse
import httpx
from app.config import load_config
from app.llamacpp_client import proxy_json, proxy_raw, proxy_stream
from app.logging_utils import configure_logging
from app.model_registry import find_model, resolve_model, scan_models
from app.openai_translate import responses_to_chat_payload, chat_to_responses, normalize_chat_payload
from app.restart import RestartPlan, trigger_restart
from app.stream_transform import stream_chat_to_responses
from app.truenas_middleware import TrueNASConfig, get_active_model_id, switch_model
from app.warmup import resolve_warmup_prompt, run_warmup_with_retry
configure_logging()
log = logging.getLogger("api_app")
def _model_list_payload(model_dir: str) -> Dict[str, Any]:
data = []
for model in scan_models(model_dir):
data.append({
"id": model.model_id,
"object": "model",
"created": model.created,
"owned_by": "llama.cpp",
})
return {"object": "list", "data": data}
def _requires_json_mode(payload: Dict[str, Any]) -> bool:
response_format = payload.get("response_format")
if isinstance(response_format, dict) and response_format.get("type") == "json_object":
return True
if payload.get("return_format") == "json":
return True
return False
def _apply_json_fallback(payload: Dict[str, Any]) -> Dict[str, Any]:
payload = dict(payload)
payload.pop("response_format", None)
payload.pop("return_format", None)
messages = payload.get("messages")
if isinstance(messages, list):
system_msg = {"role": "system", "content": "Respond only with a valid JSON object."}
if not messages or messages[0].get("role") != "system":
payload["messages"] = [system_msg, *messages]
else:
payload["messages"] = [system_msg, *messages[1:]]
return payload
async def _proxy_json_with_retry(
base_url: str,
path: str,
method: str,
headers: Dict[str, str],
payload: Dict[str, Any],
timeout_s: float,
delay_s: float = 3.0,
) -> httpx.Response:
deadline = time.time() + timeout_s
attempt = 0
last_exc: Exception | None = None
while time.time() < deadline:
attempt += 1
try:
resp = await proxy_json(base_url, path, method, headers, payload, timeout_s)
if resp.status_code == 503:
try:
data = resp.json()
except Exception:
data = {}
message = ""
if isinstance(data, dict):
err = data.get("error")
if isinstance(err, dict):
message = str(err.get("message") or "")
else:
message = str(data.get("message") or "")
if "loading model" in message.lower():
log.warning("llama.cpp still loading model, retrying (attempt %s)", attempt)
await asyncio.sleep(delay_s)
continue
return resp
except httpx.RequestError as exc:
last_exc = exc
log.warning("Proxy request failed (attempt %s): %s", attempt, exc)
await asyncio.sleep(delay_s)
if last_exc:
raise last_exc
raise RuntimeError("proxy retry deadline exceeded")
async def _get_active_model_from_truenas(cfg: TrueNASConfig) -> str:
try:
return await get_active_model_id(cfg)
except Exception as exc:
log.warning("Failed to read active model from TrueNAS config: %s", exc)
return ""
async def _wait_for_active_model(cfg: TrueNASConfig, model_id: str, timeout_s: float) -> None:
deadline = asyncio.get_event_loop().time() + timeout_s
while asyncio.get_event_loop().time() < deadline:
active = await _get_active_model_from_truenas(cfg)
if active == model_id:
return
await asyncio.sleep(2)
raise RuntimeError(f"active model did not switch to {model_id}")
async def _ensure_model_loaded(model_id: str, model_dir: str) -> str:
cfg = load_config()
model = resolve_model(model_dir, model_id, cfg.model_aliases)
if not model:
log.warning("Requested model not found: %s", model_id)
raise HTTPException(status_code=404, detail="model not found")
if model.model_id != model_id:
log.info("Resolved model alias %s -> %s", model_id, model.model_id)
truenas_cfg = None
if cfg.truenas_ws_url and cfg.truenas_api_key:
truenas_cfg = TrueNASConfig(
ws_url=cfg.truenas_ws_url,
api_key=cfg.truenas_api_key,
api_user=cfg.truenas_api_user,
app_name=cfg.truenas_app_name,
verify_ssl=cfg.truenas_verify_ssl,
)
active_id = await _get_active_model_from_truenas(truenas_cfg)
if active_id and active_id == model.model_id:
return model.model_id
if truenas_cfg:
log.info("Switching model via API model=%s args=%s extra_args=%s", model.model_id, cfg.llamacpp_args, cfg.llamacpp_extra_args)
try:
model_path = str((Path(cfg.model_container_dir) / model.model_id))
await switch_model(
truenas_cfg,
model_path,
cfg.llamacpp_args,
cfg.llamacpp_extra_args,
)
await _wait_for_active_model(truenas_cfg, model.model_id, cfg.switch_timeout_s)
except Exception as exc:
log.exception("TrueNAS model switch failed")
raise HTTPException(status_code=500, detail=f"model switch failed: {exc}")
warmup_prompt = resolve_warmup_prompt(None, cfg.warmup_prompt_path)
log.info("Running warmup prompt after model switch: model=%s prompt_len=%s", model.model_id, len(warmup_prompt))
await run_warmup_with_retry(cfg.base_url, model.model_id, warmup_prompt, timeout_s=cfg.switch_timeout_s)
return model.model_id
plan = RestartPlan(
method=cfg.restart_method,
command=cfg.restart_command,
url=cfg.restart_url,
allowed_container=cfg.allowed_container,
)
log.info("Triggering restart for model=%s method=%s", model.model_id, cfg.restart_method)
payload = {
"model_id": model.model_id,
"model_path": str(Path(cfg.model_container_dir) / model.model_id),
"gpu_count": cfg.gpu_count_runtime or cfg.agents.gpu_count,
"llamacpp_args": cfg.llamacpp_args,
"llamacpp_extra_args": cfg.llamacpp_extra_args,
}
await trigger_restart(plan, payload=payload)
warmup_prompt = resolve_warmup_prompt(None, cfg.warmup_prompt_path)
log.info("Running warmup prompt after restart: model=%s prompt_len=%s", model.model_id, len(warmup_prompt))
await run_warmup_with_retry(cfg.base_url, model.model_id, warmup_prompt, timeout_s=cfg.switch_timeout_s)
return model.model_id
def create_api_app() -> FastAPI:
cfg = load_config()
app = FastAPI(title="llama.cpp OpenAI Wrapper", version="0.1.0")
router = APIRouter()
@app.middleware("http")
async def log_requests(request: Request, call_next):
log.info("Request %s %s", request.method, request.url.path)
return await call_next(request)
@app.exception_handler(Exception)
async def unhandled_exception_handler(request: Request, exc: Exception) -> JSONResponse:
log.exception("Unhandled error")
return JSONResponse(status_code=500, content={"detail": str(exc)})
@router.get("/health")
async def health() -> Dict[str, Any]:
return {
"status": "ok",
"base_url": cfg.base_url,
"model_dir": cfg.model_dir,
"agents": {
"image": cfg.agents.image,
"container_name": cfg.agents.container_name,
"network": cfg.agents.network,
"gpu_count": cfg.agents.gpu_count,
},
"gpu_count_runtime": cfg.gpu_count_runtime,
}
@router.get("/v1/models")
async def list_models() -> Dict[str, Any]:
log.info("Listing models")
return _model_list_payload(cfg.model_dir)
@router.get("/v1/models/{model_id}")
async def get_model(model_id: str) -> Dict[str, Any]:
log.info("Get model %s", model_id)
model = resolve_model(cfg.model_dir, model_id, cfg.model_aliases) or find_model(cfg.model_dir, model_id)
if not model:
raise HTTPException(status_code=404, detail="model not found")
return {
"id": model.model_id,
"object": "model",
"created": model.created,
"owned_by": "llama.cpp",
}
@router.post("/v1/chat/completions")
async def chat_completions(request: Request) -> Response:
payload = await request.json()
payload = normalize_chat_payload(payload)
model_id = payload.get("model")
log.info("Chat completions model=%s stream=%s", model_id, bool(payload.get("stream")))
if model_id:
resolved = await _ensure_model_loaded(model_id, cfg.model_dir)
payload["model"] = resolved
stream = bool(payload.get("stream"))
if stream and _requires_json_mode(payload):
payload = _apply_json_fallback(payload)
if stream:
streamer = proxy_stream(cfg.base_url, "/v1/chat/completions", "POST", dict(request.headers), payload, cfg.proxy_timeout_s)
return StreamingResponse(streamer, media_type="text/event-stream")
resp = await _proxy_json_with_retry(cfg.base_url, "/v1/chat/completions", "POST", dict(request.headers), payload, cfg.proxy_timeout_s)
if resp.status_code >= 500 and _requires_json_mode(payload):
log.info("Retrying chat completion with JSON fallback prompt")
fallback_payload = _apply_json_fallback(payload)
resp = await _proxy_json_with_retry(cfg.base_url, "/v1/chat/completions", "POST", dict(request.headers), fallback_payload, cfg.proxy_timeout_s)
try:
return JSONResponse(status_code=resp.status_code, content=resp.json())
except Exception:
return Response(
status_code=resp.status_code,
content=resp.content,
media_type=resp.headers.get("content-type"),
)
@router.post("/v1/responses")
async def responses(request: Request) -> Response:
payload = await request.json()
chat_payload, model_id = responses_to_chat_payload(payload)
log.info("Responses model=%s stream=%s", model_id, bool(chat_payload.get("stream")))
if model_id:
resolved = await _ensure_model_loaded(model_id, cfg.model_dir)
chat_payload["model"] = resolved
stream = bool(chat_payload.get("stream"))
if stream and _requires_json_mode(chat_payload):
chat_payload = _apply_json_fallback(chat_payload)
if stream:
streamer = stream_chat_to_responses(
cfg.base_url,
dict(request.headers),
chat_payload,
cfg.proxy_timeout_s,
)
return StreamingResponse(streamer, media_type="text/event-stream")
resp = await _proxy_json_with_retry(cfg.base_url, "/v1/chat/completions", "POST", dict(request.headers), chat_payload, cfg.proxy_timeout_s)
if resp.status_code >= 500 and _requires_json_mode(chat_payload):
log.info("Retrying responses with JSON fallback prompt")
fallback_payload = _apply_json_fallback(chat_payload)
resp = await _proxy_json_with_retry(cfg.base_url, "/v1/chat/completions", "POST", dict(request.headers), fallback_payload, cfg.proxy_timeout_s)
resp.raise_for_status()
return JSONResponse(status_code=200, content=chat_to_responses(resp.json(), model_id))
@router.post("/v1/embeddings")
async def embeddings(request: Request) -> Response:
payload = await request.json()
log.info("Embeddings")
resp = await _proxy_json_with_retry(cfg.base_url, "/v1/embeddings", "POST", dict(request.headers), payload, cfg.proxy_timeout_s)
try:
return JSONResponse(status_code=resp.status_code, content=resp.json())
except Exception:
return Response(
status_code=resp.status_code,
content=resp.content,
media_type=resp.headers.get("content-type"),
)
@router.api_route("/proxy/llamacpp/{path:path}", methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"])
async def passthrough(path: str, request: Request) -> Response:
body = await request.body()
resp = await proxy_raw(cfg.base_url, f"/{path}", request.method, dict(request.headers), body, cfg.proxy_timeout_s)
return Response(status_code=resp.status_code, content=resp.content, headers=dict(resp.headers))
app.include_router(router)
return app