import asyncio import logging import time from pathlib import Path from typing import Any, Dict from fastapi import APIRouter, FastAPI, HTTPException, Request, Response from fastapi.responses import JSONResponse, StreamingResponse import httpx from app.config import load_config from app.llamacpp_client import proxy_json, proxy_raw, proxy_stream from app.logging_utils import configure_logging from app.model_registry import find_model, resolve_model, scan_models from app.openai_translate import responses_to_chat_payload, chat_to_responses, normalize_chat_payload from app.restart import RestartPlan, trigger_restart from app.stream_transform import stream_chat_to_responses from app.truenas_middleware import TrueNASConfig, get_active_model_id, switch_model from app.warmup import resolve_warmup_prompt, run_warmup_with_retry configure_logging() log = logging.getLogger("api_app") def _model_list_payload(model_dir: str) -> Dict[str, Any]: data = [] for model in scan_models(model_dir): data.append({ "id": model.model_id, "object": "model", "created": model.created, "owned_by": "llama.cpp", }) return {"object": "list", "data": data} def _requires_json_mode(payload: Dict[str, Any]) -> bool: response_format = payload.get("response_format") if isinstance(response_format, dict) and response_format.get("type") == "json_object": return True if payload.get("return_format") == "json": return True return False def _apply_json_fallback(payload: Dict[str, Any]) -> Dict[str, Any]: payload = dict(payload) payload.pop("response_format", None) payload.pop("return_format", None) messages = payload.get("messages") if isinstance(messages, list): system_msg = {"role": "system", "content": "Respond only with a valid JSON object."} if not messages or messages[0].get("role") != "system": payload["messages"] = [system_msg, *messages] else: payload["messages"] = [system_msg, *messages[1:]] return payload async def _proxy_json_with_retry( base_url: str, path: str, method: str, headers: Dict[str, str], payload: Dict[str, Any], timeout_s: float, delay_s: float = 3.0, ) -> httpx.Response: deadline = time.time() + timeout_s attempt = 0 last_exc: Exception | None = None while time.time() < deadline: attempt += 1 try: resp = await proxy_json(base_url, path, method, headers, payload, timeout_s) if resp.status_code == 503: try: data = resp.json() except Exception: data = {} message = "" if isinstance(data, dict): err = data.get("error") if isinstance(err, dict): message = str(err.get("message") or "") else: message = str(data.get("message") or "") if "loading model" in message.lower(): log.warning("llama.cpp still loading model, retrying (attempt %s)", attempt) await asyncio.sleep(delay_s) continue return resp except httpx.RequestError as exc: last_exc = exc log.warning("Proxy request failed (attempt %s): %s", attempt, exc) await asyncio.sleep(delay_s) if last_exc: raise last_exc raise RuntimeError("proxy retry deadline exceeded") async def _get_active_model_from_truenas(cfg: TrueNASConfig) -> str: try: return await get_active_model_id(cfg) except Exception as exc: log.warning("Failed to read active model from TrueNAS config: %s", exc) return "" async def _wait_for_active_model(cfg: TrueNASConfig, model_id: str, timeout_s: float) -> None: deadline = asyncio.get_event_loop().time() + timeout_s while asyncio.get_event_loop().time() < deadline: active = await _get_active_model_from_truenas(cfg) if active == model_id: return await asyncio.sleep(2) raise RuntimeError(f"active model did not switch to {model_id}") async def _ensure_model_loaded(model_id: str, model_dir: str) -> str: cfg = load_config() model = resolve_model(model_dir, model_id, cfg.model_aliases) if not model: log.warning("Requested model not found: %s", model_id) raise HTTPException(status_code=404, detail="model not found") if model.model_id != model_id: log.info("Resolved model alias %s -> %s", model_id, model.model_id) truenas_cfg = None if cfg.truenas_ws_url and cfg.truenas_api_key: truenas_cfg = TrueNASConfig( ws_url=cfg.truenas_ws_url, api_key=cfg.truenas_api_key, api_user=cfg.truenas_api_user, app_name=cfg.truenas_app_name, verify_ssl=cfg.truenas_verify_ssl, ) active_id = await _get_active_model_from_truenas(truenas_cfg) if active_id and active_id == model.model_id: return model.model_id if truenas_cfg: log.info("Switching model via API model=%s args=%s extra_args=%s", model.model_id, cfg.llamacpp_args, cfg.llamacpp_extra_args) try: model_path = str((Path(cfg.model_container_dir) / model.model_id)) await switch_model( truenas_cfg, model_path, cfg.llamacpp_args, cfg.llamacpp_extra_args, ) await _wait_for_active_model(truenas_cfg, model.model_id, cfg.switch_timeout_s) except Exception as exc: log.exception("TrueNAS model switch failed") raise HTTPException(status_code=500, detail=f"model switch failed: {exc}") warmup_prompt = resolve_warmup_prompt(None, cfg.warmup_prompt_path) log.info("Running warmup prompt after model switch: model=%s prompt_len=%s", model.model_id, len(warmup_prompt)) await run_warmup_with_retry(cfg.base_url, model.model_id, warmup_prompt, timeout_s=cfg.switch_timeout_s) return model.model_id plan = RestartPlan( method=cfg.restart_method, command=cfg.restart_command, url=cfg.restart_url, allowed_container=cfg.allowed_container, ) log.info("Triggering restart for model=%s method=%s", model.model_id, cfg.restart_method) payload = { "model_id": model.model_id, "model_path": str(Path(cfg.model_container_dir) / model.model_id), "gpu_count": cfg.gpu_count_runtime or cfg.agents.gpu_count, "llamacpp_args": cfg.llamacpp_args, "llamacpp_extra_args": cfg.llamacpp_extra_args, } await trigger_restart(plan, payload=payload) warmup_prompt = resolve_warmup_prompt(None, cfg.warmup_prompt_path) log.info("Running warmup prompt after restart: model=%s prompt_len=%s", model.model_id, len(warmup_prompt)) await run_warmup_with_retry(cfg.base_url, model.model_id, warmup_prompt, timeout_s=cfg.switch_timeout_s) return model.model_id def create_api_app() -> FastAPI: cfg = load_config() app = FastAPI(title="llama.cpp OpenAI Wrapper", version="0.1.0") router = APIRouter() @app.middleware("http") async def log_requests(request: Request, call_next): log.info("Request %s %s", request.method, request.url.path) return await call_next(request) @app.exception_handler(Exception) async def unhandled_exception_handler(request: Request, exc: Exception) -> JSONResponse: log.exception("Unhandled error") return JSONResponse(status_code=500, content={"detail": str(exc)}) @router.get("/health") async def health() -> Dict[str, Any]: return { "status": "ok", "base_url": cfg.base_url, "model_dir": cfg.model_dir, "agents": { "image": cfg.agents.image, "container_name": cfg.agents.container_name, "network": cfg.agents.network, "gpu_count": cfg.agents.gpu_count, }, "gpu_count_runtime": cfg.gpu_count_runtime, } @router.get("/v1/models") async def list_models() -> Dict[str, Any]: log.info("Listing models") return _model_list_payload(cfg.model_dir) @router.get("/v1/models/{model_id}") async def get_model(model_id: str) -> Dict[str, Any]: log.info("Get model %s", model_id) model = resolve_model(cfg.model_dir, model_id, cfg.model_aliases) or find_model(cfg.model_dir, model_id) if not model: raise HTTPException(status_code=404, detail="model not found") return { "id": model.model_id, "object": "model", "created": model.created, "owned_by": "llama.cpp", } @router.post("/v1/chat/completions") async def chat_completions(request: Request) -> Response: payload = await request.json() payload = normalize_chat_payload(payload) model_id = payload.get("model") log.info("Chat completions model=%s stream=%s", model_id, bool(payload.get("stream"))) if model_id: resolved = await _ensure_model_loaded(model_id, cfg.model_dir) payload["model"] = resolved stream = bool(payload.get("stream")) if stream and _requires_json_mode(payload): payload = _apply_json_fallback(payload) if stream: streamer = proxy_stream(cfg.base_url, "/v1/chat/completions", "POST", dict(request.headers), payload, cfg.proxy_timeout_s) return StreamingResponse(streamer, media_type="text/event-stream") resp = await _proxy_json_with_retry(cfg.base_url, "/v1/chat/completions", "POST", dict(request.headers), payload, cfg.proxy_timeout_s) if resp.status_code >= 500 and _requires_json_mode(payload): log.info("Retrying chat completion with JSON fallback prompt") fallback_payload = _apply_json_fallback(payload) resp = await _proxy_json_with_retry(cfg.base_url, "/v1/chat/completions", "POST", dict(request.headers), fallback_payload, cfg.proxy_timeout_s) try: return JSONResponse(status_code=resp.status_code, content=resp.json()) except Exception: return Response( status_code=resp.status_code, content=resp.content, media_type=resp.headers.get("content-type"), ) @router.post("/v1/responses") async def responses(request: Request) -> Response: payload = await request.json() chat_payload, model_id = responses_to_chat_payload(payload) log.info("Responses model=%s stream=%s", model_id, bool(chat_payload.get("stream"))) if model_id: resolved = await _ensure_model_loaded(model_id, cfg.model_dir) chat_payload["model"] = resolved stream = bool(chat_payload.get("stream")) if stream and _requires_json_mode(chat_payload): chat_payload = _apply_json_fallback(chat_payload) if stream: streamer = stream_chat_to_responses( cfg.base_url, dict(request.headers), chat_payload, cfg.proxy_timeout_s, ) return StreamingResponse(streamer, media_type="text/event-stream") resp = await _proxy_json_with_retry(cfg.base_url, "/v1/chat/completions", "POST", dict(request.headers), chat_payload, cfg.proxy_timeout_s) if resp.status_code >= 500 and _requires_json_mode(chat_payload): log.info("Retrying responses with JSON fallback prompt") fallback_payload = _apply_json_fallback(chat_payload) resp = await _proxy_json_with_retry(cfg.base_url, "/v1/chat/completions", "POST", dict(request.headers), fallback_payload, cfg.proxy_timeout_s) resp.raise_for_status() return JSONResponse(status_code=200, content=chat_to_responses(resp.json(), model_id)) @router.post("/v1/embeddings") async def embeddings(request: Request) -> Response: payload = await request.json() log.info("Embeddings") resp = await _proxy_json_with_retry(cfg.base_url, "/v1/embeddings", "POST", dict(request.headers), payload, cfg.proxy_timeout_s) try: return JSONResponse(status_code=resp.status_code, content=resp.json()) except Exception: return Response( status_code=resp.status_code, content=resp.content, media_type=resp.headers.get("content-type"), ) @router.api_route("/proxy/llamacpp/{path:path}", methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"]) async def passthrough(path: str, request: Request) -> Response: body = await request.body() resp = await proxy_raw(cfg.base_url, f"/{path}", request.method, dict(request.headers), body, cfg.proxy_timeout_s) return Response(status_code=resp.status_code, content=resp.content, headers=dict(resp.headers)) app.include_router(router) return app