Initial commit
This commit is contained in:
313
llamaCpp.Wrapper.app/truenas_middleware.py
Normal file
313
llamaCpp.Wrapper.app/truenas_middleware.py
Normal file
@@ -0,0 +1,313 @@
|
||||
import json
|
||||
import logging
|
||||
import shlex
|
||||
import ssl
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import websockets
|
||||
import yaml
|
||||
|
||||
|
||||
log = logging.getLogger("truenas_middleware")
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrueNASConfig:
|
||||
ws_url: str
|
||||
api_key: str
|
||||
api_user: Optional[str]
|
||||
app_name: str
|
||||
verify_ssl: bool = False
|
||||
|
||||
|
||||
def _parse_compose(raw: Any) -> Dict[str, Any]:
|
||||
if isinstance(raw, dict):
|
||||
return raw
|
||||
if isinstance(raw, str):
|
||||
text = raw.strip()
|
||||
try:
|
||||
return json.loads(text)
|
||||
except json.JSONDecodeError:
|
||||
return yaml.safe_load(text)
|
||||
raise ValueError("Unsupported compose payload")
|
||||
|
||||
|
||||
def _command_to_list(command: Any) -> list:
|
||||
if isinstance(command, list):
|
||||
return command
|
||||
if isinstance(command, str):
|
||||
return shlex.split(command)
|
||||
return []
|
||||
|
||||
|
||||
def _extract_command(config: Dict[str, Any], service_name: str = "llamacpp") -> list:
|
||||
if config.get("custom_compose_config") or config.get("custom_compose_config_string"):
|
||||
compose = _parse_compose(config.get("custom_compose_config") or config.get("custom_compose_config_string") or {})
|
||||
services = compose.get("services") or {}
|
||||
svc = services.get(service_name) or {}
|
||||
return _command_to_list(svc.get("command"))
|
||||
return _command_to_list(config.get("command"))
|
||||
|
||||
|
||||
def _model_id_from_command(cmd: list) -> Optional[str]:
|
||||
if "--model" in cmd:
|
||||
idx = cmd.index("--model")
|
||||
if idx + 1 < len(cmd):
|
||||
return Path(cmd[idx + 1]).name
|
||||
return None
|
||||
|
||||
|
||||
def _set_arg(cmd: list, flag: str, value: Optional[str]) -> list:
|
||||
if value is None:
|
||||
return cmd
|
||||
if flag in cmd:
|
||||
idx = cmd.index(flag)
|
||||
if idx + 1 < len(cmd):
|
||||
cmd[idx + 1] = value
|
||||
else:
|
||||
cmd.append(value)
|
||||
return cmd
|
||||
cmd.extend([flag, value])
|
||||
return cmd
|
||||
|
||||
|
||||
def _merge_args(cmd: list, args: Dict[str, str]) -> list:
|
||||
flag_map = {
|
||||
"device": "--device",
|
||||
"tensor_split": "--tensor-split",
|
||||
"split_mode": "--split-mode",
|
||||
"n_gpu_layers": "--n-gpu-layers",
|
||||
"ctx_size": "--ctx-size",
|
||||
"batch_size": "--batch-size",
|
||||
"ubatch_size": "--ubatch-size",
|
||||
"cache_type_k": "--cache-type-k",
|
||||
"cache_type_v": "--cache-type-v",
|
||||
"flash_attn": "--flash-attn",
|
||||
}
|
||||
for key, value in args.items():
|
||||
flag = flag_map.get(key)
|
||||
if flag:
|
||||
if flag in cmd:
|
||||
continue
|
||||
_set_arg(cmd, flag, value)
|
||||
return cmd
|
||||
|
||||
|
||||
def _merge_extra_args(cmd: list, extra: str) -> list:
|
||||
if not extra:
|
||||
return cmd
|
||||
extra_list = shlex.split(extra)
|
||||
filtered: list[str] = []
|
||||
skip_next = False
|
||||
for item in extra_list:
|
||||
if skip_next:
|
||||
skip_next = False
|
||||
continue
|
||||
if item in {"--device", "-dev"}:
|
||||
log.warning("Dropping --device from extra args to avoid llama.cpp device errors.")
|
||||
skip_next = True
|
||||
continue
|
||||
filtered.append(item)
|
||||
for flag in filtered:
|
||||
if flag not in cmd:
|
||||
cmd.append(flag)
|
||||
return cmd
|
||||
|
||||
|
||||
def _update_model_command(command: Any, model_path: str, args: Dict[str, str], extra: str) -> list:
|
||||
cmd = _command_to_list(command)
|
||||
if "--device" in cmd:
|
||||
idx = cmd.index("--device")
|
||||
del cmd[idx: idx + 2]
|
||||
cmd = _set_arg(cmd, "--model", model_path)
|
||||
cmd = _merge_args(cmd, args)
|
||||
cmd = _merge_extra_args(cmd, extra)
|
||||
return cmd
|
||||
|
||||
|
||||
def _replace_flags(cmd: list, flags: Dict[str, Optional[str]], extra: str) -> list:
|
||||
result = list(cmd)
|
||||
for flag in flags.keys():
|
||||
while flag in result:
|
||||
idx = result.index(flag)
|
||||
del result[idx: idx + 2]
|
||||
if "--device" in result:
|
||||
idx = result.index("--device")
|
||||
del result[idx: idx + 2]
|
||||
for flag, value in flags.items():
|
||||
if value is not None and value != "":
|
||||
result = _set_arg(result, flag, value)
|
||||
result = _merge_extra_args(result, extra)
|
||||
return result
|
||||
|
||||
|
||||
async def get_app_config(cfg: TrueNASConfig) -> Dict[str, Any]:
|
||||
config = await _rpc_call(cfg, "app.config", [cfg.app_name])
|
||||
if not isinstance(config, dict):
|
||||
raise RuntimeError("app.config returned unsupported payload")
|
||||
return config
|
||||
|
||||
|
||||
async def get_app_command(cfg: TrueNASConfig, service_name: str = "llamacpp") -> list:
|
||||
config = await get_app_config(cfg)
|
||||
return _extract_command(config, service_name=service_name)
|
||||
|
||||
|
||||
async def get_active_model_id(cfg: TrueNASConfig, service_name: str = "llamacpp") -> str:
|
||||
config = await get_app_config(cfg)
|
||||
cmd = _extract_command(config, service_name=service_name)
|
||||
return _model_id_from_command(cmd) or ""
|
||||
|
||||
|
||||
async def get_app_logs(
|
||||
cfg: TrueNASConfig,
|
||||
tail_lines: int = 200,
|
||||
service_name: str = "llamacpp",
|
||||
) -> str:
|
||||
tail_payloads = [
|
||||
{"tail": tail_lines},
|
||||
{"tail_lines": tail_lines},
|
||||
{"tail": str(tail_lines)},
|
||||
]
|
||||
for payload in tail_payloads:
|
||||
try:
|
||||
result = await _rpc_call(cfg, "app.container_logs", [cfg.app_name, service_name, payload])
|
||||
if isinstance(result, str):
|
||||
return result
|
||||
except Exception as exc:
|
||||
log.debug("app.container_logs failed (%s): %s", payload, exc)
|
||||
for payload in tail_payloads:
|
||||
try:
|
||||
result = await _rpc_call(cfg, "app.logs", [cfg.app_name, payload])
|
||||
if isinstance(result, str):
|
||||
return result
|
||||
except Exception as exc:
|
||||
log.debug("app.logs failed (%s): %s", payload, exc)
|
||||
return ""
|
||||
|
||||
|
||||
async def update_app_command(
|
||||
cfg: TrueNASConfig,
|
||||
command: list,
|
||||
service_name: str = "llamacpp",
|
||||
) -> None:
|
||||
config = await _rpc_call(cfg, "app.config", [cfg.app_name])
|
||||
if not isinstance(config, dict):
|
||||
raise RuntimeError("app.config returned unsupported payload")
|
||||
if config.get("custom_compose_config") or config.get("custom_compose_config_string"):
|
||||
compose = _parse_compose(config.get("custom_compose_config") or config.get("custom_compose_config_string") or {})
|
||||
services = compose.get("services") or {}
|
||||
if service_name not in services:
|
||||
raise RuntimeError(f"service {service_name} not found in compose")
|
||||
svc = services[service_name]
|
||||
svc["command"] = command
|
||||
await _rpc_call(cfg, "app.update", [cfg.app_name, {"custom_compose_config": compose}])
|
||||
return
|
||||
config["command"] = command
|
||||
await _rpc_call(cfg, "app.update", [cfg.app_name, {"values": config}])
|
||||
|
||||
|
||||
async def update_command_flags(
|
||||
cfg: TrueNASConfig,
|
||||
flags: Dict[str, Optional[str]],
|
||||
extra: str,
|
||||
service_name: str = "llamacpp",
|
||||
) -> None:
|
||||
config = await _rpc_call(cfg, "app.config", [cfg.app_name])
|
||||
if not isinstance(config, dict):
|
||||
raise RuntimeError("app.config returned unsupported payload")
|
||||
if config.get("custom_compose_config") or config.get("custom_compose_config_string"):
|
||||
compose = _parse_compose(config.get("custom_compose_config") or config.get("custom_compose_config_string") or {})
|
||||
services = compose.get("services") or {}
|
||||
if service_name not in services:
|
||||
raise RuntimeError(f"service {service_name} not found in compose")
|
||||
svc = services[service_name]
|
||||
cmd = svc.get("command")
|
||||
svc["command"] = _replace_flags(_command_to_list(cmd), flags, extra)
|
||||
await _rpc_call(cfg, "app.update", [cfg.app_name, {"custom_compose_config": compose}])
|
||||
return
|
||||
cmd = _replace_flags(_command_to_list(config.get("command")), flags, extra)
|
||||
config["command"] = cmd
|
||||
await _rpc_call(cfg, "app.update", [cfg.app_name, {"values": config}])
|
||||
|
||||
|
||||
async def _rpc_call(cfg: TrueNASConfig, method: str, params: Optional[list] = None) -> Any:
|
||||
ssl_ctx = None
|
||||
if cfg.ws_url.startswith("wss://") and not cfg.verify_ssl:
|
||||
ssl_ctx = ssl.create_default_context()
|
||||
ssl_ctx.check_hostname = False
|
||||
ssl_ctx.verify_mode = ssl.CERT_NONE
|
||||
|
||||
async with websockets.connect(cfg.ws_url, ssl=ssl_ctx) as ws:
|
||||
await ws.send(json.dumps({"msg": "connect", "version": "1", "support": ["1"]}))
|
||||
connected = json.loads(await ws.recv())
|
||||
if connected.get("msg") != "connected":
|
||||
raise RuntimeError("failed to connect to TrueNAS websocket")
|
||||
|
||||
await ws.send(
|
||||
json.dumps({"id": 1, "msg": "method", "method": "auth.login_with_api_key", "params": [cfg.api_key]})
|
||||
)
|
||||
auth_resp = json.loads(await ws.recv())
|
||||
if not auth_resp.get("result"):
|
||||
if not cfg.api_user:
|
||||
raise RuntimeError("API key rejected and TRUENAS_API_USER not set")
|
||||
await ws.send(
|
||||
json.dumps(
|
||||
{
|
||||
"id": 2,
|
||||
"msg": "method",
|
||||
"method": "auth.login_ex",
|
||||
"params": [
|
||||
{
|
||||
"mechanism": "API_KEY_PLAIN",
|
||||
"username": cfg.api_user,
|
||||
"api_key": cfg.api_key,
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
)
|
||||
auth_ex = json.loads(await ws.recv())
|
||||
if auth_ex.get("result", {}).get("response_type") != "SUCCESS":
|
||||
raise RuntimeError("API key authentication failed")
|
||||
|
||||
req_id = 3
|
||||
await ws.send(json.dumps({"id": req_id, "msg": "method", "method": method, "params": params or []}))
|
||||
while True:
|
||||
raw = json.loads(await ws.recv())
|
||||
if raw.get("id") != req_id:
|
||||
continue
|
||||
if raw.get("msg") == "error":
|
||||
raise RuntimeError(raw.get("error"))
|
||||
return raw.get("result")
|
||||
|
||||
|
||||
async def switch_model(
|
||||
cfg: TrueNASConfig,
|
||||
model_path: str,
|
||||
args: Dict[str, str],
|
||||
extra: str,
|
||||
service_name: str = "llamacpp",
|
||||
) -> None:
|
||||
config = await _rpc_call(cfg, "app.config", [cfg.app_name])
|
||||
if config.get("custom_compose_config") or config.get("custom_compose_config_string"):
|
||||
compose = _parse_compose(config.get("custom_compose_config") or config.get("custom_compose_config_string") or {})
|
||||
services = compose.get("services") or {}
|
||||
if service_name not in services:
|
||||
raise RuntimeError(f"service {service_name} not found in compose")
|
||||
svc = services[service_name]
|
||||
cmd = svc.get("command")
|
||||
svc["command"] = _update_model_command(cmd, model_path, args, extra)
|
||||
await _rpc_call(cfg, "app.update", [cfg.app_name, {"custom_compose_config": compose}])
|
||||
log.info("Requested model switch to %s via TrueNAS middleware (custom app)", model_path)
|
||||
return
|
||||
|
||||
if not isinstance(config, dict):
|
||||
raise RuntimeError("app.config returned unsupported payload")
|
||||
|
||||
cmd = config.get("command")
|
||||
config["command"] = _update_model_command(cmd, model_path, args, extra)
|
||||
await _rpc_call(cfg, "app.update", [cfg.app_name, {"values": config}])
|
||||
log.info("Requested model switch to %s via TrueNAS middleware (catalog app)", model_path)
|
||||
Reference in New Issue
Block a user