Initial commit
This commit is contained in:
141
llamaCpp.Wrapper.app/download_manager.py
Normal file
141
llamaCpp.Wrapper.app/download_manager.py
Normal file
@@ -0,0 +1,141 @@
|
||||
import asyncio
|
||||
import fnmatch
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from app.config import AppConfig
|
||||
from app.logging_utils import configure_logging
|
||||
from app.restart import RestartPlan, trigger_restart
|
||||
|
||||
configure_logging()
|
||||
log = logging.getLogger("download_manager")
|
||||
|
||||
|
||||
@dataclass
|
||||
class DownloadStatus:
|
||||
download_id: str
|
||||
url: str
|
||||
filename: str
|
||||
status: str
|
||||
bytes_total: Optional[int] = None
|
||||
bytes_downloaded: int = 0
|
||||
started_at: float = field(default_factory=time.time)
|
||||
finished_at: Optional[float] = None
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
class DownloadManager:
|
||||
def __init__(self, cfg: AppConfig, broadcaster=None) -> None:
|
||||
self.cfg = cfg
|
||||
self._downloads: Dict[str, DownloadStatus] = {}
|
||||
self._tasks: Dict[str, asyncio.Task] = {}
|
||||
self._semaphore = asyncio.Semaphore(cfg.download_max_concurrent)
|
||||
self._broadcaster = broadcaster
|
||||
|
||||
async def _emit(self, payload: dict) -> None:
|
||||
if self._broadcaster:
|
||||
await self._broadcaster.publish(payload)
|
||||
|
||||
def list_downloads(self) -> Dict[str, dict]:
|
||||
return {k: asdict(v) for k, v in self._downloads.items()}
|
||||
|
||||
def get(self, download_id: str) -> Optional[DownloadStatus]:
|
||||
return self._downloads.get(download_id)
|
||||
|
||||
def _is_allowed(self, url: str) -> bool:
|
||||
if not self.cfg.download_allowlist:
|
||||
return True
|
||||
return any(fnmatch.fnmatch(url, pattern) for pattern in self.cfg.download_allowlist)
|
||||
|
||||
async def start(self, url: str, filename: Optional[str] = None) -> DownloadStatus:
|
||||
if not self._is_allowed(url):
|
||||
raise ValueError("url not allowed by allowlist")
|
||||
if not filename:
|
||||
filename = os.path.basename(url.split("?")[0]) or f"model-{uuid.uuid4().hex}.gguf"
|
||||
log.info("Download requested url=%s filename=%s", url, filename)
|
||||
download_id = uuid.uuid4().hex
|
||||
status = DownloadStatus(download_id=download_id, url=url, filename=filename, status="queued")
|
||||
self._downloads[download_id] = status
|
||||
task = asyncio.create_task(self._run_download(status))
|
||||
self._tasks[download_id] = task
|
||||
await self._emit({"type": "download_status", "download": asdict(status)})
|
||||
return status
|
||||
|
||||
async def cancel(self, download_id: str) -> bool:
|
||||
task = self._tasks.get(download_id)
|
||||
if task:
|
||||
task.cancel()
|
||||
status = self._downloads.get(download_id)
|
||||
if status:
|
||||
log.info("Download cancelled id=%s filename=%s", download_id, status.filename)
|
||||
await self._emit({"type": "download_status", "download": asdict(status)})
|
||||
return True
|
||||
return False
|
||||
|
||||
async def _run_download(self, status: DownloadStatus) -> None:
|
||||
status.status = "downloading"
|
||||
base = Path(self.cfg.download_dir)
|
||||
base.mkdir(parents=True, exist_ok=True)
|
||||
tmp_path = base / f".{status.filename}.partial"
|
||||
final_path = base / status.filename
|
||||
last_emit = 0.0
|
||||
|
||||
try:
|
||||
async with self._semaphore:
|
||||
async with httpx.AsyncClient(timeout=None, follow_redirects=True) as client:
|
||||
async with client.stream("GET", status.url) as resp:
|
||||
resp.raise_for_status()
|
||||
length = resp.headers.get("content-length")
|
||||
if length:
|
||||
status.bytes_total = int(length)
|
||||
with tmp_path.open("wb") as f:
|
||||
async for chunk in resp.aiter_bytes():
|
||||
if chunk:
|
||||
f.write(chunk)
|
||||
status.bytes_downloaded += len(chunk)
|
||||
now = time.time()
|
||||
if now - last_emit >= 1:
|
||||
last_emit = now
|
||||
await self._emit({"type": "download_progress", "download": asdict(status)})
|
||||
if tmp_path.exists():
|
||||
tmp_path.replace(final_path)
|
||||
status.status = "completed"
|
||||
status.finished_at = time.time()
|
||||
log.info("Download completed id=%s filename=%s", status.download_id, status.filename)
|
||||
await self._emit({"type": "download_completed", "download": asdict(status)})
|
||||
if self.cfg.reload_on_new_model:
|
||||
plan = RestartPlan(
|
||||
method=self.cfg.restart_method,
|
||||
command=self.cfg.restart_command,
|
||||
url=self.cfg.restart_url,
|
||||
allowed_container=self.cfg.allowed_container,
|
||||
)
|
||||
await trigger_restart(
|
||||
plan,
|
||||
payload={
|
||||
"reason": "new_model",
|
||||
"model_id": status.filename,
|
||||
"llamacpp_args": self.cfg.llamacpp_args,
|
||||
"llamacpp_extra_args": self.cfg.llamacpp_extra_args,
|
||||
},
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
status.status = "cancelled"
|
||||
if tmp_path.exists():
|
||||
tmp_path.unlink(missing_ok=True)
|
||||
log.info("Download cancelled id=%s filename=%s", status.download_id, status.filename)
|
||||
await self._emit({"type": "download_cancelled", "download": asdict(status)})
|
||||
except Exception as exc:
|
||||
status.status = "error"
|
||||
status.error = str(exc)
|
||||
if tmp_path.exists():
|
||||
tmp_path.unlink(missing_ok=True)
|
||||
log.info("Download error id=%s filename=%s error=%s", status.download_id, status.filename, exc)
|
||||
await self._emit({"type": "download_error", "download": asdict(status)})
|
||||
Reference in New Issue
Block a user