Files
codex_truenas_helper/llamaCpp.Wrapper.app/download_manager.py
Rushabh Gosar 5d1a0ee72b Initial commit
2026-01-07 16:54:39 -08:00

142 lines
5.8 KiB
Python

import asyncio
import fnmatch
import logging
import os
import time
import uuid
from dataclasses import asdict, dataclass, field
from pathlib import Path
from typing import Dict, Optional
import httpx
from app.config import AppConfig
from app.logging_utils import configure_logging
from app.restart import RestartPlan, trigger_restart
configure_logging()
log = logging.getLogger("download_manager")
@dataclass
class DownloadStatus:
download_id: str
url: str
filename: str
status: str
bytes_total: Optional[int] = None
bytes_downloaded: int = 0
started_at: float = field(default_factory=time.time)
finished_at: Optional[float] = None
error: Optional[str] = None
class DownloadManager:
def __init__(self, cfg: AppConfig, broadcaster=None) -> None:
self.cfg = cfg
self._downloads: Dict[str, DownloadStatus] = {}
self._tasks: Dict[str, asyncio.Task] = {}
self._semaphore = asyncio.Semaphore(cfg.download_max_concurrent)
self._broadcaster = broadcaster
async def _emit(self, payload: dict) -> None:
if self._broadcaster:
await self._broadcaster.publish(payload)
def list_downloads(self) -> Dict[str, dict]:
return {k: asdict(v) for k, v in self._downloads.items()}
def get(self, download_id: str) -> Optional[DownloadStatus]:
return self._downloads.get(download_id)
def _is_allowed(self, url: str) -> bool:
if not self.cfg.download_allowlist:
return True
return any(fnmatch.fnmatch(url, pattern) for pattern in self.cfg.download_allowlist)
async def start(self, url: str, filename: Optional[str] = None) -> DownloadStatus:
if not self._is_allowed(url):
raise ValueError("url not allowed by allowlist")
if not filename:
filename = os.path.basename(url.split("?")[0]) or f"model-{uuid.uuid4().hex}.gguf"
log.info("Download requested url=%s filename=%s", url, filename)
download_id = uuid.uuid4().hex
status = DownloadStatus(download_id=download_id, url=url, filename=filename, status="queued")
self._downloads[download_id] = status
task = asyncio.create_task(self._run_download(status))
self._tasks[download_id] = task
await self._emit({"type": "download_status", "download": asdict(status)})
return status
async def cancel(self, download_id: str) -> bool:
task = self._tasks.get(download_id)
if task:
task.cancel()
status = self._downloads.get(download_id)
if status:
log.info("Download cancelled id=%s filename=%s", download_id, status.filename)
await self._emit({"type": "download_status", "download": asdict(status)})
return True
return False
async def _run_download(self, status: DownloadStatus) -> None:
status.status = "downloading"
base = Path(self.cfg.download_dir)
base.mkdir(parents=True, exist_ok=True)
tmp_path = base / f".{status.filename}.partial"
final_path = base / status.filename
last_emit = 0.0
try:
async with self._semaphore:
async with httpx.AsyncClient(timeout=None, follow_redirects=True) as client:
async with client.stream("GET", status.url) as resp:
resp.raise_for_status()
length = resp.headers.get("content-length")
if length:
status.bytes_total = int(length)
with tmp_path.open("wb") as f:
async for chunk in resp.aiter_bytes():
if chunk:
f.write(chunk)
status.bytes_downloaded += len(chunk)
now = time.time()
if now - last_emit >= 1:
last_emit = now
await self._emit({"type": "download_progress", "download": asdict(status)})
if tmp_path.exists():
tmp_path.replace(final_path)
status.status = "completed"
status.finished_at = time.time()
log.info("Download completed id=%s filename=%s", status.download_id, status.filename)
await self._emit({"type": "download_completed", "download": asdict(status)})
if self.cfg.reload_on_new_model:
plan = RestartPlan(
method=self.cfg.restart_method,
command=self.cfg.restart_command,
url=self.cfg.restart_url,
allowed_container=self.cfg.allowed_container,
)
await trigger_restart(
plan,
payload={
"reason": "new_model",
"model_id": status.filename,
"llamacpp_args": self.cfg.llamacpp_args,
"llamacpp_extra_args": self.cfg.llamacpp_extra_args,
},
)
except asyncio.CancelledError:
status.status = "cancelled"
if tmp_path.exists():
tmp_path.unlink(missing_ok=True)
log.info("Download cancelled id=%s filename=%s", status.download_id, status.filename)
await self._emit({"type": "download_cancelled", "download": asdict(status)})
except Exception as exc:
status.status = "error"
status.error = str(exc)
if tmp_path.exists():
tmp_path.unlink(missing_ok=True)
log.info("Download error id=%s filename=%s error=%s", status.download_id, status.filename, exc)
await self._emit({"type": "download_error", "download": asdict(status)})