358 lines
14 KiB
Python
358 lines
14 KiB
Python
import asyncio
|
|
import json
|
|
import logging
|
|
from pathlib import Path
|
|
from typing import Any, Dict, Optional
|
|
|
|
import httpx
|
|
from fastapi import FastAPI, HTTPException, Request
|
|
from fastapi.responses import FileResponse, HTMLResponse, JSONResponse, StreamingResponse
|
|
|
|
from app.config import load_config
|
|
from app.docker_logs import docker_container_logs
|
|
from app.download_manager import DownloadManager
|
|
from app.logging_utils import configure_logging
|
|
from app.model_registry import scan_models
|
|
from app.truenas_middleware import (
|
|
TrueNASConfig,
|
|
get_active_model_id,
|
|
get_app_command,
|
|
get_app_logs,
|
|
switch_model,
|
|
update_command_flags,
|
|
)
|
|
from app.warmup import resolve_warmup_prompt, run_warmup_with_retry
|
|
|
|
|
|
configure_logging()
|
|
log = logging.getLogger("ui_app")
|
|
|
|
|
|
class EventBroadcaster:
|
|
def __init__(self) -> None:
|
|
self._queues: set[asyncio.Queue] = set()
|
|
|
|
def connect(self) -> asyncio.Queue:
|
|
queue: asyncio.Queue = asyncio.Queue()
|
|
self._queues.add(queue)
|
|
return queue
|
|
|
|
def disconnect(self, queue: asyncio.Queue) -> None:
|
|
self._queues.discard(queue)
|
|
|
|
async def publish(self, payload: dict) -> None:
|
|
for queue in list(self._queues):
|
|
queue.put_nowait(payload)
|
|
|
|
|
|
def _static_path() -> Path:
|
|
return Path(__file__).parent / "ui_static"
|
|
|
|
|
|
async def _fetch_active_model(truenas_cfg: Optional[TrueNASConfig]) -> Optional[str]:
|
|
if not truenas_cfg:
|
|
return None
|
|
try:
|
|
return await get_active_model_id(truenas_cfg)
|
|
except Exception as exc:
|
|
log.warning("Failed to read active model from TrueNAS config: %s", exc)
|
|
return None
|
|
|
|
|
|
def _model_list(model_dir: str, active_model: Optional[str]) -> Dict[str, Any]:
|
|
data = []
|
|
for model in scan_models(model_dir):
|
|
data.append({
|
|
"id": model.model_id,
|
|
"size": model.size,
|
|
"active": model.model_id == active_model,
|
|
})
|
|
return {"models": data, "active_model": active_model}
|
|
|
|
|
|
def create_ui_app() -> FastAPI:
|
|
cfg = load_config()
|
|
app = FastAPI(title="llama.cpp Model Manager", version="0.1.0")
|
|
broadcaster = EventBroadcaster()
|
|
manager = DownloadManager(cfg, broadcaster=broadcaster)
|
|
truenas_cfg = None
|
|
if cfg.truenas_ws_url and cfg.truenas_api_key:
|
|
truenas_cfg = TrueNASConfig(
|
|
ws_url=cfg.truenas_ws_url,
|
|
api_key=cfg.truenas_api_key,
|
|
api_user=cfg.truenas_api_user,
|
|
app_name=cfg.truenas_app_name,
|
|
verify_ssl=cfg.truenas_verify_ssl,
|
|
)
|
|
|
|
async def monitor_active_model() -> None:
|
|
last_model = None
|
|
while True:
|
|
current = await _fetch_active_model(truenas_cfg)
|
|
if current and current != last_model:
|
|
last_model = current
|
|
await broadcaster.publish({"type": "active_model", "model_id": current})
|
|
await asyncio.sleep(3)
|
|
|
|
async def _fetch_logs() -> str:
|
|
logs = ""
|
|
if truenas_cfg:
|
|
try:
|
|
logs = await asyncio.wait_for(get_app_logs(truenas_cfg, tail_lines=200), timeout=5)
|
|
except asyncio.TimeoutError:
|
|
logs = ""
|
|
if not logs and cfg.llamacpp_container_name:
|
|
try:
|
|
logs = await asyncio.wait_for(
|
|
docker_container_logs(cfg.llamacpp_container_name, tail_lines=200),
|
|
timeout=10,
|
|
)
|
|
except asyncio.TimeoutError:
|
|
logs = ""
|
|
return logs
|
|
|
|
@app.on_event("startup")
|
|
async def start_tasks() -> None:
|
|
asyncio.create_task(monitor_active_model())
|
|
|
|
@app.middleware("http")
|
|
async def log_requests(request: Request, call_next):
|
|
log.info("UI request %s %s", request.method, request.url.path)
|
|
return await call_next(request)
|
|
|
|
@app.get("/health")
|
|
async def health() -> Dict[str, Any]:
|
|
return {"status": "ok", "model_dir": cfg.model_dir}
|
|
|
|
@app.get("/")
|
|
async def index() -> HTMLResponse:
|
|
return FileResponse(_static_path() / "index.html")
|
|
|
|
@app.get("/ui/styles.css")
|
|
async def styles() -> FileResponse:
|
|
return FileResponse(_static_path() / "styles.css")
|
|
|
|
@app.get("/ui/app.js")
|
|
async def app_js() -> FileResponse:
|
|
return FileResponse(_static_path() / "app.js")
|
|
|
|
@app.get("/ui/api/models")
|
|
async def list_models() -> JSONResponse:
|
|
active_model = await _fetch_active_model(truenas_cfg)
|
|
log.info("UI list models active=%s", active_model)
|
|
return JSONResponse(_model_list(cfg.model_dir, active_model))
|
|
|
|
@app.get("/ui/api/downloads")
|
|
async def list_downloads() -> JSONResponse:
|
|
log.info("UI list downloads")
|
|
return JSONResponse({"downloads": manager.list_downloads()})
|
|
|
|
@app.post("/ui/api/downloads")
|
|
async def start_download(request: Request) -> JSONResponse:
|
|
payload = await request.json()
|
|
url = payload.get("url")
|
|
filename = payload.get("filename")
|
|
log.info("UI download start url=%s filename=%s", url, filename)
|
|
if not url:
|
|
raise HTTPException(status_code=400, detail="url is required")
|
|
try:
|
|
status = await manager.start(url, filename=filename)
|
|
except ValueError as exc:
|
|
raise HTTPException(status_code=403, detail=str(exc))
|
|
return JSONResponse({"download": status.__dict__})
|
|
|
|
@app.delete("/ui/api/downloads/{download_id}")
|
|
async def cancel_download(download_id: str) -> JSONResponse:
|
|
log.info("UI download cancel id=%s", download_id)
|
|
ok = await manager.cancel(download_id)
|
|
if not ok:
|
|
raise HTTPException(status_code=404, detail="download not found")
|
|
return JSONResponse({"status": "cancelled"})
|
|
|
|
@app.get("/ui/api/events")
|
|
async def events() -> StreamingResponse:
|
|
queue = broadcaster.connect()
|
|
|
|
async def event_stream():
|
|
try:
|
|
while True:
|
|
payload = await queue.get()
|
|
data = json.dumps(payload, separators=(",", ":"))
|
|
yield f"data: {data}\n\n".encode("utf-8")
|
|
finally:
|
|
broadcaster.disconnect(queue)
|
|
|
|
return StreamingResponse(event_stream(), media_type="text/event-stream")
|
|
|
|
@app.post("/ui/api/switch-model")
|
|
async def switch_model_ui(request: Request) -> JSONResponse:
|
|
payload = await request.json()
|
|
model_id = payload.get("model_id")
|
|
warmup_override = payload.get("warmup_prompt") or ""
|
|
if not model_id:
|
|
raise HTTPException(status_code=400, detail="model_id is required")
|
|
|
|
model_path = Path(cfg.model_dir) / model_id
|
|
if not model_path.exists():
|
|
raise HTTPException(status_code=404, detail="model not found")
|
|
|
|
if not truenas_cfg:
|
|
raise HTTPException(status_code=500, detail="TrueNAS credentials not configured")
|
|
|
|
try:
|
|
container_model_path = str(Path(cfg.model_container_dir) / model_id)
|
|
await switch_model(truenas_cfg, container_model_path, cfg.llamacpp_args, cfg.llamacpp_extra_args)
|
|
except Exception as exc:
|
|
await broadcaster.publish({"type": "model_switch_failed", "model_id": model_id, "error": str(exc)})
|
|
raise HTTPException(status_code=500, detail=f"model switch failed: {exc}")
|
|
|
|
warmup_prompt = resolve_warmup_prompt(warmup_override, cfg.warmup_prompt_path)
|
|
log.info("UI warmup after switch model=%s prompt_len=%s", model_id, len(warmup_prompt))
|
|
try:
|
|
await run_warmup_with_retry(cfg.base_url, model_id, warmup_prompt, timeout_s=cfg.switch_timeout_s)
|
|
except Exception as exc:
|
|
await broadcaster.publish({"type": "model_switch_failed", "model_id": model_id, "error": str(exc)})
|
|
raise HTTPException(status_code=500, detail=f"model switch warmup failed: {exc}")
|
|
|
|
try:
|
|
async with httpx.AsyncClient(base_url=cfg.base_url, timeout=120) as client:
|
|
resp = await client.post(
|
|
"/v1/chat/completions",
|
|
json={
|
|
"model": model_id,
|
|
"messages": [{"role": "user", "content": "ok"}],
|
|
"max_tokens": 4,
|
|
"temperature": 0,
|
|
},
|
|
)
|
|
resp.raise_for_status()
|
|
except Exception as exc:
|
|
await broadcaster.publish({"type": "model_switch_failed", "model_id": model_id, "error": str(exc)})
|
|
raise HTTPException(status_code=500, detail=f"model switch verification failed: {exc}")
|
|
|
|
await broadcaster.publish({"type": "model_switched", "model_id": model_id})
|
|
log.info("UI model switched model=%s", model_id)
|
|
return JSONResponse({"status": "ok", "model_id": model_id})
|
|
|
|
@app.get("/ui/api/llamacpp-config")
|
|
async def get_llamacpp_config() -> JSONResponse:
|
|
active_model = await _fetch_active_model(truenas_cfg)
|
|
log.info("UI get llama.cpp config active=%s", active_model)
|
|
params: Dict[str, Optional[str]] = {}
|
|
command_raw = []
|
|
if truenas_cfg:
|
|
command_raw = await get_app_command(truenas_cfg)
|
|
flag_map = {
|
|
"--ctx-size": "ctx_size",
|
|
"--n-gpu-layers": "n_gpu_layers",
|
|
"--tensor-split": "tensor_split",
|
|
"--split-mode": "split_mode",
|
|
"--cache-type-k": "cache_type_k",
|
|
"--cache-type-v": "cache_type_v",
|
|
"--flash-attn": "flash_attn",
|
|
"--temp": "temp",
|
|
"--top-k": "top_k",
|
|
"--top-p": "top_p",
|
|
"--repeat-penalty": "repeat_penalty",
|
|
"--repeat-last-n": "repeat_last_n",
|
|
"--frequency-penalty": "frequency_penalty",
|
|
"--presence-penalty": "presence_penalty",
|
|
}
|
|
if isinstance(command_raw, list):
|
|
for flag, key in flag_map.items():
|
|
if flag in command_raw:
|
|
idx = command_raw.index(flag)
|
|
if idx + 1 < len(command_raw):
|
|
params[key] = command_raw[idx + 1]
|
|
known_flags = set(flag_map.keys()) | {"--model"}
|
|
extra = []
|
|
if isinstance(command_raw, list):
|
|
skip_next = False
|
|
for item in command_raw:
|
|
if skip_next:
|
|
skip_next = False
|
|
continue
|
|
if item in known_flags:
|
|
skip_next = True
|
|
continue
|
|
extra.append(item)
|
|
return JSONResponse(
|
|
{
|
|
"active_model": active_model,
|
|
"params": params,
|
|
"extra_args": " ".join(extra),
|
|
}
|
|
)
|
|
|
|
@app.post("/ui/api/llamacpp-config")
|
|
async def update_llamacpp_config(request: Request) -> JSONResponse:
|
|
payload = await request.json()
|
|
params = payload.get("params") or {}
|
|
extra_args = payload.get("extra_args") or ""
|
|
warmup_override = payload.get("warmup_prompt") or ""
|
|
log.info("UI save llama.cpp config params=%s extra_args=%s", params, extra_args)
|
|
if not truenas_cfg:
|
|
raise HTTPException(status_code=500, detail="TrueNAS credentials not configured")
|
|
flags = {
|
|
"--ctx-size": params.get("ctx_size"),
|
|
"--n-gpu-layers": params.get("n_gpu_layers"),
|
|
"--tensor-split": params.get("tensor_split"),
|
|
"--split-mode": params.get("split_mode"),
|
|
"--cache-type-k": params.get("cache_type_k"),
|
|
"--cache-type-v": params.get("cache_type_v"),
|
|
"--flash-attn": params.get("flash_attn"),
|
|
"--temp": params.get("temp"),
|
|
"--top-k": params.get("top_k"),
|
|
"--top-p": params.get("top_p"),
|
|
"--repeat-penalty": params.get("repeat_penalty"),
|
|
"--repeat-last-n": params.get("repeat_last_n"),
|
|
"--frequency-penalty": params.get("frequency_penalty"),
|
|
"--presence-penalty": params.get("presence_penalty"),
|
|
}
|
|
try:
|
|
await update_command_flags(truenas_cfg, flags, extra_args)
|
|
except Exception as exc:
|
|
log.exception("UI update llama.cpp config failed")
|
|
raise HTTPException(status_code=500, detail=f"config update failed: {exc}")
|
|
active_model = await _fetch_active_model(truenas_cfg)
|
|
if active_model:
|
|
warmup_prompt = resolve_warmup_prompt(warmup_override, cfg.warmup_prompt_path)
|
|
log.info("UI warmup after config update model=%s prompt_len=%s", active_model, len(warmup_prompt))
|
|
try:
|
|
await run_warmup_with_retry(cfg.base_url, active_model, warmup_prompt, timeout_s=cfg.switch_timeout_s)
|
|
except Exception as exc:
|
|
raise HTTPException(status_code=500, detail=f"config warmup failed: {exc}")
|
|
await broadcaster.publish({"type": "llamacpp_config_updated"})
|
|
return JSONResponse({"status": "ok"})
|
|
|
|
@app.get("/ui/api/llamacpp-logs")
|
|
async def get_llamacpp_logs() -> JSONResponse:
|
|
logs = await _fetch_logs()
|
|
return JSONResponse({"logs": logs})
|
|
|
|
@app.get("/ui/api/llamacpp-logs/stream")
|
|
async def stream_llamacpp_logs() -> StreamingResponse:
|
|
async def event_stream():
|
|
last_lines: list[str] = []
|
|
while True:
|
|
logs = await _fetch_logs()
|
|
lines = logs.splitlines()
|
|
if last_lines:
|
|
last_tail = last_lines[-1]
|
|
idx = -1
|
|
for i in range(len(lines) - 1, -1, -1):
|
|
if lines[i] == last_tail:
|
|
idx = i
|
|
break
|
|
if idx >= 0:
|
|
lines = lines[idx + 1 :]
|
|
if lines:
|
|
last_lines = (last_lines + lines)[-200:]
|
|
data = json.dumps({"type": "logs", "lines": lines}, separators=(",", ":"))
|
|
yield f"data: {data}\n\n".encode("utf-8")
|
|
await asyncio.sleep(2)
|
|
|
|
return StreamingResponse(event_stream(), media_type="text/event-stream")
|
|
|
|
return app
|