Files
Rushabh Gosar 5d1a0ee72b Initial commit
2026-01-07 16:54:39 -08:00

358 lines
14 KiB
Python

import asyncio
import json
import logging
from pathlib import Path
from typing import Any, Dict, Optional
import httpx
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import FileResponse, HTMLResponse, JSONResponse, StreamingResponse
from app.config import load_config
from app.docker_logs import docker_container_logs
from app.download_manager import DownloadManager
from app.logging_utils import configure_logging
from app.model_registry import scan_models
from app.truenas_middleware import (
TrueNASConfig,
get_active_model_id,
get_app_command,
get_app_logs,
switch_model,
update_command_flags,
)
from app.warmup import resolve_warmup_prompt, run_warmup_with_retry
configure_logging()
log = logging.getLogger("ui_app")
class EventBroadcaster:
def __init__(self) -> None:
self._queues: set[asyncio.Queue] = set()
def connect(self) -> asyncio.Queue:
queue: asyncio.Queue = asyncio.Queue()
self._queues.add(queue)
return queue
def disconnect(self, queue: asyncio.Queue) -> None:
self._queues.discard(queue)
async def publish(self, payload: dict) -> None:
for queue in list(self._queues):
queue.put_nowait(payload)
def _static_path() -> Path:
return Path(__file__).parent / "ui_static"
async def _fetch_active_model(truenas_cfg: Optional[TrueNASConfig]) -> Optional[str]:
if not truenas_cfg:
return None
try:
return await get_active_model_id(truenas_cfg)
except Exception as exc:
log.warning("Failed to read active model from TrueNAS config: %s", exc)
return None
def _model_list(model_dir: str, active_model: Optional[str]) -> Dict[str, Any]:
data = []
for model in scan_models(model_dir):
data.append({
"id": model.model_id,
"size": model.size,
"active": model.model_id == active_model,
})
return {"models": data, "active_model": active_model}
def create_ui_app() -> FastAPI:
cfg = load_config()
app = FastAPI(title="llama.cpp Model Manager", version="0.1.0")
broadcaster = EventBroadcaster()
manager = DownloadManager(cfg, broadcaster=broadcaster)
truenas_cfg = None
if cfg.truenas_ws_url and cfg.truenas_api_key:
truenas_cfg = TrueNASConfig(
ws_url=cfg.truenas_ws_url,
api_key=cfg.truenas_api_key,
api_user=cfg.truenas_api_user,
app_name=cfg.truenas_app_name,
verify_ssl=cfg.truenas_verify_ssl,
)
async def monitor_active_model() -> None:
last_model = None
while True:
current = await _fetch_active_model(truenas_cfg)
if current and current != last_model:
last_model = current
await broadcaster.publish({"type": "active_model", "model_id": current})
await asyncio.sleep(3)
async def _fetch_logs() -> str:
logs = ""
if truenas_cfg:
try:
logs = await asyncio.wait_for(get_app_logs(truenas_cfg, tail_lines=200), timeout=5)
except asyncio.TimeoutError:
logs = ""
if not logs and cfg.llamacpp_container_name:
try:
logs = await asyncio.wait_for(
docker_container_logs(cfg.llamacpp_container_name, tail_lines=200),
timeout=10,
)
except asyncio.TimeoutError:
logs = ""
return logs
@app.on_event("startup")
async def start_tasks() -> None:
asyncio.create_task(monitor_active_model())
@app.middleware("http")
async def log_requests(request: Request, call_next):
log.info("UI request %s %s", request.method, request.url.path)
return await call_next(request)
@app.get("/health")
async def health() -> Dict[str, Any]:
return {"status": "ok", "model_dir": cfg.model_dir}
@app.get("/")
async def index() -> HTMLResponse:
return FileResponse(_static_path() / "index.html")
@app.get("/ui/styles.css")
async def styles() -> FileResponse:
return FileResponse(_static_path() / "styles.css")
@app.get("/ui/app.js")
async def app_js() -> FileResponse:
return FileResponse(_static_path() / "app.js")
@app.get("/ui/api/models")
async def list_models() -> JSONResponse:
active_model = await _fetch_active_model(truenas_cfg)
log.info("UI list models active=%s", active_model)
return JSONResponse(_model_list(cfg.model_dir, active_model))
@app.get("/ui/api/downloads")
async def list_downloads() -> JSONResponse:
log.info("UI list downloads")
return JSONResponse({"downloads": manager.list_downloads()})
@app.post("/ui/api/downloads")
async def start_download(request: Request) -> JSONResponse:
payload = await request.json()
url = payload.get("url")
filename = payload.get("filename")
log.info("UI download start url=%s filename=%s", url, filename)
if not url:
raise HTTPException(status_code=400, detail="url is required")
try:
status = await manager.start(url, filename=filename)
except ValueError as exc:
raise HTTPException(status_code=403, detail=str(exc))
return JSONResponse({"download": status.__dict__})
@app.delete("/ui/api/downloads/{download_id}")
async def cancel_download(download_id: str) -> JSONResponse:
log.info("UI download cancel id=%s", download_id)
ok = await manager.cancel(download_id)
if not ok:
raise HTTPException(status_code=404, detail="download not found")
return JSONResponse({"status": "cancelled"})
@app.get("/ui/api/events")
async def events() -> StreamingResponse:
queue = broadcaster.connect()
async def event_stream():
try:
while True:
payload = await queue.get()
data = json.dumps(payload, separators=(",", ":"))
yield f"data: {data}\n\n".encode("utf-8")
finally:
broadcaster.disconnect(queue)
return StreamingResponse(event_stream(), media_type="text/event-stream")
@app.post("/ui/api/switch-model")
async def switch_model_ui(request: Request) -> JSONResponse:
payload = await request.json()
model_id = payload.get("model_id")
warmup_override = payload.get("warmup_prompt") or ""
if not model_id:
raise HTTPException(status_code=400, detail="model_id is required")
model_path = Path(cfg.model_dir) / model_id
if not model_path.exists():
raise HTTPException(status_code=404, detail="model not found")
if not truenas_cfg:
raise HTTPException(status_code=500, detail="TrueNAS credentials not configured")
try:
container_model_path = str(Path(cfg.model_container_dir) / model_id)
await switch_model(truenas_cfg, container_model_path, cfg.llamacpp_args, cfg.llamacpp_extra_args)
except Exception as exc:
await broadcaster.publish({"type": "model_switch_failed", "model_id": model_id, "error": str(exc)})
raise HTTPException(status_code=500, detail=f"model switch failed: {exc}")
warmup_prompt = resolve_warmup_prompt(warmup_override, cfg.warmup_prompt_path)
log.info("UI warmup after switch model=%s prompt_len=%s", model_id, len(warmup_prompt))
try:
await run_warmup_with_retry(cfg.base_url, model_id, warmup_prompt, timeout_s=cfg.switch_timeout_s)
except Exception as exc:
await broadcaster.publish({"type": "model_switch_failed", "model_id": model_id, "error": str(exc)})
raise HTTPException(status_code=500, detail=f"model switch warmup failed: {exc}")
try:
async with httpx.AsyncClient(base_url=cfg.base_url, timeout=120) as client:
resp = await client.post(
"/v1/chat/completions",
json={
"model": model_id,
"messages": [{"role": "user", "content": "ok"}],
"max_tokens": 4,
"temperature": 0,
},
)
resp.raise_for_status()
except Exception as exc:
await broadcaster.publish({"type": "model_switch_failed", "model_id": model_id, "error": str(exc)})
raise HTTPException(status_code=500, detail=f"model switch verification failed: {exc}")
await broadcaster.publish({"type": "model_switched", "model_id": model_id})
log.info("UI model switched model=%s", model_id)
return JSONResponse({"status": "ok", "model_id": model_id})
@app.get("/ui/api/llamacpp-config")
async def get_llamacpp_config() -> JSONResponse:
active_model = await _fetch_active_model(truenas_cfg)
log.info("UI get llama.cpp config active=%s", active_model)
params: Dict[str, Optional[str]] = {}
command_raw = []
if truenas_cfg:
command_raw = await get_app_command(truenas_cfg)
flag_map = {
"--ctx-size": "ctx_size",
"--n-gpu-layers": "n_gpu_layers",
"--tensor-split": "tensor_split",
"--split-mode": "split_mode",
"--cache-type-k": "cache_type_k",
"--cache-type-v": "cache_type_v",
"--flash-attn": "flash_attn",
"--temp": "temp",
"--top-k": "top_k",
"--top-p": "top_p",
"--repeat-penalty": "repeat_penalty",
"--repeat-last-n": "repeat_last_n",
"--frequency-penalty": "frequency_penalty",
"--presence-penalty": "presence_penalty",
}
if isinstance(command_raw, list):
for flag, key in flag_map.items():
if flag in command_raw:
idx = command_raw.index(flag)
if idx + 1 < len(command_raw):
params[key] = command_raw[idx + 1]
known_flags = set(flag_map.keys()) | {"--model"}
extra = []
if isinstance(command_raw, list):
skip_next = False
for item in command_raw:
if skip_next:
skip_next = False
continue
if item in known_flags:
skip_next = True
continue
extra.append(item)
return JSONResponse(
{
"active_model": active_model,
"params": params,
"extra_args": " ".join(extra),
}
)
@app.post("/ui/api/llamacpp-config")
async def update_llamacpp_config(request: Request) -> JSONResponse:
payload = await request.json()
params = payload.get("params") or {}
extra_args = payload.get("extra_args") or ""
warmup_override = payload.get("warmup_prompt") or ""
log.info("UI save llama.cpp config params=%s extra_args=%s", params, extra_args)
if not truenas_cfg:
raise HTTPException(status_code=500, detail="TrueNAS credentials not configured")
flags = {
"--ctx-size": params.get("ctx_size"),
"--n-gpu-layers": params.get("n_gpu_layers"),
"--tensor-split": params.get("tensor_split"),
"--split-mode": params.get("split_mode"),
"--cache-type-k": params.get("cache_type_k"),
"--cache-type-v": params.get("cache_type_v"),
"--flash-attn": params.get("flash_attn"),
"--temp": params.get("temp"),
"--top-k": params.get("top_k"),
"--top-p": params.get("top_p"),
"--repeat-penalty": params.get("repeat_penalty"),
"--repeat-last-n": params.get("repeat_last_n"),
"--frequency-penalty": params.get("frequency_penalty"),
"--presence-penalty": params.get("presence_penalty"),
}
try:
await update_command_flags(truenas_cfg, flags, extra_args)
except Exception as exc:
log.exception("UI update llama.cpp config failed")
raise HTTPException(status_code=500, detail=f"config update failed: {exc}")
active_model = await _fetch_active_model(truenas_cfg)
if active_model:
warmup_prompt = resolve_warmup_prompt(warmup_override, cfg.warmup_prompt_path)
log.info("UI warmup after config update model=%s prompt_len=%s", active_model, len(warmup_prompt))
try:
await run_warmup_with_retry(cfg.base_url, active_model, warmup_prompt, timeout_s=cfg.switch_timeout_s)
except Exception as exc:
raise HTTPException(status_code=500, detail=f"config warmup failed: {exc}")
await broadcaster.publish({"type": "llamacpp_config_updated"})
return JSONResponse({"status": "ok"})
@app.get("/ui/api/llamacpp-logs")
async def get_llamacpp_logs() -> JSONResponse:
logs = await _fetch_logs()
return JSONResponse({"logs": logs})
@app.get("/ui/api/llamacpp-logs/stream")
async def stream_llamacpp_logs() -> StreamingResponse:
async def event_stream():
last_lines: list[str] = []
while True:
logs = await _fetch_logs()
lines = logs.splitlines()
if last_lines:
last_tail = last_lines[-1]
idx = -1
for i in range(len(lines) - 1, -1, -1):
if lines[i] == last_tail:
idx = i
break
if idx >= 0:
lines = lines[idx + 1 :]
if lines:
last_lines = (last_lines + lines)[-200:]
data = json.dumps({"type": "logs", "lines": lines}, separators=(",", ":"))
yield f"data: {data}\n\n".encode("utf-8")
await asyncio.sleep(2)
return StreamingResponse(event_stream(), media_type="text/event-stream")
return app