117 lines
4.0 KiB
Python
117 lines
4.0 KiB
Python
import argparse
|
|
import asyncio
|
|
import json
|
|
import ssl
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
import websockets
|
|
|
|
|
|
async def _rpc_call(ws_url: str, api_key: str, method: str, params: Optional[list] = None, verify_ssl: bool = False) -> Any:
|
|
ssl_ctx = None
|
|
if ws_url.startswith("wss://") and not verify_ssl:
|
|
ssl_ctx = ssl.create_default_context()
|
|
ssl_ctx.check_hostname = False
|
|
ssl_ctx.verify_mode = ssl.CERT_NONE
|
|
|
|
async with websockets.connect(ws_url, ssl=ssl_ctx) as ws:
|
|
await ws.send(json.dumps({"msg": "connect", "version": "1", "support": ["1"]}))
|
|
connected = json.loads(await ws.recv())
|
|
if connected.get("msg") != "connected":
|
|
raise RuntimeError("failed to connect to TrueNAS websocket")
|
|
|
|
await ws.send(json.dumps({"id": 1, "msg": "method", "method": "auth.login_with_api_key", "params": [api_key]}))
|
|
auth_resp = json.loads(await ws.recv())
|
|
if not auth_resp.get("result"):
|
|
raise RuntimeError("API key authentication failed")
|
|
|
|
req_id = 2
|
|
await ws.send(json.dumps({"id": req_id, "msg": "method", "method": method, "params": params or []}))
|
|
while True:
|
|
raw = json.loads(await ws.recv())
|
|
if raw.get("id") != req_id:
|
|
continue
|
|
if raw.get("msg") == "error":
|
|
raise RuntimeError(raw.get("error"))
|
|
return raw.get("result")
|
|
|
|
|
|
async def main() -> None:
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--ws-url", required=True)
|
|
parser.add_argument("--api-key", required=True)
|
|
parser.add_argument("--api-user")
|
|
parser.add_argument("--app-name", required=True)
|
|
parser.add_argument("--image", required=True)
|
|
parser.add_argument("--model-host-path", required=True)
|
|
parser.add_argument("--llamacpp-base-url", required=True)
|
|
parser.add_argument("--network", required=True)
|
|
parser.add_argument("--api-port", type=int, default=9091)
|
|
parser.add_argument("--ui-port", type=int, default=9092)
|
|
parser.add_argument("--verify-ssl", action="store_true")
|
|
args = parser.parse_args()
|
|
|
|
api_port = args.api_port
|
|
ui_port = args.ui_port
|
|
|
|
env = {
|
|
"PORT_A": str(api_port),
|
|
"PORT_B": str(ui_port),
|
|
"LLAMACPP_BASE_URL": args.llamacpp_base_url,
|
|
"MODEL_DIR": "/models",
|
|
"TRUENAS_WS_URL": args.ws_url,
|
|
"TRUENAS_API_KEY": args.api_key,
|
|
"TRUENAS_APP_NAME": "llamacpp",
|
|
"TRUENAS_VERIFY_SSL": "false",
|
|
}
|
|
if args.api_user:
|
|
env["TRUENAS_API_USER"] = args.api_user
|
|
|
|
compose = {
|
|
"services": {
|
|
"wrapper": {
|
|
"image": args.image,
|
|
"restart": "unless-stopped",
|
|
"ports": [
|
|
f"{api_port}:{api_port}",
|
|
f"{ui_port}:{ui_port}",
|
|
],
|
|
"environment": env,
|
|
"volumes": [
|
|
f"{args.model_host_path}:/models",
|
|
"/var/run/docker.sock:/var/run/docker.sock",
|
|
],
|
|
"networks": ["llamacpp_net"],
|
|
}
|
|
},
|
|
"networks": {
|
|
"llamacpp_net": {"external": True, "name": args.network}
|
|
},
|
|
}
|
|
|
|
create_payload = {
|
|
"custom_app": True,
|
|
"app_name": args.app_name,
|
|
"custom_compose_config": compose,
|
|
}
|
|
|
|
existing = await _rpc_call(args.ws_url, args.api_key, "app.query", [[["id", "=", args.app_name]]], args.verify_ssl)
|
|
if existing:
|
|
result = await _rpc_call(
|
|
args.ws_url,
|
|
args.api_key,
|
|
"app.update",
|
|
[args.app_name, {"custom_compose_config": compose}],
|
|
args.verify_ssl,
|
|
)
|
|
action = "updated"
|
|
else:
|
|
result = await _rpc_call(args.ws_url, args.api_key, "app.create", [create_payload], args.verify_ssl)
|
|
action = "created"
|
|
|
|
print(json.dumps({"action": action, "api_port": api_port, "ui_port": ui_port, "result": result}, indent=2))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
asyncio.run(main())
|