Initial commit
This commit is contained in:
309
llamaCpp.Wrapper.app/api_app.py
Normal file
309
llamaCpp.Wrapper.app/api_app.py
Normal file
@@ -0,0 +1,309 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict
|
||||
|
||||
from fastapi import APIRouter, FastAPI, HTTPException, Request, Response
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
import httpx
|
||||
|
||||
from app.config import load_config
|
||||
from app.llamacpp_client import proxy_json, proxy_raw, proxy_stream
|
||||
from app.logging_utils import configure_logging
|
||||
from app.model_registry import find_model, resolve_model, scan_models
|
||||
from app.openai_translate import responses_to_chat_payload, chat_to_responses, normalize_chat_payload
|
||||
from app.restart import RestartPlan, trigger_restart
|
||||
from app.stream_transform import stream_chat_to_responses
|
||||
from app.truenas_middleware import TrueNASConfig, get_active_model_id, switch_model
|
||||
from app.warmup import resolve_warmup_prompt, run_warmup_with_retry
|
||||
|
||||
|
||||
configure_logging()
|
||||
log = logging.getLogger("api_app")
|
||||
|
||||
|
||||
def _model_list_payload(model_dir: str) -> Dict[str, Any]:
|
||||
data = []
|
||||
for model in scan_models(model_dir):
|
||||
data.append({
|
||||
"id": model.model_id,
|
||||
"object": "model",
|
||||
"created": model.created,
|
||||
"owned_by": "llama.cpp",
|
||||
})
|
||||
return {"object": "list", "data": data}
|
||||
|
||||
|
||||
def _requires_json_mode(payload: Dict[str, Any]) -> bool:
|
||||
response_format = payload.get("response_format")
|
||||
if isinstance(response_format, dict) and response_format.get("type") == "json_object":
|
||||
return True
|
||||
if payload.get("return_format") == "json":
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _apply_json_fallback(payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||
payload = dict(payload)
|
||||
payload.pop("response_format", None)
|
||||
payload.pop("return_format", None)
|
||||
messages = payload.get("messages")
|
||||
if isinstance(messages, list):
|
||||
system_msg = {"role": "system", "content": "Respond only with a valid JSON object."}
|
||||
if not messages or messages[0].get("role") != "system":
|
||||
payload["messages"] = [system_msg, *messages]
|
||||
else:
|
||||
payload["messages"] = [system_msg, *messages[1:]]
|
||||
return payload
|
||||
|
||||
|
||||
async def _proxy_json_with_retry(
|
||||
base_url: str,
|
||||
path: str,
|
||||
method: str,
|
||||
headers: Dict[str, str],
|
||||
payload: Dict[str, Any],
|
||||
timeout_s: float,
|
||||
delay_s: float = 3.0,
|
||||
) -> httpx.Response:
|
||||
deadline = time.time() + timeout_s
|
||||
attempt = 0
|
||||
last_exc: Exception | None = None
|
||||
while time.time() < deadline:
|
||||
attempt += 1
|
||||
try:
|
||||
resp = await proxy_json(base_url, path, method, headers, payload, timeout_s)
|
||||
if resp.status_code == 503:
|
||||
try:
|
||||
data = resp.json()
|
||||
except Exception:
|
||||
data = {}
|
||||
message = ""
|
||||
if isinstance(data, dict):
|
||||
err = data.get("error")
|
||||
if isinstance(err, dict):
|
||||
message = str(err.get("message") or "")
|
||||
else:
|
||||
message = str(data.get("message") or "")
|
||||
if "loading model" in message.lower():
|
||||
log.warning("llama.cpp still loading model, retrying (attempt %s)", attempt)
|
||||
await asyncio.sleep(delay_s)
|
||||
continue
|
||||
return resp
|
||||
except httpx.RequestError as exc:
|
||||
last_exc = exc
|
||||
log.warning("Proxy request failed (attempt %s): %s", attempt, exc)
|
||||
await asyncio.sleep(delay_s)
|
||||
if last_exc:
|
||||
raise last_exc
|
||||
raise RuntimeError("proxy retry deadline exceeded")
|
||||
|
||||
|
||||
async def _get_active_model_from_truenas(cfg: TrueNASConfig) -> str:
|
||||
try:
|
||||
return await get_active_model_id(cfg)
|
||||
except Exception as exc:
|
||||
log.warning("Failed to read active model from TrueNAS config: %s", exc)
|
||||
return ""
|
||||
|
||||
|
||||
async def _wait_for_active_model(cfg: TrueNASConfig, model_id: str, timeout_s: float) -> None:
|
||||
deadline = asyncio.get_event_loop().time() + timeout_s
|
||||
while asyncio.get_event_loop().time() < deadline:
|
||||
active = await _get_active_model_from_truenas(cfg)
|
||||
if active == model_id:
|
||||
return
|
||||
await asyncio.sleep(2)
|
||||
raise RuntimeError(f"active model did not switch to {model_id}")
|
||||
|
||||
|
||||
async def _ensure_model_loaded(model_id: str, model_dir: str) -> str:
|
||||
cfg = load_config()
|
||||
model = resolve_model(model_dir, model_id, cfg.model_aliases)
|
||||
if not model:
|
||||
log.warning("Requested model not found: %s", model_id)
|
||||
raise HTTPException(status_code=404, detail="model not found")
|
||||
if model.model_id != model_id:
|
||||
log.info("Resolved model alias %s -> %s", model_id, model.model_id)
|
||||
|
||||
truenas_cfg = None
|
||||
if cfg.truenas_ws_url and cfg.truenas_api_key:
|
||||
truenas_cfg = TrueNASConfig(
|
||||
ws_url=cfg.truenas_ws_url,
|
||||
api_key=cfg.truenas_api_key,
|
||||
api_user=cfg.truenas_api_user,
|
||||
app_name=cfg.truenas_app_name,
|
||||
verify_ssl=cfg.truenas_verify_ssl,
|
||||
)
|
||||
active_id = await _get_active_model_from_truenas(truenas_cfg)
|
||||
if active_id and active_id == model.model_id:
|
||||
return model.model_id
|
||||
|
||||
if truenas_cfg:
|
||||
log.info("Switching model via API model=%s args=%s extra_args=%s", model.model_id, cfg.llamacpp_args, cfg.llamacpp_extra_args)
|
||||
try:
|
||||
model_path = str((Path(cfg.model_container_dir) / model.model_id))
|
||||
await switch_model(
|
||||
truenas_cfg,
|
||||
model_path,
|
||||
cfg.llamacpp_args,
|
||||
cfg.llamacpp_extra_args,
|
||||
)
|
||||
await _wait_for_active_model(truenas_cfg, model.model_id, cfg.switch_timeout_s)
|
||||
except Exception as exc:
|
||||
log.exception("TrueNAS model switch failed")
|
||||
raise HTTPException(status_code=500, detail=f"model switch failed: {exc}")
|
||||
warmup_prompt = resolve_warmup_prompt(None, cfg.warmup_prompt_path)
|
||||
log.info("Running warmup prompt after model switch: model=%s prompt_len=%s", model.model_id, len(warmup_prompt))
|
||||
await run_warmup_with_retry(cfg.base_url, model.model_id, warmup_prompt, timeout_s=cfg.switch_timeout_s)
|
||||
return model.model_id
|
||||
|
||||
plan = RestartPlan(
|
||||
method=cfg.restart_method,
|
||||
command=cfg.restart_command,
|
||||
url=cfg.restart_url,
|
||||
allowed_container=cfg.allowed_container,
|
||||
)
|
||||
log.info("Triggering restart for model=%s method=%s", model.model_id, cfg.restart_method)
|
||||
payload = {
|
||||
"model_id": model.model_id,
|
||||
"model_path": str(Path(cfg.model_container_dir) / model.model_id),
|
||||
"gpu_count": cfg.gpu_count_runtime or cfg.agents.gpu_count,
|
||||
"llamacpp_args": cfg.llamacpp_args,
|
||||
"llamacpp_extra_args": cfg.llamacpp_extra_args,
|
||||
}
|
||||
await trigger_restart(plan, payload=payload)
|
||||
warmup_prompt = resolve_warmup_prompt(None, cfg.warmup_prompt_path)
|
||||
log.info("Running warmup prompt after restart: model=%s prompt_len=%s", model.model_id, len(warmup_prompt))
|
||||
await run_warmup_with_retry(cfg.base_url, model.model_id, warmup_prompt, timeout_s=cfg.switch_timeout_s)
|
||||
return model.model_id
|
||||
|
||||
|
||||
def create_api_app() -> FastAPI:
|
||||
cfg = load_config()
|
||||
app = FastAPI(title="llama.cpp OpenAI Wrapper", version="0.1.0")
|
||||
router = APIRouter()
|
||||
|
||||
@app.middleware("http")
|
||||
async def log_requests(request: Request, call_next):
|
||||
log.info("Request %s %s", request.method, request.url.path)
|
||||
return await call_next(request)
|
||||
|
||||
@app.exception_handler(Exception)
|
||||
async def unhandled_exception_handler(request: Request, exc: Exception) -> JSONResponse:
|
||||
log.exception("Unhandled error")
|
||||
return JSONResponse(status_code=500, content={"detail": str(exc)})
|
||||
|
||||
@router.get("/health")
|
||||
async def health() -> Dict[str, Any]:
|
||||
return {
|
||||
"status": "ok",
|
||||
"base_url": cfg.base_url,
|
||||
"model_dir": cfg.model_dir,
|
||||
"agents": {
|
||||
"image": cfg.agents.image,
|
||||
"container_name": cfg.agents.container_name,
|
||||
"network": cfg.agents.network,
|
||||
"gpu_count": cfg.agents.gpu_count,
|
||||
},
|
||||
"gpu_count_runtime": cfg.gpu_count_runtime,
|
||||
}
|
||||
|
||||
@router.get("/v1/models")
|
||||
async def list_models() -> Dict[str, Any]:
|
||||
log.info("Listing models")
|
||||
return _model_list_payload(cfg.model_dir)
|
||||
|
||||
@router.get("/v1/models/{model_id}")
|
||||
async def get_model(model_id: str) -> Dict[str, Any]:
|
||||
log.info("Get model %s", model_id)
|
||||
model = resolve_model(cfg.model_dir, model_id, cfg.model_aliases) or find_model(cfg.model_dir, model_id)
|
||||
if not model:
|
||||
raise HTTPException(status_code=404, detail="model not found")
|
||||
return {
|
||||
"id": model.model_id,
|
||||
"object": "model",
|
||||
"created": model.created,
|
||||
"owned_by": "llama.cpp",
|
||||
}
|
||||
|
||||
@router.post("/v1/chat/completions")
|
||||
async def chat_completions(request: Request) -> Response:
|
||||
payload = await request.json()
|
||||
payload = normalize_chat_payload(payload)
|
||||
model_id = payload.get("model")
|
||||
log.info("Chat completions model=%s stream=%s", model_id, bool(payload.get("stream")))
|
||||
if model_id:
|
||||
resolved = await _ensure_model_loaded(model_id, cfg.model_dir)
|
||||
payload["model"] = resolved
|
||||
stream = bool(payload.get("stream"))
|
||||
if stream and _requires_json_mode(payload):
|
||||
payload = _apply_json_fallback(payload)
|
||||
if stream:
|
||||
streamer = proxy_stream(cfg.base_url, "/v1/chat/completions", "POST", dict(request.headers), payload, cfg.proxy_timeout_s)
|
||||
return StreamingResponse(streamer, media_type="text/event-stream")
|
||||
resp = await _proxy_json_with_retry(cfg.base_url, "/v1/chat/completions", "POST", dict(request.headers), payload, cfg.proxy_timeout_s)
|
||||
if resp.status_code >= 500 and _requires_json_mode(payload):
|
||||
log.info("Retrying chat completion with JSON fallback prompt")
|
||||
fallback_payload = _apply_json_fallback(payload)
|
||||
resp = await _proxy_json_with_retry(cfg.base_url, "/v1/chat/completions", "POST", dict(request.headers), fallback_payload, cfg.proxy_timeout_s)
|
||||
try:
|
||||
return JSONResponse(status_code=resp.status_code, content=resp.json())
|
||||
except Exception:
|
||||
return Response(
|
||||
status_code=resp.status_code,
|
||||
content=resp.content,
|
||||
media_type=resp.headers.get("content-type"),
|
||||
)
|
||||
|
||||
@router.post("/v1/responses")
|
||||
async def responses(request: Request) -> Response:
|
||||
payload = await request.json()
|
||||
chat_payload, model_id = responses_to_chat_payload(payload)
|
||||
log.info("Responses model=%s stream=%s", model_id, bool(chat_payload.get("stream")))
|
||||
if model_id:
|
||||
resolved = await _ensure_model_loaded(model_id, cfg.model_dir)
|
||||
chat_payload["model"] = resolved
|
||||
stream = bool(chat_payload.get("stream"))
|
||||
if stream and _requires_json_mode(chat_payload):
|
||||
chat_payload = _apply_json_fallback(chat_payload)
|
||||
if stream:
|
||||
streamer = stream_chat_to_responses(
|
||||
cfg.base_url,
|
||||
dict(request.headers),
|
||||
chat_payload,
|
||||
cfg.proxy_timeout_s,
|
||||
)
|
||||
return StreamingResponse(streamer, media_type="text/event-stream")
|
||||
resp = await _proxy_json_with_retry(cfg.base_url, "/v1/chat/completions", "POST", dict(request.headers), chat_payload, cfg.proxy_timeout_s)
|
||||
if resp.status_code >= 500 and _requires_json_mode(chat_payload):
|
||||
log.info("Retrying responses with JSON fallback prompt")
|
||||
fallback_payload = _apply_json_fallback(chat_payload)
|
||||
resp = await _proxy_json_with_retry(cfg.base_url, "/v1/chat/completions", "POST", dict(request.headers), fallback_payload, cfg.proxy_timeout_s)
|
||||
resp.raise_for_status()
|
||||
return JSONResponse(status_code=200, content=chat_to_responses(resp.json(), model_id))
|
||||
|
||||
@router.post("/v1/embeddings")
|
||||
async def embeddings(request: Request) -> Response:
|
||||
payload = await request.json()
|
||||
log.info("Embeddings")
|
||||
resp = await _proxy_json_with_retry(cfg.base_url, "/v1/embeddings", "POST", dict(request.headers), payload, cfg.proxy_timeout_s)
|
||||
try:
|
||||
return JSONResponse(status_code=resp.status_code, content=resp.json())
|
||||
except Exception:
|
||||
return Response(
|
||||
status_code=resp.status_code,
|
||||
content=resp.content,
|
||||
media_type=resp.headers.get("content-type"),
|
||||
)
|
||||
|
||||
@router.api_route("/proxy/llamacpp/{path:path}", methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"])
|
||||
async def passthrough(path: str, request: Request) -> Response:
|
||||
body = await request.body()
|
||||
resp = await proxy_raw(cfg.base_url, f"/{path}", request.method, dict(request.headers), body, cfg.proxy_timeout_s)
|
||||
return Response(status_code=resp.status_code, content=resp.content, headers=dict(resp.headers))
|
||||
|
||||
app.include_router(router)
|
||||
return app
|
||||
|
||||
Reference in New Issue
Block a user