import json import logging import shlex import ssl from dataclasses import dataclass from pathlib import Path from typing import Any, Dict, Optional import websockets import yaml log = logging.getLogger("truenas_middleware") @dataclass class TrueNASConfig: ws_url: str api_key: str api_user: Optional[str] app_name: str verify_ssl: bool = False def _parse_compose(raw: Any) -> Dict[str, Any]: if isinstance(raw, dict): return raw if isinstance(raw, str): text = raw.strip() try: return json.loads(text) except json.JSONDecodeError: return yaml.safe_load(text) raise ValueError("Unsupported compose payload") def _command_to_list(command: Any) -> list: if isinstance(command, list): return command if isinstance(command, str): return shlex.split(command) return [] def _extract_command(config: Dict[str, Any], service_name: str = "llamacpp") -> list: if config.get("custom_compose_config") or config.get("custom_compose_config_string"): compose = _parse_compose(config.get("custom_compose_config") or config.get("custom_compose_config_string") or {}) services = compose.get("services") or {} svc = services.get(service_name) or {} return _command_to_list(svc.get("command")) return _command_to_list(config.get("command")) def _model_id_from_command(cmd: list) -> Optional[str]: if "--model" in cmd: idx = cmd.index("--model") if idx + 1 < len(cmd): return Path(cmd[idx + 1]).name return None def _set_arg(cmd: list, flag: str, value: Optional[str]) -> list: if value is None: return cmd if flag in cmd: idx = cmd.index(flag) if idx + 1 < len(cmd): cmd[idx + 1] = value else: cmd.append(value) return cmd cmd.extend([flag, value]) return cmd def _merge_args(cmd: list, args: Dict[str, str]) -> list: flag_map = { "device": "--device", "tensor_split": "--tensor-split", "split_mode": "--split-mode", "n_gpu_layers": "--n-gpu-layers", "ctx_size": "--ctx-size", "batch_size": "--batch-size", "ubatch_size": "--ubatch-size", "cache_type_k": "--cache-type-k", "cache_type_v": "--cache-type-v", "flash_attn": "--flash-attn", } for key, value in args.items(): flag = flag_map.get(key) if flag: if flag in cmd: continue _set_arg(cmd, flag, value) return cmd def _merge_extra_args(cmd: list, extra: str) -> list: if not extra: return cmd extra_list = shlex.split(extra) filtered: list[str] = [] skip_next = False for item in extra_list: if skip_next: skip_next = False continue if item in {"--device", "-dev"}: log.warning("Dropping --device from extra args to avoid llama.cpp device errors.") skip_next = True continue filtered.append(item) for flag in filtered: if flag not in cmd: cmd.append(flag) return cmd def _update_model_command(command: Any, model_path: str, args: Dict[str, str], extra: str) -> list: cmd = _command_to_list(command) if "--device" in cmd: idx = cmd.index("--device") del cmd[idx: idx + 2] cmd = _set_arg(cmd, "--model", model_path) cmd = _merge_args(cmd, args) cmd = _merge_extra_args(cmd, extra) return cmd def _replace_flags(cmd: list, flags: Dict[str, Optional[str]], extra: str) -> list: result = list(cmd) for flag in flags.keys(): while flag in result: idx = result.index(flag) del result[idx: idx + 2] if "--device" in result: idx = result.index("--device") del result[idx: idx + 2] for flag, value in flags.items(): if value is not None and value != "": result = _set_arg(result, flag, value) result = _merge_extra_args(result, extra) return result async def get_app_config(cfg: TrueNASConfig) -> Dict[str, Any]: config = await _rpc_call(cfg, "app.config", [cfg.app_name]) if not isinstance(config, dict): raise RuntimeError("app.config returned unsupported payload") return config async def get_app_command(cfg: TrueNASConfig, service_name: str = "llamacpp") -> list: config = await get_app_config(cfg) return _extract_command(config, service_name=service_name) async def get_active_model_id(cfg: TrueNASConfig, service_name: str = "llamacpp") -> str: config = await get_app_config(cfg) cmd = _extract_command(config, service_name=service_name) return _model_id_from_command(cmd) or "" async def get_app_logs( cfg: TrueNASConfig, tail_lines: int = 200, service_name: str = "llamacpp", ) -> str: tail_payloads = [ {"tail": tail_lines}, {"tail_lines": tail_lines}, {"tail": str(tail_lines)}, ] for payload in tail_payloads: try: result = await _rpc_call(cfg, "app.container_logs", [cfg.app_name, service_name, payload]) if isinstance(result, str): return result except Exception as exc: log.debug("app.container_logs failed (%s): %s", payload, exc) for payload in tail_payloads: try: result = await _rpc_call(cfg, "app.logs", [cfg.app_name, payload]) if isinstance(result, str): return result except Exception as exc: log.debug("app.logs failed (%s): %s", payload, exc) return "" async def update_app_command( cfg: TrueNASConfig, command: list, service_name: str = "llamacpp", ) -> None: config = await _rpc_call(cfg, "app.config", [cfg.app_name]) if not isinstance(config, dict): raise RuntimeError("app.config returned unsupported payload") if config.get("custom_compose_config") or config.get("custom_compose_config_string"): compose = _parse_compose(config.get("custom_compose_config") or config.get("custom_compose_config_string") or {}) services = compose.get("services") or {} if service_name not in services: raise RuntimeError(f"service {service_name} not found in compose") svc = services[service_name] svc["command"] = command await _rpc_call(cfg, "app.update", [cfg.app_name, {"custom_compose_config": compose}]) return config["command"] = command await _rpc_call(cfg, "app.update", [cfg.app_name, {"values": config}]) async def update_command_flags( cfg: TrueNASConfig, flags: Dict[str, Optional[str]], extra: str, service_name: str = "llamacpp", ) -> None: config = await _rpc_call(cfg, "app.config", [cfg.app_name]) if not isinstance(config, dict): raise RuntimeError("app.config returned unsupported payload") if config.get("custom_compose_config") or config.get("custom_compose_config_string"): compose = _parse_compose(config.get("custom_compose_config") or config.get("custom_compose_config_string") or {}) services = compose.get("services") or {} if service_name not in services: raise RuntimeError(f"service {service_name} not found in compose") svc = services[service_name] cmd = svc.get("command") svc["command"] = _replace_flags(_command_to_list(cmd), flags, extra) await _rpc_call(cfg, "app.update", [cfg.app_name, {"custom_compose_config": compose}]) return cmd = _replace_flags(_command_to_list(config.get("command")), flags, extra) config["command"] = cmd await _rpc_call(cfg, "app.update", [cfg.app_name, {"values": config}]) async def _rpc_call(cfg: TrueNASConfig, method: str, params: Optional[list] = None) -> Any: ssl_ctx = None if cfg.ws_url.startswith("wss://") and not cfg.verify_ssl: ssl_ctx = ssl.create_default_context() ssl_ctx.check_hostname = False ssl_ctx.verify_mode = ssl.CERT_NONE async with websockets.connect(cfg.ws_url, ssl=ssl_ctx) as ws: await ws.send(json.dumps({"msg": "connect", "version": "1", "support": ["1"]})) connected = json.loads(await ws.recv()) if connected.get("msg") != "connected": raise RuntimeError("failed to connect to TrueNAS websocket") await ws.send( json.dumps({"id": 1, "msg": "method", "method": "auth.login_with_api_key", "params": [cfg.api_key]}) ) auth_resp = json.loads(await ws.recv()) if not auth_resp.get("result"): if not cfg.api_user: raise RuntimeError("API key rejected and TRUENAS_API_USER not set") await ws.send( json.dumps( { "id": 2, "msg": "method", "method": "auth.login_ex", "params": [ { "mechanism": "API_KEY_PLAIN", "username": cfg.api_user, "api_key": cfg.api_key, } ], } ) ) auth_ex = json.loads(await ws.recv()) if auth_ex.get("result", {}).get("response_type") != "SUCCESS": raise RuntimeError("API key authentication failed") req_id = 3 await ws.send(json.dumps({"id": req_id, "msg": "method", "method": method, "params": params or []})) while True: raw = json.loads(await ws.recv()) if raw.get("id") != req_id: continue if raw.get("msg") == "error": raise RuntimeError(raw.get("error")) return raw.get("result") async def switch_model( cfg: TrueNASConfig, model_path: str, args: Dict[str, str], extra: str, service_name: str = "llamacpp", ) -> None: config = await _rpc_call(cfg, "app.config", [cfg.app_name]) if config.get("custom_compose_config") or config.get("custom_compose_config_string"): compose = _parse_compose(config.get("custom_compose_config") or config.get("custom_compose_config_string") or {}) services = compose.get("services") or {} if service_name not in services: raise RuntimeError(f"service {service_name} not found in compose") svc = services[service_name] cmd = svc.get("command") svc["command"] = _update_model_command(cmd, model_path, args, extra) await _rpc_call(cfg, "app.update", [cfg.app_name, {"custom_compose_config": compose}]) log.info("Requested model switch to %s via TrueNAS middleware (custom app)", model_path) return if not isinstance(config, dict): raise RuntimeError("app.config returned unsupported payload") cmd = config.get("command") config["command"] = _update_model_command(cmd, model_path, args, extra) await _rpc_call(cfg, "app.update", [cfg.app_name, {"values": config}]) log.info("Requested model switch to %s via TrueNAS middleware (catalog app)", model_path)