142 lines
5.8 KiB
Python
142 lines
5.8 KiB
Python
import asyncio
|
|
import fnmatch
|
|
import logging
|
|
import os
|
|
import time
|
|
import uuid
|
|
from dataclasses import asdict, dataclass, field
|
|
from pathlib import Path
|
|
from typing import Dict, Optional
|
|
|
|
import httpx
|
|
|
|
from app.config import AppConfig
|
|
from app.logging_utils import configure_logging
|
|
from app.restart import RestartPlan, trigger_restart
|
|
|
|
configure_logging()
|
|
log = logging.getLogger("download_manager")
|
|
|
|
|
|
@dataclass
|
|
class DownloadStatus:
|
|
download_id: str
|
|
url: str
|
|
filename: str
|
|
status: str
|
|
bytes_total: Optional[int] = None
|
|
bytes_downloaded: int = 0
|
|
started_at: float = field(default_factory=time.time)
|
|
finished_at: Optional[float] = None
|
|
error: Optional[str] = None
|
|
|
|
|
|
class DownloadManager:
|
|
def __init__(self, cfg: AppConfig, broadcaster=None) -> None:
|
|
self.cfg = cfg
|
|
self._downloads: Dict[str, DownloadStatus] = {}
|
|
self._tasks: Dict[str, asyncio.Task] = {}
|
|
self._semaphore = asyncio.Semaphore(cfg.download_max_concurrent)
|
|
self._broadcaster = broadcaster
|
|
|
|
async def _emit(self, payload: dict) -> None:
|
|
if self._broadcaster:
|
|
await self._broadcaster.publish(payload)
|
|
|
|
def list_downloads(self) -> Dict[str, dict]:
|
|
return {k: asdict(v) for k, v in self._downloads.items()}
|
|
|
|
def get(self, download_id: str) -> Optional[DownloadStatus]:
|
|
return self._downloads.get(download_id)
|
|
|
|
def _is_allowed(self, url: str) -> bool:
|
|
if not self.cfg.download_allowlist:
|
|
return True
|
|
return any(fnmatch.fnmatch(url, pattern) for pattern in self.cfg.download_allowlist)
|
|
|
|
async def start(self, url: str, filename: Optional[str] = None) -> DownloadStatus:
|
|
if not self._is_allowed(url):
|
|
raise ValueError("url not allowed by allowlist")
|
|
if not filename:
|
|
filename = os.path.basename(url.split("?")[0]) or f"model-{uuid.uuid4().hex}.gguf"
|
|
log.info("Download requested url=%s filename=%s", url, filename)
|
|
download_id = uuid.uuid4().hex
|
|
status = DownloadStatus(download_id=download_id, url=url, filename=filename, status="queued")
|
|
self._downloads[download_id] = status
|
|
task = asyncio.create_task(self._run_download(status))
|
|
self._tasks[download_id] = task
|
|
await self._emit({"type": "download_status", "download": asdict(status)})
|
|
return status
|
|
|
|
async def cancel(self, download_id: str) -> bool:
|
|
task = self._tasks.get(download_id)
|
|
if task:
|
|
task.cancel()
|
|
status = self._downloads.get(download_id)
|
|
if status:
|
|
log.info("Download cancelled id=%s filename=%s", download_id, status.filename)
|
|
await self._emit({"type": "download_status", "download": asdict(status)})
|
|
return True
|
|
return False
|
|
|
|
async def _run_download(self, status: DownloadStatus) -> None:
|
|
status.status = "downloading"
|
|
base = Path(self.cfg.download_dir)
|
|
base.mkdir(parents=True, exist_ok=True)
|
|
tmp_path = base / f".{status.filename}.partial"
|
|
final_path = base / status.filename
|
|
last_emit = 0.0
|
|
|
|
try:
|
|
async with self._semaphore:
|
|
async with httpx.AsyncClient(timeout=None, follow_redirects=True) as client:
|
|
async with client.stream("GET", status.url) as resp:
|
|
resp.raise_for_status()
|
|
length = resp.headers.get("content-length")
|
|
if length:
|
|
status.bytes_total = int(length)
|
|
with tmp_path.open("wb") as f:
|
|
async for chunk in resp.aiter_bytes():
|
|
if chunk:
|
|
f.write(chunk)
|
|
status.bytes_downloaded += len(chunk)
|
|
now = time.time()
|
|
if now - last_emit >= 1:
|
|
last_emit = now
|
|
await self._emit({"type": "download_progress", "download": asdict(status)})
|
|
if tmp_path.exists():
|
|
tmp_path.replace(final_path)
|
|
status.status = "completed"
|
|
status.finished_at = time.time()
|
|
log.info("Download completed id=%s filename=%s", status.download_id, status.filename)
|
|
await self._emit({"type": "download_completed", "download": asdict(status)})
|
|
if self.cfg.reload_on_new_model:
|
|
plan = RestartPlan(
|
|
method=self.cfg.restart_method,
|
|
command=self.cfg.restart_command,
|
|
url=self.cfg.restart_url,
|
|
allowed_container=self.cfg.allowed_container,
|
|
)
|
|
await trigger_restart(
|
|
plan,
|
|
payload={
|
|
"reason": "new_model",
|
|
"model_id": status.filename,
|
|
"llamacpp_args": self.cfg.llamacpp_args,
|
|
"llamacpp_extra_args": self.cfg.llamacpp_extra_args,
|
|
},
|
|
)
|
|
except asyncio.CancelledError:
|
|
status.status = "cancelled"
|
|
if tmp_path.exists():
|
|
tmp_path.unlink(missing_ok=True)
|
|
log.info("Download cancelled id=%s filename=%s", status.download_id, status.filename)
|
|
await self._emit({"type": "download_cancelled", "download": asdict(status)})
|
|
except Exception as exc:
|
|
status.status = "error"
|
|
status.error = str(exc)
|
|
if tmp_path.exists():
|
|
tmp_path.unlink(missing_ok=True)
|
|
log.info("Download error id=%s filename=%s error=%s", status.download_id, status.filename, exc)
|
|
await self._emit({"type": "download_error", "download": asdict(status)})
|