Files
codex_truenas_helper/llamaCpp.Wrapper.app/truenas_middleware.py
Rushabh Gosar 5d1a0ee72b Initial commit
2026-01-07 16:54:39 -08:00

314 lines
11 KiB
Python

import json
import logging
import shlex
import ssl
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, Optional
import websockets
import yaml
log = logging.getLogger("truenas_middleware")
@dataclass
class TrueNASConfig:
ws_url: str
api_key: str
api_user: Optional[str]
app_name: str
verify_ssl: bool = False
def _parse_compose(raw: Any) -> Dict[str, Any]:
if isinstance(raw, dict):
return raw
if isinstance(raw, str):
text = raw.strip()
try:
return json.loads(text)
except json.JSONDecodeError:
return yaml.safe_load(text)
raise ValueError("Unsupported compose payload")
def _command_to_list(command: Any) -> list:
if isinstance(command, list):
return command
if isinstance(command, str):
return shlex.split(command)
return []
def _extract_command(config: Dict[str, Any], service_name: str = "llamacpp") -> list:
if config.get("custom_compose_config") or config.get("custom_compose_config_string"):
compose = _parse_compose(config.get("custom_compose_config") or config.get("custom_compose_config_string") or {})
services = compose.get("services") or {}
svc = services.get(service_name) or {}
return _command_to_list(svc.get("command"))
return _command_to_list(config.get("command"))
def _model_id_from_command(cmd: list) -> Optional[str]:
if "--model" in cmd:
idx = cmd.index("--model")
if idx + 1 < len(cmd):
return Path(cmd[idx + 1]).name
return None
def _set_arg(cmd: list, flag: str, value: Optional[str]) -> list:
if value is None:
return cmd
if flag in cmd:
idx = cmd.index(flag)
if idx + 1 < len(cmd):
cmd[idx + 1] = value
else:
cmd.append(value)
return cmd
cmd.extend([flag, value])
return cmd
def _merge_args(cmd: list, args: Dict[str, str]) -> list:
flag_map = {
"device": "--device",
"tensor_split": "--tensor-split",
"split_mode": "--split-mode",
"n_gpu_layers": "--n-gpu-layers",
"ctx_size": "--ctx-size",
"batch_size": "--batch-size",
"ubatch_size": "--ubatch-size",
"cache_type_k": "--cache-type-k",
"cache_type_v": "--cache-type-v",
"flash_attn": "--flash-attn",
}
for key, value in args.items():
flag = flag_map.get(key)
if flag:
if flag in cmd:
continue
_set_arg(cmd, flag, value)
return cmd
def _merge_extra_args(cmd: list, extra: str) -> list:
if not extra:
return cmd
extra_list = shlex.split(extra)
filtered: list[str] = []
skip_next = False
for item in extra_list:
if skip_next:
skip_next = False
continue
if item in {"--device", "-dev"}:
log.warning("Dropping --device from extra args to avoid llama.cpp device errors.")
skip_next = True
continue
filtered.append(item)
for flag in filtered:
if flag not in cmd:
cmd.append(flag)
return cmd
def _update_model_command(command: Any, model_path: str, args: Dict[str, str], extra: str) -> list:
cmd = _command_to_list(command)
if "--device" in cmd:
idx = cmd.index("--device")
del cmd[idx: idx + 2]
cmd = _set_arg(cmd, "--model", model_path)
cmd = _merge_args(cmd, args)
cmd = _merge_extra_args(cmd, extra)
return cmd
def _replace_flags(cmd: list, flags: Dict[str, Optional[str]], extra: str) -> list:
result = list(cmd)
for flag in flags.keys():
while flag in result:
idx = result.index(flag)
del result[idx: idx + 2]
if "--device" in result:
idx = result.index("--device")
del result[idx: idx + 2]
for flag, value in flags.items():
if value is not None and value != "":
result = _set_arg(result, flag, value)
result = _merge_extra_args(result, extra)
return result
async def get_app_config(cfg: TrueNASConfig) -> Dict[str, Any]:
config = await _rpc_call(cfg, "app.config", [cfg.app_name])
if not isinstance(config, dict):
raise RuntimeError("app.config returned unsupported payload")
return config
async def get_app_command(cfg: TrueNASConfig, service_name: str = "llamacpp") -> list:
config = await get_app_config(cfg)
return _extract_command(config, service_name=service_name)
async def get_active_model_id(cfg: TrueNASConfig, service_name: str = "llamacpp") -> str:
config = await get_app_config(cfg)
cmd = _extract_command(config, service_name=service_name)
return _model_id_from_command(cmd) or ""
async def get_app_logs(
cfg: TrueNASConfig,
tail_lines: int = 200,
service_name: str = "llamacpp",
) -> str:
tail_payloads = [
{"tail": tail_lines},
{"tail_lines": tail_lines},
{"tail": str(tail_lines)},
]
for payload in tail_payloads:
try:
result = await _rpc_call(cfg, "app.container_logs", [cfg.app_name, service_name, payload])
if isinstance(result, str):
return result
except Exception as exc:
log.debug("app.container_logs failed (%s): %s", payload, exc)
for payload in tail_payloads:
try:
result = await _rpc_call(cfg, "app.logs", [cfg.app_name, payload])
if isinstance(result, str):
return result
except Exception as exc:
log.debug("app.logs failed (%s): %s", payload, exc)
return ""
async def update_app_command(
cfg: TrueNASConfig,
command: list,
service_name: str = "llamacpp",
) -> None:
config = await _rpc_call(cfg, "app.config", [cfg.app_name])
if not isinstance(config, dict):
raise RuntimeError("app.config returned unsupported payload")
if config.get("custom_compose_config") or config.get("custom_compose_config_string"):
compose = _parse_compose(config.get("custom_compose_config") or config.get("custom_compose_config_string") or {})
services = compose.get("services") or {}
if service_name not in services:
raise RuntimeError(f"service {service_name} not found in compose")
svc = services[service_name]
svc["command"] = command
await _rpc_call(cfg, "app.update", [cfg.app_name, {"custom_compose_config": compose}])
return
config["command"] = command
await _rpc_call(cfg, "app.update", [cfg.app_name, {"values": config}])
async def update_command_flags(
cfg: TrueNASConfig,
flags: Dict[str, Optional[str]],
extra: str,
service_name: str = "llamacpp",
) -> None:
config = await _rpc_call(cfg, "app.config", [cfg.app_name])
if not isinstance(config, dict):
raise RuntimeError("app.config returned unsupported payload")
if config.get("custom_compose_config") or config.get("custom_compose_config_string"):
compose = _parse_compose(config.get("custom_compose_config") or config.get("custom_compose_config_string") or {})
services = compose.get("services") or {}
if service_name not in services:
raise RuntimeError(f"service {service_name} not found in compose")
svc = services[service_name]
cmd = svc.get("command")
svc["command"] = _replace_flags(_command_to_list(cmd), flags, extra)
await _rpc_call(cfg, "app.update", [cfg.app_name, {"custom_compose_config": compose}])
return
cmd = _replace_flags(_command_to_list(config.get("command")), flags, extra)
config["command"] = cmd
await _rpc_call(cfg, "app.update", [cfg.app_name, {"values": config}])
async def _rpc_call(cfg: TrueNASConfig, method: str, params: Optional[list] = None) -> Any:
ssl_ctx = None
if cfg.ws_url.startswith("wss://") and not cfg.verify_ssl:
ssl_ctx = ssl.create_default_context()
ssl_ctx.check_hostname = False
ssl_ctx.verify_mode = ssl.CERT_NONE
async with websockets.connect(cfg.ws_url, ssl=ssl_ctx) as ws:
await ws.send(json.dumps({"msg": "connect", "version": "1", "support": ["1"]}))
connected = json.loads(await ws.recv())
if connected.get("msg") != "connected":
raise RuntimeError("failed to connect to TrueNAS websocket")
await ws.send(
json.dumps({"id": 1, "msg": "method", "method": "auth.login_with_api_key", "params": [cfg.api_key]})
)
auth_resp = json.loads(await ws.recv())
if not auth_resp.get("result"):
if not cfg.api_user:
raise RuntimeError("API key rejected and TRUENAS_API_USER not set")
await ws.send(
json.dumps(
{
"id": 2,
"msg": "method",
"method": "auth.login_ex",
"params": [
{
"mechanism": "API_KEY_PLAIN",
"username": cfg.api_user,
"api_key": cfg.api_key,
}
],
}
)
)
auth_ex = json.loads(await ws.recv())
if auth_ex.get("result", {}).get("response_type") != "SUCCESS":
raise RuntimeError("API key authentication failed")
req_id = 3
await ws.send(json.dumps({"id": req_id, "msg": "method", "method": method, "params": params or []}))
while True:
raw = json.loads(await ws.recv())
if raw.get("id") != req_id:
continue
if raw.get("msg") == "error":
raise RuntimeError(raw.get("error"))
return raw.get("result")
async def switch_model(
cfg: TrueNASConfig,
model_path: str,
args: Dict[str, str],
extra: str,
service_name: str = "llamacpp",
) -> None:
config = await _rpc_call(cfg, "app.config", [cfg.app_name])
if config.get("custom_compose_config") or config.get("custom_compose_config_string"):
compose = _parse_compose(config.get("custom_compose_config") or config.get("custom_compose_config_string") or {})
services = compose.get("services") or {}
if service_name not in services:
raise RuntimeError(f"service {service_name} not found in compose")
svc = services[service_name]
cmd = svc.get("command")
svc["command"] = _update_model_command(cmd, model_path, args, extra)
await _rpc_call(cfg, "app.update", [cfg.app_name, {"custom_compose_config": compose}])
log.info("Requested model switch to %s via TrueNAS middleware (custom app)", model_path)
return
if not isinstance(config, dict):
raise RuntimeError("app.config returned unsupported payload")
cmd = config.get("command")
config["command"] = _update_model_command(cmd, model_path, args, extra)
await _rpc_call(cfg, "app.update", [cfg.app_name, {"values": config}])
log.info("Requested model switch to %s via TrueNAS middleware (catalog app)", model_path)