import asyncio import fnmatch import logging import os import time import uuid from dataclasses import asdict, dataclass, field from pathlib import Path from typing import Dict, Optional import httpx from app.config import AppConfig from app.logging_utils import configure_logging from app.restart import RestartPlan, trigger_restart configure_logging() log = logging.getLogger("download_manager") @dataclass class DownloadStatus: download_id: str url: str filename: str status: str bytes_total: Optional[int] = None bytes_downloaded: int = 0 started_at: float = field(default_factory=time.time) finished_at: Optional[float] = None error: Optional[str] = None class DownloadManager: def __init__(self, cfg: AppConfig, broadcaster=None) -> None: self.cfg = cfg self._downloads: Dict[str, DownloadStatus] = {} self._tasks: Dict[str, asyncio.Task] = {} self._semaphore = asyncio.Semaphore(cfg.download_max_concurrent) self._broadcaster = broadcaster async def _emit(self, payload: dict) -> None: if self._broadcaster: await self._broadcaster.publish(payload) def list_downloads(self) -> Dict[str, dict]: return {k: asdict(v) for k, v in self._downloads.items()} def get(self, download_id: str) -> Optional[DownloadStatus]: return self._downloads.get(download_id) def _is_allowed(self, url: str) -> bool: if not self.cfg.download_allowlist: return True return any(fnmatch.fnmatch(url, pattern) for pattern in self.cfg.download_allowlist) async def start(self, url: str, filename: Optional[str] = None) -> DownloadStatus: if not self._is_allowed(url): raise ValueError("url not allowed by allowlist") if not filename: filename = os.path.basename(url.split("?")[0]) or f"model-{uuid.uuid4().hex}.gguf" log.info("Download requested url=%s filename=%s", url, filename) download_id = uuid.uuid4().hex status = DownloadStatus(download_id=download_id, url=url, filename=filename, status="queued") self._downloads[download_id] = status task = asyncio.create_task(self._run_download(status)) self._tasks[download_id] = task await self._emit({"type": "download_status", "download": asdict(status)}) return status async def cancel(self, download_id: str) -> bool: task = self._tasks.get(download_id) if task: task.cancel() status = self._downloads.get(download_id) if status: log.info("Download cancelled id=%s filename=%s", download_id, status.filename) await self._emit({"type": "download_status", "download": asdict(status)}) return True return False async def _run_download(self, status: DownloadStatus) -> None: status.status = "downloading" base = Path(self.cfg.download_dir) base.mkdir(parents=True, exist_ok=True) tmp_path = base / f".{status.filename}.partial" final_path = base / status.filename last_emit = 0.0 try: async with self._semaphore: async with httpx.AsyncClient(timeout=None, follow_redirects=True) as client: async with client.stream("GET", status.url) as resp: resp.raise_for_status() length = resp.headers.get("content-length") if length: status.bytes_total = int(length) with tmp_path.open("wb") as f: async for chunk in resp.aiter_bytes(): if chunk: f.write(chunk) status.bytes_downloaded += len(chunk) now = time.time() if now - last_emit >= 1: last_emit = now await self._emit({"type": "download_progress", "download": asdict(status)}) if tmp_path.exists(): tmp_path.replace(final_path) status.status = "completed" status.finished_at = time.time() log.info("Download completed id=%s filename=%s", status.download_id, status.filename) await self._emit({"type": "download_completed", "download": asdict(status)}) if self.cfg.reload_on_new_model: plan = RestartPlan( method=self.cfg.restart_method, command=self.cfg.restart_command, url=self.cfg.restart_url, allowed_container=self.cfg.allowed_container, ) await trigger_restart( plan, payload={ "reason": "new_model", "model_id": status.filename, "llamacpp_args": self.cfg.llamacpp_args, "llamacpp_extra_args": self.cfg.llamacpp_extra_args, }, ) except asyncio.CancelledError: status.status = "cancelled" if tmp_path.exists(): tmp_path.unlink(missing_ok=True) log.info("Download cancelled id=%s filename=%s", status.download_id, status.filename) await self._emit({"type": "download_cancelled", "download": asdict(status)}) except Exception as exc: status.status = "error" status.error = str(exc) if tmp_path.exists(): tmp_path.unlink(missing_ok=True) log.info("Download error id=%s filename=%s error=%s", status.download_id, status.filename, exc) await self._emit({"type": "download_error", "download": asdict(status)})