Initial commit

This commit is contained in:
Rushabh Gosar
2026-01-07 16:54:39 -08:00
commit 5d1a0ee72b
53 changed files with 9885 additions and 0 deletions

View File

@@ -0,0 +1,313 @@
import json
import logging
import shlex
import ssl
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, Optional
import websockets
import yaml
log = logging.getLogger("truenas_middleware")
@dataclass
class TrueNASConfig:
ws_url: str
api_key: str
api_user: Optional[str]
app_name: str
verify_ssl: bool = False
def _parse_compose(raw: Any) -> Dict[str, Any]:
if isinstance(raw, dict):
return raw
if isinstance(raw, str):
text = raw.strip()
try:
return json.loads(text)
except json.JSONDecodeError:
return yaml.safe_load(text)
raise ValueError("Unsupported compose payload")
def _command_to_list(command: Any) -> list:
if isinstance(command, list):
return command
if isinstance(command, str):
return shlex.split(command)
return []
def _extract_command(config: Dict[str, Any], service_name: str = "llamacpp") -> list:
if config.get("custom_compose_config") or config.get("custom_compose_config_string"):
compose = _parse_compose(config.get("custom_compose_config") or config.get("custom_compose_config_string") or {})
services = compose.get("services") or {}
svc = services.get(service_name) or {}
return _command_to_list(svc.get("command"))
return _command_to_list(config.get("command"))
def _model_id_from_command(cmd: list) -> Optional[str]:
if "--model" in cmd:
idx = cmd.index("--model")
if idx + 1 < len(cmd):
return Path(cmd[idx + 1]).name
return None
def _set_arg(cmd: list, flag: str, value: Optional[str]) -> list:
if value is None:
return cmd
if flag in cmd:
idx = cmd.index(flag)
if idx + 1 < len(cmd):
cmd[idx + 1] = value
else:
cmd.append(value)
return cmd
cmd.extend([flag, value])
return cmd
def _merge_args(cmd: list, args: Dict[str, str]) -> list:
flag_map = {
"device": "--device",
"tensor_split": "--tensor-split",
"split_mode": "--split-mode",
"n_gpu_layers": "--n-gpu-layers",
"ctx_size": "--ctx-size",
"batch_size": "--batch-size",
"ubatch_size": "--ubatch-size",
"cache_type_k": "--cache-type-k",
"cache_type_v": "--cache-type-v",
"flash_attn": "--flash-attn",
}
for key, value in args.items():
flag = flag_map.get(key)
if flag:
if flag in cmd:
continue
_set_arg(cmd, flag, value)
return cmd
def _merge_extra_args(cmd: list, extra: str) -> list:
if not extra:
return cmd
extra_list = shlex.split(extra)
filtered: list[str] = []
skip_next = False
for item in extra_list:
if skip_next:
skip_next = False
continue
if item in {"--device", "-dev"}:
log.warning("Dropping --device from extra args to avoid llama.cpp device errors.")
skip_next = True
continue
filtered.append(item)
for flag in filtered:
if flag not in cmd:
cmd.append(flag)
return cmd
def _update_model_command(command: Any, model_path: str, args: Dict[str, str], extra: str) -> list:
cmd = _command_to_list(command)
if "--device" in cmd:
idx = cmd.index("--device")
del cmd[idx: idx + 2]
cmd = _set_arg(cmd, "--model", model_path)
cmd = _merge_args(cmd, args)
cmd = _merge_extra_args(cmd, extra)
return cmd
def _replace_flags(cmd: list, flags: Dict[str, Optional[str]], extra: str) -> list:
result = list(cmd)
for flag in flags.keys():
while flag in result:
idx = result.index(flag)
del result[idx: idx + 2]
if "--device" in result:
idx = result.index("--device")
del result[idx: idx + 2]
for flag, value in flags.items():
if value is not None and value != "":
result = _set_arg(result, flag, value)
result = _merge_extra_args(result, extra)
return result
async def get_app_config(cfg: TrueNASConfig) -> Dict[str, Any]:
config = await _rpc_call(cfg, "app.config", [cfg.app_name])
if not isinstance(config, dict):
raise RuntimeError("app.config returned unsupported payload")
return config
async def get_app_command(cfg: TrueNASConfig, service_name: str = "llamacpp") -> list:
config = await get_app_config(cfg)
return _extract_command(config, service_name=service_name)
async def get_active_model_id(cfg: TrueNASConfig, service_name: str = "llamacpp") -> str:
config = await get_app_config(cfg)
cmd = _extract_command(config, service_name=service_name)
return _model_id_from_command(cmd) or ""
async def get_app_logs(
cfg: TrueNASConfig,
tail_lines: int = 200,
service_name: str = "llamacpp",
) -> str:
tail_payloads = [
{"tail": tail_lines},
{"tail_lines": tail_lines},
{"tail": str(tail_lines)},
]
for payload in tail_payloads:
try:
result = await _rpc_call(cfg, "app.container_logs", [cfg.app_name, service_name, payload])
if isinstance(result, str):
return result
except Exception as exc:
log.debug("app.container_logs failed (%s): %s", payload, exc)
for payload in tail_payloads:
try:
result = await _rpc_call(cfg, "app.logs", [cfg.app_name, payload])
if isinstance(result, str):
return result
except Exception as exc:
log.debug("app.logs failed (%s): %s", payload, exc)
return ""
async def update_app_command(
cfg: TrueNASConfig,
command: list,
service_name: str = "llamacpp",
) -> None:
config = await _rpc_call(cfg, "app.config", [cfg.app_name])
if not isinstance(config, dict):
raise RuntimeError("app.config returned unsupported payload")
if config.get("custom_compose_config") or config.get("custom_compose_config_string"):
compose = _parse_compose(config.get("custom_compose_config") or config.get("custom_compose_config_string") or {})
services = compose.get("services") or {}
if service_name not in services:
raise RuntimeError(f"service {service_name} not found in compose")
svc = services[service_name]
svc["command"] = command
await _rpc_call(cfg, "app.update", [cfg.app_name, {"custom_compose_config": compose}])
return
config["command"] = command
await _rpc_call(cfg, "app.update", [cfg.app_name, {"values": config}])
async def update_command_flags(
cfg: TrueNASConfig,
flags: Dict[str, Optional[str]],
extra: str,
service_name: str = "llamacpp",
) -> None:
config = await _rpc_call(cfg, "app.config", [cfg.app_name])
if not isinstance(config, dict):
raise RuntimeError("app.config returned unsupported payload")
if config.get("custom_compose_config") or config.get("custom_compose_config_string"):
compose = _parse_compose(config.get("custom_compose_config") or config.get("custom_compose_config_string") or {})
services = compose.get("services") or {}
if service_name not in services:
raise RuntimeError(f"service {service_name} not found in compose")
svc = services[service_name]
cmd = svc.get("command")
svc["command"] = _replace_flags(_command_to_list(cmd), flags, extra)
await _rpc_call(cfg, "app.update", [cfg.app_name, {"custom_compose_config": compose}])
return
cmd = _replace_flags(_command_to_list(config.get("command")), flags, extra)
config["command"] = cmd
await _rpc_call(cfg, "app.update", [cfg.app_name, {"values": config}])
async def _rpc_call(cfg: TrueNASConfig, method: str, params: Optional[list] = None) -> Any:
ssl_ctx = None
if cfg.ws_url.startswith("wss://") and not cfg.verify_ssl:
ssl_ctx = ssl.create_default_context()
ssl_ctx.check_hostname = False
ssl_ctx.verify_mode = ssl.CERT_NONE
async with websockets.connect(cfg.ws_url, ssl=ssl_ctx) as ws:
await ws.send(json.dumps({"msg": "connect", "version": "1", "support": ["1"]}))
connected = json.loads(await ws.recv())
if connected.get("msg") != "connected":
raise RuntimeError("failed to connect to TrueNAS websocket")
await ws.send(
json.dumps({"id": 1, "msg": "method", "method": "auth.login_with_api_key", "params": [cfg.api_key]})
)
auth_resp = json.loads(await ws.recv())
if not auth_resp.get("result"):
if not cfg.api_user:
raise RuntimeError("API key rejected and TRUENAS_API_USER not set")
await ws.send(
json.dumps(
{
"id": 2,
"msg": "method",
"method": "auth.login_ex",
"params": [
{
"mechanism": "API_KEY_PLAIN",
"username": cfg.api_user,
"api_key": cfg.api_key,
}
],
}
)
)
auth_ex = json.loads(await ws.recv())
if auth_ex.get("result", {}).get("response_type") != "SUCCESS":
raise RuntimeError("API key authentication failed")
req_id = 3
await ws.send(json.dumps({"id": req_id, "msg": "method", "method": method, "params": params or []}))
while True:
raw = json.loads(await ws.recv())
if raw.get("id") != req_id:
continue
if raw.get("msg") == "error":
raise RuntimeError(raw.get("error"))
return raw.get("result")
async def switch_model(
cfg: TrueNASConfig,
model_path: str,
args: Dict[str, str],
extra: str,
service_name: str = "llamacpp",
) -> None:
config = await _rpc_call(cfg, "app.config", [cfg.app_name])
if config.get("custom_compose_config") or config.get("custom_compose_config_string"):
compose = _parse_compose(config.get("custom_compose_config") or config.get("custom_compose_config_string") or {})
services = compose.get("services") or {}
if service_name not in services:
raise RuntimeError(f"service {service_name} not found in compose")
svc = services[service_name]
cmd = svc.get("command")
svc["command"] = _update_model_command(cmd, model_path, args, extra)
await _rpc_call(cfg, "app.update", [cfg.app_name, {"custom_compose_config": compose}])
log.info("Requested model switch to %s via TrueNAS middleware (custom app)", model_path)
return
if not isinstance(config, dict):
raise RuntimeError("app.config returned unsupported payload")
cmd = config.get("command")
config["command"] = _update_model_command(cmd, model_path, args, extra)
await _rpc_call(cfg, "app.update", [cfg.app_name, {"values": config}])
log.info("Requested model switch to %s via TrueNAS middleware (catalog app)", model_path)