fix: enforce ws3 roles + node allowlist

This commit is contained in:
Peter Steinberger
2026-01-20 09:23:56 +00:00
parent 32a668e4d9
commit 9dbc1435a6
27 changed files with 3096 additions and 40 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -2713,6 +2713,29 @@ macOS app behavior:
} }
``` ```
### `gateway.nodes` (Node command allowlist)
The Gateway enforces a per-platform command allowlist for `node.invoke`. Nodes must both
**declare** a command and have it **allowed** by the Gateway to run it.
Use this section to extend or deny commands:
```json5
{
gateway: {
nodes: {
allowCommands: ["custom.vendor.command"], // extra commands beyond defaults
denyCommands: ["sms.send"] // block a command even if declared
}
}
}
```
Notes:
- `allowCommands` extends the built-in per-platform defaults.
- `denyCommands` always wins (even if the node claims the command).
- `node.invoke` rejects commands that are not declared by the node.
### `gateway.reload` (Config hot reload) ### `gateway.reload` (Config hot reload)
The Gateway watches `~/.clawdbot/clawdbot.json` (or `CLAWDBOT_CONFIG_PATH`) and applies changes automatically. The Gateway watches `~/.clawdbot/clawdbot.json` (or `CLAWDBOT_CONFIG_PATH`) and applies changes automatically.

View File

@@ -123,6 +123,11 @@ Nodes declare capability claims at connect time:
The Gateway treats these as **claims** and enforces server-side allowlists. The Gateway treats these as **claims** and enforces server-side allowlists.
### Node helper methods
- Nodes may call `skills.bins` to fetch the current list of skill executables
for auto-allow checks.
## Versioning ## Versioning
- `PROTOCOL_VERSION` lives in `src/gateway/protocol/schema.ts`. - `PROTOCOL_VERSION` lives in `src/gateway/protocol/schema.ts`.
@@ -144,6 +149,7 @@ The Gateway treats these as **claims** and enforces server-side allowlists.
- Gateways issue tokens per device + role. - Gateways issue tokens per device + role.
- Pairing approvals are required for new device IDs unless local auto-approval - Pairing approvals are required for new device IDs unless local auto-approval
is enabled. is enabled.
- All WS clients must include `device` identity during `connect` (operator + node).
## TLS + pinning ## TLS + pinning

View File

@@ -6,6 +6,12 @@ import path from "node:path";
import { describe, expect, it, vi } from "vitest"; import { describe, expect, it, vi } from "vitest";
import { WebSocket } from "ws"; import { WebSocket } from "ws";
import {
loadOrCreateDeviceIdentity,
publicKeyRawBase64UrlFromPem,
signDevicePayload,
} from "../infra/device-identity.js";
import { buildDeviceAuthPayload } from "../gateway/device-auth.js";
import { PROTOCOL_VERSION } from "../gateway/protocol/index.js"; import { PROTOCOL_VERSION } from "../gateway/protocol/index.js";
import { rawDataToString } from "../infra/ws.js"; import { rawDataToString } from "../infra/ws.js";
import { GATEWAY_CLIENT_MODES, GATEWAY_CLIENT_NAMES } from "../utils/message-channel.js"; import { GATEWAY_CLIENT_MODES, GATEWAY_CLIENT_NAMES } from "../utils/message-channel.js";
@@ -58,6 +64,23 @@ async function onceMessage<T = unknown>(
async function connectReq(params: { url: string; token?: string }) { async function connectReq(params: { url: string; token?: string }) {
const ws = new WebSocket(params.url); const ws = new WebSocket(params.url);
await new Promise<void>((resolve) => ws.once("open", resolve)); await new Promise<void>((resolve) => ws.once("open", resolve));
const identity = loadOrCreateDeviceIdentity();
const signedAtMs = Date.now();
const payload = buildDeviceAuthPayload({
deviceId: identity.deviceId,
clientId: GATEWAY_CLIENT_NAMES.TEST,
clientMode: GATEWAY_CLIENT_MODES.TEST,
role: "operator",
scopes: [],
signedAtMs,
token: params.token ?? null,
});
const device = {
id: identity.deviceId,
publicKey: publicKeyRawBase64UrlFromPem(identity.publicKeyPem),
signature: signDevicePayload(identity.privateKeyPem, payload),
signedAt: signedAtMs,
};
ws.send( ws.send(
JSON.stringify({ JSON.stringify({
type: "req", type: "req",
@@ -75,6 +98,7 @@ async function connectReq(params: { url: string; token?: string }) {
}, },
caps: [], caps: [],
auth: params.token ? { token: params.token } : undefined, auth: params.token ? { token: params.token } : undefined,
device,
}, },
}), }),
); );

View File

@@ -6,6 +6,12 @@ import path from "node:path";
import { describe, expect, it, vi } from "vitest"; import { describe, expect, it, vi } from "vitest";
import { WebSocket } from "ws"; import { WebSocket } from "ws";
import {
loadOrCreateDeviceIdentity,
publicKeyRawBase64UrlFromPem,
signDevicePayload,
} from "../infra/device-identity.js";
import { buildDeviceAuthPayload } from "../gateway/device-auth.js";
import { PROTOCOL_VERSION } from "../gateway/protocol/index.js"; import { PROTOCOL_VERSION } from "../gateway/protocol/index.js";
import { getFreePort as getFreeTestPort } from "../gateway/test-helpers.js"; import { getFreePort as getFreeTestPort } from "../gateway/test-helpers.js";
import { rawDataToString } from "../infra/ws.js"; import { rawDataToString } from "../infra/ws.js";
@@ -64,6 +70,23 @@ async function onceMessage<T = unknown>(
async function connectReq(params: { url: string; token?: string }) { async function connectReq(params: { url: string; token?: string }) {
const ws = new WebSocket(params.url); const ws = new WebSocket(params.url);
await new Promise<void>((resolve) => ws.once("open", resolve)); await new Promise<void>((resolve) => ws.once("open", resolve));
const identity = loadOrCreateDeviceIdentity();
const signedAtMs = Date.now();
const payload = buildDeviceAuthPayload({
deviceId: identity.deviceId,
clientId: GATEWAY_CLIENT_NAMES.TEST,
clientMode: GATEWAY_CLIENT_MODES.TEST,
role: "operator",
scopes: [],
signedAtMs,
token: params.token ?? null,
});
const device = {
id: identity.deviceId,
publicKey: publicKeyRawBase64UrlFromPem(identity.publicKeyPem),
signature: signDevicePayload(identity.privateKeyPem, payload),
signedAt: signedAtMs,
};
ws.send( ws.send(
JSON.stringify({ JSON.stringify({
type: "req", type: "req",
@@ -81,6 +104,7 @@ async function connectReq(params: { url: string; token?: string }) {
}, },
caps: [], caps: [],
auth: params.token ? { token: params.token } : undefined, auth: params.token ? { token: params.token } : undefined,
device,
}, },
}), }),
); );

View File

@@ -169,6 +169,8 @@ const FIELD_LABELS: Record<string, string> = {
"gateway.http.endpoints.chatCompletions.enabled": "OpenAI Chat Completions Endpoint", "gateway.http.endpoints.chatCompletions.enabled": "OpenAI Chat Completions Endpoint",
"gateway.reload.mode": "Config Reload Mode", "gateway.reload.mode": "Config Reload Mode",
"gateway.reload.debounceMs": "Config Reload Debounce (ms)", "gateway.reload.debounceMs": "Config Reload Debounce (ms)",
"gateway.nodes.allowCommands": "Gateway Node Allowlist (Extra Commands)",
"gateway.nodes.denyCommands": "Gateway Node Denylist",
"skills.load.watch": "Watch Skills", "skills.load.watch": "Watch Skills",
"skills.load.watchDebounceMs": "Skills Watch Debounce (ms)", "skills.load.watchDebounceMs": "Skills Watch Debounce (ms)",
"agents.defaults.workspace": "Workspace", "agents.defaults.workspace": "Workspace",
@@ -318,6 +320,10 @@ const FIELD_HELP: Record<string, string> = {
"Enable the OpenAI-compatible `POST /v1/chat/completions` endpoint (default: false).", "Enable the OpenAI-compatible `POST /v1/chat/completions` endpoint (default: false).",
"gateway.reload.mode": 'Hot reload strategy for config changes ("hybrid" recommended).', "gateway.reload.mode": 'Hot reload strategy for config changes ("hybrid" recommended).',
"gateway.reload.debounceMs": "Debounce window (ms) before applying config changes.", "gateway.reload.debounceMs": "Debounce window (ms) before applying config changes.",
"gateway.nodes.allowCommands":
"Extra node.invoke commands to allow beyond the gateway defaults (array of command strings).",
"gateway.nodes.denyCommands":
"Commands to block even if present in node claims or default allowlist.",
"tools.exec.applyPatch.enabled": "tools.exec.applyPatch.enabled":
"Experimental. Enables apply_patch for OpenAI models when allowed by tool policy.", "Experimental. Enables apply_patch for OpenAI models when allowed by tool policy.",
"tools.exec.applyPatch.allowModels": "tools.exec.applyPatch.allowModels":

View File

@@ -170,6 +170,13 @@ export type GatewayHttpConfig = {
endpoints?: GatewayHttpEndpointsConfig; endpoints?: GatewayHttpEndpointsConfig;
}; };
export type GatewayNodesConfig = {
/** Additional node.invoke commands to allow on the gateway. */
allowCommands?: string[];
/** Commands to deny even if they appear in the defaults or node claims. */
denyCommands?: string[];
};
export type GatewayConfig = { export type GatewayConfig = {
/** Single multiplexed port for Gateway WS + HTTP (default: 18789). */ /** Single multiplexed port for Gateway WS + HTTP (default: 18789). */
port?: number; port?: number;
@@ -196,4 +203,5 @@ export type GatewayConfig = {
reload?: GatewayReloadConfig; reload?: GatewayReloadConfig;
tls?: GatewayTlsConfig; tls?: GatewayTlsConfig;
http?: GatewayHttpConfig; http?: GatewayHttpConfig;
nodes?: GatewayNodesConfig;
}; };

View File

@@ -341,6 +341,13 @@ export const ClawdbotSchema = z
}) })
.strict() .strict()
.optional(), .optional(),
nodes: z
.object({
allowCommands: z.array(z.string()).optional(),
denyCommands: z.array(z.string()).optional(),
})
.strict()
.optional(),
}) })
.strict() .strict()
.optional(), .optional(),

View File

@@ -3,7 +3,11 @@ import { WebSocket, type ClientOptions, type CertMeta } from "ws";
import { rawDataToString } from "../infra/ws.js"; import { rawDataToString } from "../infra/ws.js";
import { logDebug, logError } from "../logger.js"; import { logDebug, logError } from "../logger.js";
import type { DeviceIdentity } from "../infra/device-identity.js"; import type { DeviceIdentity } from "../infra/device-identity.js";
import { publicKeyRawBase64UrlFromPem, signDevicePayload } from "../infra/device-identity.js"; import {
loadOrCreateDeviceIdentity,
publicKeyRawBase64UrlFromPem,
signDevicePayload,
} from "../infra/device-identity.js";
import { import {
GATEWAY_CLIENT_MODES, GATEWAY_CLIENT_MODES,
GATEWAY_CLIENT_NAMES, GATEWAY_CLIENT_NAMES,
@@ -78,7 +82,10 @@ export class GatewayClient {
private tickTimer: NodeJS.Timeout | null = null; private tickTimer: NodeJS.Timeout | null = null;
constructor(opts: GatewayClientOptions) { constructor(opts: GatewayClientOptions) {
this.opts = opts; this.opts = {
...opts,
deviceIdentity: opts.deviceIdentity ?? loadOrCreateDeviceIdentity(),
};
} }
start() { start() {

View File

@@ -6,9 +6,15 @@ import path from "node:path";
import { describe, expect, it } from "vitest"; import { describe, expect, it } from "vitest";
import { WebSocket } from "ws"; import { WebSocket } from "ws";
import {
loadOrCreateDeviceIdentity,
publicKeyRawBase64UrlFromPem,
signDevicePayload,
} from "../infra/device-identity.js";
import { rawDataToString } from "../infra/ws.js"; import { rawDataToString } from "../infra/ws.js";
import { getDeterministicFreePortBlock } from "../test-utils/ports.js"; import { getDeterministicFreePortBlock } from "../test-utils/ports.js";
import { GATEWAY_CLIENT_MODES, GATEWAY_CLIENT_NAMES } from "../utils/message-channel.js"; import { GATEWAY_CLIENT_MODES, GATEWAY_CLIENT_NAMES } from "../utils/message-channel.js";
import { buildDeviceAuthPayload } from "./device-auth.js";
import { PROTOCOL_VERSION } from "./protocol/index.js"; import { PROTOCOL_VERSION } from "./protocol/index.js";
async function getFreeGatewayPort(): Promise<number> { async function getFreeGatewayPort(): Promise<number> {
@@ -43,6 +49,23 @@ async function onceMessage<T = unknown>(
async function connectReq(params: { url: string; token?: string }) { async function connectReq(params: { url: string; token?: string }) {
const ws = new WebSocket(params.url); const ws = new WebSocket(params.url);
await new Promise<void>((resolve) => ws.once("open", resolve)); await new Promise<void>((resolve) => ws.once("open", resolve));
const identity = loadOrCreateDeviceIdentity();
const signedAtMs = Date.now();
const payload = buildDeviceAuthPayload({
deviceId: identity.deviceId,
clientId: GATEWAY_CLIENT_NAMES.TEST,
clientMode: GATEWAY_CLIENT_MODES.TEST,
role: "operator",
scopes: [],
signedAtMs,
token: params.token ?? null,
});
const device = {
id: identity.deviceId,
publicKey: publicKeyRawBase64UrlFromPem(identity.publicKeyPem),
signature: signDevicePayload(identity.privateKeyPem, payload),
signedAt: signedAtMs,
};
ws.send( ws.send(
JSON.stringify({ JSON.stringify({
type: "req", type: "req",
@@ -60,6 +83,7 @@ async function connectReq(params: { url: string; token?: string }) {
}, },
caps: [], caps: [],
auth: params.token ? { token: params.token } : undefined, auth: params.token ? { token: params.token } : undefined,
device,
}, },
}), }),
); );

View File

@@ -0,0 +1,110 @@
import type { ClawdbotConfig } from "../config/config.js";
import type { NodeSession } from "./node-registry.js";
const CANVAS_COMMANDS = [
"canvas.present",
"canvas.hide",
"canvas.navigate",
"canvas.eval",
"canvas.snapshot",
"canvas.a2ui.push",
"canvas.a2ui.pushJSONL",
"canvas.a2ui.reset",
];
const CAMERA_COMMANDS = ["camera.list", "camera.snap", "camera.clip"];
const SCREEN_COMMANDS = ["screen.record"];
const LOCATION_COMMANDS = ["location.get"];
const SMS_COMMANDS = ["sms.send"];
const SYSTEM_COMMANDS = [
"system.run",
"system.which",
"system.notify",
"system.execApprovals.get",
"system.execApprovals.set",
];
const PLATFORM_DEFAULTS: Record<string, string[]> = {
ios: [...CANVAS_COMMANDS, ...CAMERA_COMMANDS, ...SCREEN_COMMANDS, ...LOCATION_COMMANDS],
android: [
...CANVAS_COMMANDS,
...CAMERA_COMMANDS,
...SCREEN_COMMANDS,
...LOCATION_COMMANDS,
...SMS_COMMANDS,
],
macos: [
...CANVAS_COMMANDS,
...CAMERA_COMMANDS,
...SCREEN_COMMANDS,
...LOCATION_COMMANDS,
...SYSTEM_COMMANDS,
],
linux: [...SYSTEM_COMMANDS],
windows: [...SYSTEM_COMMANDS],
unknown: [
...CANVAS_COMMANDS,
...CAMERA_COMMANDS,
...SCREEN_COMMANDS,
...LOCATION_COMMANDS,
...SMS_COMMANDS,
...SYSTEM_COMMANDS,
],
};
function normalizePlatformId(platform?: string, deviceFamily?: string): string {
const raw = (platform ?? "").trim().toLowerCase();
if (raw.startsWith("ios")) return "ios";
if (raw.startsWith("android")) return "android";
if (raw.startsWith("mac")) return "macos";
if (raw.startsWith("darwin")) return "macos";
if (raw.startsWith("win")) return "windows";
if (raw.startsWith("linux")) return "linux";
const family = (deviceFamily ?? "").trim().toLowerCase();
if (family.includes("iphone") || family.includes("ipad") || family.includes("ios")) return "ios";
if (family.includes("android")) return "android";
if (family.includes("mac")) return "macos";
if (family.includes("windows")) return "windows";
if (family.includes("linux")) return "linux";
return "unknown";
}
export function resolveNodeCommandAllowlist(
cfg: ClawdbotConfig,
node?: Pick<NodeSession, "platform" | "deviceFamily">,
): Set<string> {
const platformId = normalizePlatformId(node?.platform, node?.deviceFamily);
const base = PLATFORM_DEFAULTS[platformId] ?? PLATFORM_DEFAULTS.unknown;
const extra = cfg.gateway?.nodes?.allowCommands ?? [];
const deny = new Set(cfg.gateway?.nodes?.denyCommands ?? []);
const allow = new Set([...base, ...extra].map((cmd) => cmd.trim()).filter(Boolean));
for (const blocked of deny) {
const trimmed = blocked.trim();
if (trimmed) allow.delete(trimmed);
}
return allow;
}
export function isNodeCommandAllowed(params: {
command: string;
declaredCommands?: string[];
allowlist: Set<string>;
}): { ok: true } | { ok: false; reason: string } {
const command = params.command.trim();
if (!command) return { ok: false, reason: "command required" };
if (!params.allowlist.has(command)) {
return { ok: false, reason: "command not allowlisted" };
}
if (Array.isArray(params.declaredCommands) && params.declaredCommands.length > 0) {
if (!params.declaredCommands.includes(command)) {
return { ok: false, reason: "command not declared by node" };
}
} else {
return { ok: false, reason: "node did not declare commands" };
}
return { ok: true };
}

View File

@@ -155,6 +155,7 @@ export class NodeRegistry {
}): boolean { }): boolean {
const pending = this.pendingInvokes.get(params.id); const pending = this.pendingInvokes.get(params.id);
if (!pending) return false; if (!pending) return false;
if (pending.nodeId !== params.nodeId) return false;
clearTimeout(pending.timer); clearTimeout(pending.timer);
this.pendingInvokes.delete(params.id); this.pendingInvokes.delete(params.id);
pending.resolve({ pending.resolve({

View File

@@ -138,6 +138,10 @@ import {
SessionsResolveParamsSchema, SessionsResolveParamsSchema,
type ShutdownEvent, type ShutdownEvent,
ShutdownEventSchema, ShutdownEventSchema,
type SkillsBinsParams,
SkillsBinsParamsSchema,
type SkillsBinsResult,
SkillsBinsResultSchema,
type SkillsInstallParams, type SkillsInstallParams,
SkillsInstallParamsSchema, SkillsInstallParamsSchema,
type SkillsStatusParams, type SkillsStatusParams,
@@ -247,6 +251,7 @@ export const validateChannelsLogoutParams = ajv.compile<ChannelsLogoutParams>(
); );
export const validateModelsListParams = ajv.compile<ModelsListParams>(ModelsListParamsSchema); export const validateModelsListParams = ajv.compile<ModelsListParams>(ModelsListParamsSchema);
export const validateSkillsStatusParams = ajv.compile<SkillsStatusParams>(SkillsStatusParamsSchema); export const validateSkillsStatusParams = ajv.compile<SkillsStatusParams>(SkillsStatusParamsSchema);
export const validateSkillsBinsParams = ajv.compile<SkillsBinsParams>(SkillsBinsParamsSchema);
export const validateSkillsInstallParams = export const validateSkillsInstallParams =
ajv.compile<SkillsInstallParams>(SkillsInstallParamsSchema); ajv.compile<SkillsInstallParams>(SkillsInstallParamsSchema);
export const validateSkillsUpdateParams = ajv.compile<SkillsUpdateParams>(SkillsUpdateParamsSchema); export const validateSkillsUpdateParams = ajv.compile<SkillsUpdateParams>(SkillsUpdateParamsSchema);
@@ -424,6 +429,8 @@ export type {
AgentsListParams, AgentsListParams,
AgentsListResult, AgentsListResult,
SkillsStatusParams, SkillsStatusParams,
SkillsBinsParams,
SkillsBinsResult,
SkillsInstallParams, SkillsInstallParams,
SkillsUpdateParams, SkillsUpdateParams,
NodePairRejectParams, NodePairRejectParams,

View File

@@ -44,6 +44,15 @@ export const ModelsListResultSchema = Type.Object(
export const SkillsStatusParamsSchema = Type.Object({}, { additionalProperties: false }); export const SkillsStatusParamsSchema = Type.Object({}, { additionalProperties: false });
export const SkillsBinsParamsSchema = Type.Object({}, { additionalProperties: false });
export const SkillsBinsResultSchema = Type.Object(
{
bins: Type.Array(NonEmptyString),
},
{ additionalProperties: false },
);
export const SkillsInstallParamsSchema = Type.Object( export const SkillsInstallParamsSchema = Type.Object(
{ {
name: NonEmptyString, name: NonEmptyString,

View File

@@ -39,16 +39,14 @@ export const ConnectParamsSchema = Type.Object(
permissions: Type.Optional(Type.Record(NonEmptyString, Type.Boolean())), permissions: Type.Optional(Type.Record(NonEmptyString, Type.Boolean())),
role: Type.Optional(NonEmptyString), role: Type.Optional(NonEmptyString),
scopes: Type.Optional(Type.Array(NonEmptyString)), scopes: Type.Optional(Type.Array(NonEmptyString)),
device: Type.Optional( device: Type.Object(
Type.Object( {
{ id: NonEmptyString,
id: NonEmptyString, publicKey: NonEmptyString,
publicKey: NonEmptyString, signature: NonEmptyString,
signature: NonEmptyString, signedAt: Type.Integer({ minimum: 0 }),
signedAt: Type.Integer({ minimum: 0 }), },
}, { additionalProperties: false },
{ additionalProperties: false },
),
), ),
auth: Type.Optional( auth: Type.Optional(
Type.Object( Type.Object(

View File

@@ -15,6 +15,8 @@ import {
ModelChoiceSchema, ModelChoiceSchema,
ModelsListParamsSchema, ModelsListParamsSchema,
ModelsListResultSchema, ModelsListResultSchema,
SkillsBinsParamsSchema,
SkillsBinsResultSchema,
SkillsInstallParamsSchema, SkillsInstallParamsSchema,
SkillsStatusParamsSchema, SkillsStatusParamsSchema,
SkillsUpdateParamsSchema, SkillsUpdateParamsSchema,
@@ -179,6 +181,8 @@ export const ProtocolSchemas: Record<string, TSchema> = {
ModelsListParams: ModelsListParamsSchema, ModelsListParams: ModelsListParamsSchema,
ModelsListResult: ModelsListResultSchema, ModelsListResult: ModelsListResultSchema,
SkillsStatusParams: SkillsStatusParamsSchema, SkillsStatusParams: SkillsStatusParamsSchema,
SkillsBinsParams: SkillsBinsParamsSchema,
SkillsBinsResult: SkillsBinsResultSchema,
SkillsInstallParams: SkillsInstallParamsSchema, SkillsInstallParams: SkillsInstallParamsSchema,
SkillsUpdateParams: SkillsUpdateParamsSchema, SkillsUpdateParams: SkillsUpdateParamsSchema,
CronJob: CronJobSchema, CronJob: CronJobSchema,

View File

@@ -13,6 +13,8 @@ import type {
ModelChoiceSchema, ModelChoiceSchema,
ModelsListParamsSchema, ModelsListParamsSchema,
ModelsListResultSchema, ModelsListResultSchema,
SkillsBinsParamsSchema,
SkillsBinsResultSchema,
SkillsInstallParamsSchema, SkillsInstallParamsSchema,
SkillsStatusParamsSchema, SkillsStatusParamsSchema,
SkillsUpdateParamsSchema, SkillsUpdateParamsSchema,
@@ -168,6 +170,8 @@ export type ModelChoice = Static<typeof ModelChoiceSchema>;
export type ModelsListParams = Static<typeof ModelsListParamsSchema>; export type ModelsListParams = Static<typeof ModelsListParamsSchema>;
export type ModelsListResult = Static<typeof ModelsListResultSchema>; export type ModelsListResult = Static<typeof ModelsListResultSchema>;
export type SkillsStatusParams = Static<typeof SkillsStatusParamsSchema>; export type SkillsStatusParams = Static<typeof SkillsStatusParamsSchema>;
export type SkillsBinsParams = Static<typeof SkillsBinsParamsSchema>;
export type SkillsBinsResult = Static<typeof SkillsBinsResultSchema>;
export type SkillsInstallParams = Static<typeof SkillsInstallParamsSchema>; export type SkillsInstallParams = Static<typeof SkillsInstallParamsSchema>;
export type SkillsUpdateParams = Static<typeof SkillsUpdateParamsSchema>; export type SkillsUpdateParams = Static<typeof SkillsUpdateParamsSchema>;
export type CronJob = Static<typeof CronJobSchema>; export type CronJob = Static<typeof CronJobSchema>;

View File

@@ -27,6 +27,7 @@ const BASE_METHODS = [
"models.list", "models.list",
"agents.list", "agents.list",
"skills.status", "skills.status",
"skills.bins",
"skills.install", "skills.install",
"skills.update", "skills.update",
"update.run", "update.run",

View File

@@ -25,11 +25,13 @@ import { webHandlers } from "./server-methods/web.js";
import { wizardHandlers } from "./server-methods/wizard.js"; import { wizardHandlers } from "./server-methods/wizard.js";
const ADMIN_SCOPE = "operator.admin"; const ADMIN_SCOPE = "operator.admin";
const READ_SCOPE = "operator.read";
const WRITE_SCOPE = "operator.write";
const APPROVALS_SCOPE = "operator.approvals"; const APPROVALS_SCOPE = "operator.approvals";
const PAIRING_SCOPE = "operator.pairing"; const PAIRING_SCOPE = "operator.pairing";
const APPROVAL_METHODS = new Set(["exec.approval.request", "exec.approval.resolve"]); const APPROVAL_METHODS = new Set(["exec.approval.request", "exec.approval.resolve"]);
const NODE_ROLE_METHODS = new Set(["node.invoke.result", "node.event"]); const NODE_ROLE_METHODS = new Set(["node.invoke.result", "node.event", "skills.bins"]);
const PAIRING_METHODS = new Set([ const PAIRING_METHODS = new Set([
"node.pair.request", "node.pair.request",
"node.pair.list", "node.pair.list",
@@ -39,15 +41,51 @@ const PAIRING_METHODS = new Set([
"device.pair.list", "device.pair.list",
"device.pair.approve", "device.pair.approve",
"device.pair.reject", "device.pair.reject",
"node.rename",
]); ]);
const ADMIN_METHOD_PREFIXES = ["exec.approvals."]; const ADMIN_METHOD_PREFIXES = ["exec.approvals."];
const READ_METHODS = new Set([
"health",
"logs.tail",
"channels.status",
"status",
"usage.status",
"usage.cost",
"models.list",
"agents.list",
"skills.status",
"voicewake.get",
"sessions.list",
"cron.list",
"cron.status",
"cron.runs",
"system-presence",
"last-heartbeat",
"node.list",
"node.describe",
"chat.history",
]);
const WRITE_METHODS = new Set([
"send",
"agent",
"agent.wait",
"wake",
"talk.mode",
"voicewake.set",
"node.invoke",
"chat.send",
"chat.abort",
]);
function authorizeGatewayMethod(method: string, client: GatewayRequestOptions["client"]) { function authorizeGatewayMethod(method: string, client: GatewayRequestOptions["client"]) {
if (!client?.connect) return null; if (!client?.connect) return null;
const role = client.connect.role ?? "operator"; const role = client.connect.role ?? "operator";
const scopes = client.connect.scopes ?? []; const scopes = client.connect.scopes ?? [];
if (NODE_ROLE_METHODS.has(method)) {
if (role === "node") return null;
return errorShape(ErrorCodes.INVALID_REQUEST, `unauthorized role: ${role}`);
}
if (role === "node") { if (role === "node") {
if (NODE_ROLE_METHODS.has(method)) return null;
return errorShape(ErrorCodes.INVALID_REQUEST, `unauthorized role: ${role}`); return errorShape(ErrorCodes.INVALID_REQUEST, `unauthorized role: ${role}`);
} }
if (role !== "operator") { if (role !== "operator") {
@@ -60,10 +98,38 @@ function authorizeGatewayMethod(method: string, client: GatewayRequestOptions["c
if (PAIRING_METHODS.has(method) && !scopes.includes(PAIRING_SCOPE)) { if (PAIRING_METHODS.has(method) && !scopes.includes(PAIRING_SCOPE)) {
return errorShape(ErrorCodes.INVALID_REQUEST, "missing scope: operator.pairing"); return errorShape(ErrorCodes.INVALID_REQUEST, "missing scope: operator.pairing");
} }
if (READ_METHODS.has(method) && !(scopes.includes(READ_SCOPE) || scopes.includes(WRITE_SCOPE))) {
return errorShape(ErrorCodes.INVALID_REQUEST, "missing scope: operator.read");
}
if (WRITE_METHODS.has(method) && !scopes.includes(WRITE_SCOPE)) {
return errorShape(ErrorCodes.INVALID_REQUEST, "missing scope: operator.write");
}
if (APPROVAL_METHODS.has(method)) return null;
if (PAIRING_METHODS.has(method)) return null;
if (READ_METHODS.has(method)) return null;
if (WRITE_METHODS.has(method)) return null;
if (ADMIN_METHOD_PREFIXES.some((prefix) => method.startsWith(prefix))) { if (ADMIN_METHOD_PREFIXES.some((prefix) => method.startsWith(prefix))) {
return errorShape(ErrorCodes.INVALID_REQUEST, "missing scope: operator.admin"); return errorShape(ErrorCodes.INVALID_REQUEST, "missing scope: operator.admin");
} }
return null; if (
method.startsWith("config.") ||
method.startsWith("wizard.") ||
method.startsWith("update.") ||
method === "channels.logout" ||
method === "skills.install" ||
method === "skills.update" ||
method === "cron.add" ||
method === "cron.update" ||
method === "cron.remove" ||
method === "cron.run" ||
method === "sessions.patch" ||
method === "sessions.reset" ||
method === "sessions.delete" ||
method === "sessions.compact"
) {
return errorShape(ErrorCodes.INVALID_REQUEST, "missing scope: operator.admin");
}
return errorShape(ErrorCodes.INVALID_REQUEST, "missing scope: operator.admin");
} }
export const coreGatewayHandlers: GatewayRequestHandlers = { export const coreGatewayHandlers: GatewayRequestHandlers = {

View File

@@ -28,6 +28,11 @@ import {
safeParseJson, safeParseJson,
uniqueSortedStrings, uniqueSortedStrings,
} from "./nodes.helpers.js"; } from "./nodes.helpers.js";
import { loadConfig } from "../../config/config.js";
import {
isNodeCommandAllowed,
resolveNodeCommandAllowlist,
} from "../node-command-policy.js";
import type { GatewayRequestHandlers } from "./types.js"; import type { GatewayRequestHandlers } from "./types.js";
function isNodeEntry(entry: { role?: string; roles?: string[] }) { function isNodeEntry(entry: { role?: string; roles?: string[] }) {
@@ -353,6 +358,34 @@ export const nodeHandlers: GatewayRequestHandlers = {
} }
await respondUnavailableOnThrow(respond, async () => { await respondUnavailableOnThrow(respond, async () => {
const nodeSession = context.nodeRegistry.get(nodeId);
if (!nodeSession) {
respond(
false,
undefined,
errorShape(ErrorCodes.UNAVAILABLE, "node not connected", {
details: { code: "NOT_CONNECTED" },
}),
);
return;
}
const cfg = loadConfig();
const allowlist = resolveNodeCommandAllowlist(cfg, nodeSession);
const allowed = isNodeCommandAllowed({
command,
declaredCommands: nodeSession.commands,
allowlist,
});
if (!allowed.ok) {
respond(
false,
undefined,
errorShape(ErrorCodes.INVALID_REQUEST, "node command not allowed", {
details: { reason: allowed.reason, command },
}),
);
return;
}
const res = await context.nodeRegistry.invoke({ const res = await context.nodeRegistry.invoke({
nodeId, nodeId,
command, command,
@@ -384,7 +417,7 @@ export const nodeHandlers: GatewayRequestHandlers = {
); );
}); });
}, },
"node.invoke.result": async ({ params, respond, context }) => { "node.invoke.result": async ({ params, respond, context, client }) => {
if (!validateNodeInvokeResultParams(params)) { if (!validateNodeInvokeResultParams(params)) {
respondInvalidParams({ respondInvalidParams({
respond, respond,
@@ -401,6 +434,11 @@ export const nodeHandlers: GatewayRequestHandlers = {
payloadJSON?: string | null; payloadJSON?: string | null;
error?: { code?: string; message?: string } | null; error?: { code?: string; message?: string } | null;
}; };
const callerNodeId = client?.connect?.device?.id ?? client?.connect?.client?.id;
if (callerNodeId && callerNodeId !== p.nodeId) {
respond(false, undefined, errorShape(ErrorCodes.INVALID_REQUEST, "nodeId mismatch"));
return;
}
const ok = context.nodeRegistry.handleInvokeResult({ const ok = context.nodeRegistry.handleInvokeResult({
id: p.id, id: p.id,
nodeId: p.nodeId, nodeId: p.nodeId,
@@ -415,7 +453,7 @@ export const nodeHandlers: GatewayRequestHandlers = {
} }
respond(true, { ok: true }, undefined); respond(true, { ok: true }, undefined);
}, },
"node.event": async ({ params, respond, context }) => { "node.event": async ({ params, respond, context, client }) => {
if (!validateNodeEventParams(params)) { if (!validateNodeEventParams(params)) {
respondInvalidParams({ respondInvalidParams({
respond, respond,
@@ -433,6 +471,7 @@ export const nodeHandlers: GatewayRequestHandlers = {
: null; : null;
await respondUnavailableOnThrow(respond, async () => { await respondUnavailableOnThrow(respond, async () => {
const { handleNodeEvent } = await import("../server-node-events.js"); const { handleNodeEvent } = await import("../server-node-events.js");
const nodeId = client?.connect?.device?.id ?? client?.connect?.client?.id ?? "node";
const nodeContext = { const nodeContext = {
deps: context.deps, deps: context.deps,
broadcast: context.broadcast, broadcast: context.broadcast,
@@ -453,7 +492,7 @@ export const nodeHandlers: GatewayRequestHandlers = {
loadGatewayModelCatalog: context.loadGatewayModelCatalog, loadGatewayModelCatalog: context.loadGatewayModelCatalog,
logGateway: { warn: context.logGateway.warn }, logGateway: { warn: context.logGateway.warn },
}; };
await handleNodeEvent(nodeContext, "node", { await handleNodeEvent(nodeContext, nodeId, {
event: p.event, event: p.event,
payloadJSON, payloadJSON,
}); });

View File

@@ -1,6 +1,7 @@
import { resolveAgentWorkspaceDir, resolveDefaultAgentId } from "../../agents/agent-scope.js"; import { resolveAgentWorkspaceDir, resolveDefaultAgentId } from "../../agents/agent-scope.js";
import { installSkill } from "../../agents/skills-install.js"; import { installSkill } from "../../agents/skills-install.js";
import { buildWorkspaceSkillStatus } from "../../agents/skills-status.js"; import { buildWorkspaceSkillStatus } from "../../agents/skills-status.js";
import { loadWorkspaceSkillEntries, type SkillEntry } from "../../agents/skills.js";
import type { ClawdbotConfig } from "../../config/config.js"; import type { ClawdbotConfig } from "../../config/config.js";
import { loadConfig, writeConfigFile } from "../../config/config.js"; import { loadConfig, writeConfigFile } from "../../config/config.js";
import { getRemoteSkillEligibility } from "../../infra/skills-remote.js"; import { getRemoteSkillEligibility } from "../../infra/skills-remote.js";
@@ -8,12 +9,52 @@ import {
ErrorCodes, ErrorCodes,
errorShape, errorShape,
formatValidationErrors, formatValidationErrors,
validateSkillsBinsParams,
validateSkillsInstallParams, validateSkillsInstallParams,
validateSkillsStatusParams, validateSkillsStatusParams,
validateSkillsUpdateParams, validateSkillsUpdateParams,
} from "../protocol/index.js"; } from "../protocol/index.js";
import type { GatewayRequestHandlers } from "./types.js"; import type { GatewayRequestHandlers } from "./types.js";
function listWorkspaceDirs(cfg: ClawdbotConfig): string[] {
const dirs = new Set<string>();
const list = cfg.agents?.list;
if (Array.isArray(list)) {
for (const entry of list) {
if (entry && typeof entry === "object" && typeof entry.id === "string") {
dirs.add(resolveAgentWorkspaceDir(cfg, entry.id));
}
}
}
dirs.add(resolveAgentWorkspaceDir(cfg, resolveDefaultAgentId(cfg)));
return [...dirs];
}
function collectSkillBins(entries: SkillEntry[]): string[] {
const bins = new Set<string>();
for (const entry of entries) {
const required = entry.clawdbot?.requires?.bins ?? [];
const anyBins = entry.clawdbot?.requires?.anyBins ?? [];
const install = entry.clawdbot?.install ?? [];
for (const bin of required) {
const trimmed = bin.trim();
if (trimmed) bins.add(trimmed);
}
for (const bin of anyBins) {
const trimmed = bin.trim();
if (trimmed) bins.add(trimmed);
}
for (const spec of install) {
const specBins = spec?.bins ?? [];
for (const bin of specBins) {
const trimmed = String(bin).trim();
if (trimmed) bins.add(trimmed);
}
}
}
return [...bins].sort();
}
export const skillsHandlers: GatewayRequestHandlers = { export const skillsHandlers: GatewayRequestHandlers = {
"skills.status": ({ params, respond }) => { "skills.status": ({ params, respond }) => {
if (!validateSkillsStatusParams(params)) { if (!validateSkillsStatusParams(params)) {
@@ -35,6 +76,27 @@ export const skillsHandlers: GatewayRequestHandlers = {
}); });
respond(true, report, undefined); respond(true, report, undefined);
}, },
"skills.bins": ({ params, respond }) => {
if (!validateSkillsBinsParams(params)) {
respond(
false,
undefined,
errorShape(
ErrorCodes.INVALID_REQUEST,
`invalid skills.bins params: ${formatValidationErrors(validateSkillsBinsParams.errors)}`,
),
);
return;
}
const cfg = loadConfig();
const workspaceDirs = listWorkspaceDirs(cfg);
const bins = new Set<string>();
for (const workspaceDir of workspaceDirs) {
const entries = loadWorkspaceSkillEntries(workspaceDir, { config: cfg });
for (const bin of collectSkillBins(entries)) bins.add(bin);
}
respond(true, { bins: [...bins].sort() }, undefined);
},
"skills.install": async ({ params, respond }) => { "skills.install": async ({ params, respond }) => {
if (!validateSkillsInstallParams(params)) { if (!validateSkillsInstallParams(params)) {
respond( respond(

View File

@@ -150,6 +150,12 @@ describe("gateway server auth/connect", () => {
platform: "web", platform: "web",
mode: "webchat", mode: "webchat",
}, },
device: {
id: 123,
publicKey: "bad",
signature: "bad",
signedAt: "bad",
},
}, },
}), }),
); );

View File

@@ -0,0 +1,147 @@
import { describe, expect, test } from "vitest";
import { WebSocket } from "ws";
import { GATEWAY_CLIENT_MODES, GATEWAY_CLIENT_NAMES } from "../utils/message-channel.js";
import {
connectOk,
installGatewayTestHooks,
onceMessage,
rpcReq,
startServerWithClient,
} from "./test-helpers.js";
installGatewayTestHooks();
describe("gateway node command allowlist", () => {
test("rejects commands outside platform allowlist", async () => {
const { server, ws, port } = await startServerWithClient();
await connectOk(ws);
const nodeWs = new WebSocket(`ws://127.0.0.1:${port}`);
await new Promise<void>((resolve) => nodeWs.once("open", resolve));
await connectOk(nodeWs, {
role: "node",
client: {
id: GATEWAY_CLIENT_NAMES.NODE_HOST,
version: "1.0.0",
platform: "ios",
mode: GATEWAY_CLIENT_MODES.NODE,
},
commands: ["system.run"],
});
const listRes = await rpcReq<{ nodes?: Array<{ nodeId: string }> }>(ws, "node.list", {});
const nodeId = listRes.payload?.nodes?.[0]?.nodeId ?? "";
expect(nodeId).toBeTruthy();
const res = await rpcReq(ws, "node.invoke", {
nodeId,
command: "system.run",
params: { command: "echo hi" },
idempotencyKey: "allowlist-1",
});
expect(res.ok).toBe(false);
expect(res.error?.message).toContain("node command not allowed");
nodeWs.close();
ws.close();
await server.close();
});
test("rejects commands not declared by node", async () => {
const { server, ws, port } = await startServerWithClient();
await connectOk(ws);
const nodeWs = new WebSocket(`ws://127.0.0.1:${port}`);
await new Promise<void>((resolve) => nodeWs.once("open", resolve));
await connectOk(nodeWs, {
role: "node",
client: {
id: "node-empty",
version: "1.0.0",
platform: "ios",
mode: GATEWAY_CLIENT_MODES.NODE,
},
commands: [],
});
const listRes = await rpcReq<{ nodes?: Array<{ nodeId: string }> }>(ws, "node.list", {});
const nodeId = listRes.payload?.nodes?.find((entry) => entry.nodeId)?.nodeId ?? "";
expect(nodeId).toBeTruthy();
const res = await rpcReq(ws, "node.invoke", {
nodeId,
command: "canvas.snapshot",
params: {},
idempotencyKey: "allowlist-2",
});
expect(res.ok).toBe(false);
expect(res.error?.message).toContain("node command not allowed");
nodeWs.close();
ws.close();
await server.close();
});
test("allows declared command within allowlist", async () => {
const { server, ws, port } = await startServerWithClient();
await connectOk(ws);
const nodeWs = new WebSocket(`ws://127.0.0.1:${port}`);
await new Promise<void>((resolve) => nodeWs.once("open", resolve));
await connectOk(nodeWs, {
role: "node",
client: {
id: "node-allowed",
version: "1.0.0",
platform: "ios",
mode: GATEWAY_CLIENT_MODES.NODE,
},
commands: ["canvas.snapshot"],
});
const listRes = await rpcReq<{ nodes?: Array<{ nodeId: string }> }>(ws, "node.list", {});
const nodeId = listRes.payload?.nodes?.[0]?.nodeId ?? "";
expect(nodeId).toBeTruthy();
const invokeReqP = onceMessage<{ type: "event"; event: string; payload?: unknown }>(
nodeWs,
(o) => o.type === "event" && o.event === "node.invoke.request",
);
const invokeResP = rpcReq(ws, "node.invoke", {
nodeId,
command: "canvas.snapshot",
params: { format: "png" },
idempotencyKey: "allowlist-3",
});
const invokeReq = await invokeReqP;
const payload = invokeReq.payload as { id?: string; nodeId?: string };
const requestId = payload?.id ?? "";
const nodeIdFromReq = payload?.nodeId ?? "node-allowed";
nodeWs.send(
JSON.stringify({
type: "req",
id: "node-result",
method: "node.invoke.result",
params: {
id: requestId,
nodeId: nodeIdFromReq,
ok: true,
payloadJSON: JSON.stringify({ ok: true }),
},
}),
);
await onceMessage(nodeWs, (o) => o.type === "res" && o.id === "node-result");
const invokeRes = await invokeResP;
expect(invokeRes.ok).toBe(true);
nodeWs.close();
ws.close();
await server.close();
});
});

View File

@@ -0,0 +1,61 @@
import { describe, expect, test } from "vitest";
import { WebSocket } from "ws";
import { GATEWAY_CLIENT_MODES, GATEWAY_CLIENT_NAMES } from "../utils/message-channel.js";
import {
connectOk,
installGatewayTestHooks,
rpcReq,
startServerWithClient,
} from "./test-helpers.js";
installGatewayTestHooks();
describe("gateway role enforcement", () => {
test("operator cannot send node events or invoke results", async () => {
const { server, ws } = await startServerWithClient();
await connectOk(ws);
const eventRes = await rpcReq(ws, "node.event", { event: "test", payload: { ok: true } });
expect(eventRes.ok).toBe(false);
expect(eventRes.error?.message ?? "").toContain("unauthorized role");
const invokeRes = await rpcReq(ws, "node.invoke.result", {
id: "invoke-1",
nodeId: "node-1",
ok: true,
});
expect(invokeRes.ok).toBe(false);
expect(invokeRes.error?.message ?? "").toContain("unauthorized role");
ws.close();
await server.close();
});
test("node can fetch skills bins but not control plane methods", async () => {
const { server, port } = await startServerWithClient();
const nodeWs = new WebSocket(`ws://127.0.0.1:${port}`);
await new Promise<void>((resolve) => nodeWs.once("open", resolve));
await connectOk(nodeWs, {
role: "node",
client: {
id: GATEWAY_CLIENT_NAMES.NODE_HOST,
version: "1.0.0",
platform: "ios",
mode: GATEWAY_CLIENT_MODES.NODE,
},
commands: [],
});
const binsRes = await rpcReq<{ bins?: unknown[] }>(nodeWs, "skills.bins", {});
expect(binsRes.ok).toBe(true);
expect(Array.isArray(binsRes.payload?.bins)).toBe(true);
const statusRes = await rpcReq(nodeWs, "status", {});
expect(statusRes.ok).toBe(false);
expect(statusRes.error?.message ?? "").toContain("unauthorized role");
nodeWs.close();
await server.close();
});
});

View File

@@ -24,6 +24,7 @@ import { authorizeGatewayConnect } from "../../auth.js";
import { loadConfig } from "../../../config/config.js"; import { loadConfig } from "../../../config/config.js";
import { buildDeviceAuthPayload } from "../../device-auth.js"; import { buildDeviceAuthPayload } from "../../device-auth.js";
import { isLoopbackAddress } from "../../net.js"; import { isLoopbackAddress } from "../../net.js";
import { resolveNodeCommandAllowlist } from "../../node-command-policy.js";
import { import {
type ConnectParams, type ConnectParams,
ErrorCodes, ErrorCodes,
@@ -253,17 +254,55 @@ export function attachGatewayWsMessageHandler(params: {
} }
const authMethod = authResult.method ?? "none"; const authMethod = authResult.method ?? "none";
const role = connectParams.role ?? "operator"; const roleRaw = connectParams.role ?? "operator";
const scopes = Array.isArray(connectParams.scopes) const role = roleRaw === "operator" || roleRaw === "node" ? roleRaw : null;
? connectParams.scopes if (!role) {
: role === "operator" setHandshakeState("failed");
? ["operator.admin"] setCloseCause("invalid-role", {
: []; role: roleRaw,
client: connectParams.client.id,
clientDisplayName: connectParams.client.displayName,
mode: connectParams.client.mode,
version: connectParams.client.version,
});
send({
type: "res",
id: frame.id,
ok: false,
error: errorShape(ErrorCodes.INVALID_REQUEST, "invalid role"),
});
close(1008, "invalid role");
return;
}
const requestedScopes = Array.isArray(connectParams.scopes) ? connectParams.scopes : [];
const scopes =
requestedScopes.length > 0
? requestedScopes
: role === "operator"
? ["operator.admin"]
: [];
connectParams.role = role; connectParams.role = role;
connectParams.scopes = scopes; connectParams.scopes = scopes;
const device = connectParams.device; const device = connectParams.device;
let devicePublicKey: string | null = null; let devicePublicKey: string | null = null;
if (!device) {
setHandshakeState("failed");
setCloseCause("device-required", {
client: connectParams.client.id,
clientDisplayName: connectParams.client.displayName,
mode: connectParams.client.mode,
version: connectParams.client.version,
});
send({
type: "res",
id: frame.id,
ok: false,
error: errorShape(ErrorCodes.NOT_PAIRED, "device identity required"),
});
close(1008, "device identity required");
return;
}
if (device) { if (device) {
const derivedId = deriveDeviceIdFromPublicKey(device.publicKey); const derivedId = deriveDeviceIdFromPublicKey(device.publicKey);
if (!derivedId || derivedId !== device.id) { if (!derivedId || derivedId !== device.id) {
@@ -307,7 +346,7 @@ export function attachGatewayWsMessageHandler(params: {
clientId: connectParams.client.id, clientId: connectParams.client.id,
clientMode: connectParams.client.mode, clientMode: connectParams.client.mode,
role, role,
scopes, scopes: requestedScopes,
signedAtMs: signedAt, signedAtMs: signedAt,
token: connectParams.auth?.token ?? null, token: connectParams.auth?.token ?? null,
}); });
@@ -347,9 +386,7 @@ export function attachGatewayWsMessageHandler(params: {
} }
if (device && devicePublicKey) { if (device && devicePublicKey) {
const paired = await getPairedDevice(device.id); const requirePairing = async (reason: string, paired?: { deviceId: string }) => {
const isPaired = paired?.publicKey === devicePublicKey;
if (!isPaired) {
const pairing = await requestDevicePairing({ const pairing = await requestDevicePairing({
deviceId: device.id, deviceId: device.id,
publicKey: devicePublicKey, publicKey: devicePublicKey,
@@ -360,7 +397,7 @@ export function attachGatewayWsMessageHandler(params: {
role, role,
scopes, scopes,
remoteIp: remoteAddr, remoteIp: remoteAddr,
silent: isLoopbackAddress(remoteAddr) && authMethod !== "none", silent: isLoopbackAddress(remoteAddr),
}); });
const context = buildRequestContext(); const context = buildRequestContext();
if (pairing.request.silent === true) { if (pairing.request.silent === true) {
@@ -385,6 +422,7 @@ export function attachGatewayWsMessageHandler(params: {
setCloseCause("pairing-required", { setCloseCause("pairing-required", {
deviceId: device.id, deviceId: device.id,
requestId: pairing.request.requestId, requestId: pairing.request.requestId,
reason,
}); });
send({ send({
type: "res", type: "res",
@@ -395,9 +433,47 @@ export function attachGatewayWsMessageHandler(params: {
}), }),
}); });
close(1008, "pairing required"); close(1008, "pairing required");
return; return false;
} }
return true;
};
const paired = await getPairedDevice(device.id);
const isPaired = paired?.publicKey === devicePublicKey;
if (!isPaired) {
const ok = await requirePairing("not-paired");
if (!ok) return;
} else { } else {
const allowedRoles = new Set(
Array.isArray(paired.roles)
? paired.roles
: paired.role
? [paired.role]
: [],
);
if (allowedRoles.size === 0) {
const ok = await requirePairing("role-upgrade", paired);
if (!ok) return;
} else if (!allowedRoles.has(role)) {
const ok = await requirePairing("role-upgrade", paired);
if (!ok) return;
}
const pairedScopes = Array.isArray(paired.scopes) ? paired.scopes : [];
if (scopes.length > 0) {
if (pairedScopes.length === 0) {
const ok = await requirePairing("scope-upgrade", paired);
if (!ok) return;
} else {
const allowedScopes = new Set(pairedScopes);
const missingScope = scopes.find((scope) => !allowedScopes.has(scope));
if (missingScope) {
const ok = await requirePairing("scope-upgrade", paired);
if (!ok) return;
}
}
}
await updatePairedDeviceMetadata(device.id, { await updatePairedDeviceMetadata(device.id, {
displayName: connectParams.client.displayName, displayName: connectParams.client.displayName,
platform: connectParams.client.platform, platform: connectParams.client.platform,
@@ -410,10 +486,25 @@ export function attachGatewayWsMessageHandler(params: {
} }
} }
if (role === "node") {
const cfg = loadConfig();
const allowlist = resolveNodeCommandAllowlist(cfg, {
platform: connectParams.client.platform,
deviceFamily: connectParams.client.deviceFamily,
});
const declared = Array.isArray(connectParams.commands) ? connectParams.commands : [];
const filtered = declared
.map((cmd) => cmd.trim())
.filter((cmd) => cmd.length > 0 && allowlist.has(cmd));
connectParams.commands = filtered;
}
const shouldTrackPresence = !isGatewayCliClient(connectParams.client); const shouldTrackPresence = !isGatewayCliClient(connectParams.client);
const clientId = connectParams.client.id; const clientId = connectParams.client.id;
const instanceId = connectParams.client.instanceId; const instanceId = connectParams.client.instanceId;
const presenceKey = shouldTrackPresence ? (instanceId ?? connId) : undefined; const presenceKey = shouldTrackPresence
? (connectParams.device?.id ?? instanceId ?? connId)
: undefined;
logWs("in", "connect", { logWs("in", "connect", {
connId, connId,
@@ -441,7 +532,7 @@ export function attachGatewayWsMessageHandler(params: {
deviceFamily: connectParams.client.deviceFamily, deviceFamily: connectParams.client.deviceFamily,
modelIdentifier: connectParams.client.modelIdentifier, modelIdentifier: connectParams.client.modelIdentifier,
mode: connectParams.client.mode, mode: connectParams.client.mode,
instanceId, instanceId: connectParams.device?.id ?? instanceId,
reason: "connect", reason: "connect",
}); });
incrementPresenceVersion(); incrementPresenceVersion();

View File

@@ -8,6 +8,11 @@ import { WebSocket } from "ws";
import { resolveMainSessionKeyFromConfig, type SessionEntry } from "../config/sessions.js"; import { resolveMainSessionKeyFromConfig, type SessionEntry } from "../config/sessions.js";
import { resetAgentRunContextForTest } from "../infra/agent-events.js"; import { resetAgentRunContextForTest } from "../infra/agent-events.js";
import {
loadOrCreateDeviceIdentity,
publicKeyRawBase64UrlFromPem,
signDevicePayload,
} from "../infra/device-identity.js";
import { drainSystemEvents, peekSystemEvents } from "../infra/system-events.js"; import { drainSystemEvents, peekSystemEvents } from "../infra/system-events.js";
import { rawDataToString } from "../infra/ws.js"; import { rawDataToString } from "../infra/ws.js";
import { resetLogger, setLoggerOverride } from "../logging.js"; import { resetLogger, setLoggerOverride } from "../logging.js";
@@ -16,6 +21,7 @@ import { getDeterministicFreePortBlock } from "../test-utils/ports.js";
import { GATEWAY_CLIENT_MODES, GATEWAY_CLIENT_NAMES } from "../utils/message-channel.js"; import { GATEWAY_CLIENT_MODES, GATEWAY_CLIENT_NAMES } from "../utils/message-channel.js";
import { PROTOCOL_VERSION } from "./protocol/index.js"; import { PROTOCOL_VERSION } from "./protocol/index.js";
import { buildDeviceAuthPayload } from "./device-auth.js";
import type { GatewayServerOptions } from "./server.js"; import type { GatewayServerOptions } from "./server.js";
import { import {
agentCommand, agentCommand,
@@ -268,10 +274,44 @@ export async function connectReq(
caps?: string[]; caps?: string[];
commands?: string[]; commands?: string[];
permissions?: Record<string, boolean>; permissions?: Record<string, boolean>;
device?: {
id: string;
publicKey: string;
signature: string;
signedAt: number;
};
}, },
): Promise<ConnectResponse> { ): Promise<ConnectResponse> {
const { randomUUID } = await import("node:crypto"); const { randomUUID } = await import("node:crypto");
const id = randomUUID(); const id = randomUUID();
const client = opts?.client ?? {
id: GATEWAY_CLIENT_NAMES.TEST,
version: "1.0.0",
platform: "test",
mode: GATEWAY_CLIENT_MODES.TEST,
};
const role = opts?.role ?? "operator";
const requestedScopes = Array.isArray(opts?.scopes) ? opts?.scopes : [];
const device = (() => {
if (opts?.device) return opts.device;
const identity = loadOrCreateDeviceIdentity();
const signedAtMs = Date.now();
const payload = buildDeviceAuthPayload({
deviceId: identity.deviceId,
clientId: client.id,
clientMode: client.mode,
role,
scopes: requestedScopes,
signedAtMs,
token: opts?.token ?? null,
});
return {
id: identity.deviceId,
publicKey: publicKeyRawBase64UrlFromPem(identity.publicKeyPem),
signature: signDevicePayload(identity.privateKeyPem, payload),
signedAt: signedAtMs,
};
})();
ws.send( ws.send(
JSON.stringify({ JSON.stringify({
type: "req", type: "req",
@@ -280,16 +320,11 @@ export async function connectReq(
params: { params: {
minProtocol: opts?.minProtocol ?? PROTOCOL_VERSION, minProtocol: opts?.minProtocol ?? PROTOCOL_VERSION,
maxProtocol: opts?.maxProtocol ?? PROTOCOL_VERSION, maxProtocol: opts?.maxProtocol ?? PROTOCOL_VERSION,
client: opts?.client ?? { client,
id: GATEWAY_CLIENT_NAMES.TEST,
version: "1.0.0",
platform: "test",
mode: GATEWAY_CLIENT_MODES.TEST,
},
caps: opts?.caps ?? [], caps: opts?.caps ?? [],
commands: opts?.commands ?? [], commands: opts?.commands ?? [],
permissions: opts?.permissions ?? undefined, permissions: opts?.permissions ?? undefined,
role: opts?.role, role,
scopes: opts?.scopes, scopes: opts?.scopes,
auth: auth:
opts?.token || opts?.password opts?.token || opts?.password
@@ -298,6 +333,7 @@ export async function connectReq(
password: opts?.password, password: opts?.password,
} }
: undefined, : undefined,
device,
}, },
}), }),
); );

View File

@@ -154,6 +154,19 @@ function mergeRoles(...items: Array<string | string[] | undefined>): string[] |
return [...roles]; return [...roles];
} }
function mergeScopes(...items: Array<string[] | undefined>): string[] | undefined {
const scopes = new Set<string>();
for (const item of items) {
if (!item) continue;
for (const scope of item) {
const trimmed = scope.trim();
if (trimmed) scopes.add(trimmed);
}
}
if (scopes.size === 0) return undefined;
return [...scopes];
}
export async function listDevicePairing(baseDir?: string): Promise<DevicePairingList> { export async function listDevicePairing(baseDir?: string): Promise<DevicePairingList> {
const state = await loadState(baseDir); const state = await loadState(baseDir);
const pending = Object.values(state.pendingById).sort((a, b) => b.ts - a.ts); const pending = Object.values(state.pendingById).sort((a, b) => b.ts - a.ts);
@@ -223,6 +236,7 @@ export async function approveDevicePairing(
const now = Date.now(); const now = Date.now();
const existing = state.pairedByDeviceId[pending.deviceId]; const existing = state.pairedByDeviceId[pending.deviceId];
const roles = mergeRoles(existing?.roles, existing?.role, pending.roles, pending.role); const roles = mergeRoles(existing?.roles, existing?.role, pending.roles, pending.role);
const scopes = mergeScopes(existing?.scopes, pending.scopes);
const device: PairedDevice = { const device: PairedDevice = {
deviceId: pending.deviceId, deviceId: pending.deviceId,
publicKey: pending.publicKey, publicKey: pending.publicKey,
@@ -232,7 +246,7 @@ export async function approveDevicePairing(
clientMode: pending.clientMode, clientMode: pending.clientMode,
role: pending.role, role: pending.role,
roles, roles,
scopes: pending.scopes, scopes,
remoteIp: pending.remoteIp, remoteIp: pending.remoteIp,
createdAtMs: existing?.createdAtMs ?? now, createdAtMs: existing?.createdAtMs ?? now,
approvedAtMs: now, approvedAtMs: now,
@@ -268,6 +282,7 @@ export async function updatePairedDeviceMetadata(
const existing = state.pairedByDeviceId[normalizeDeviceId(deviceId)]; const existing = state.pairedByDeviceId[normalizeDeviceId(deviceId)];
if (!existing) return; if (!existing) return;
const roles = mergeRoles(existing.roles, existing.role, patch.role); const roles = mergeRoles(existing.roles, existing.role, patch.role);
const scopes = mergeScopes(existing.scopes, patch.scopes);
state.pairedByDeviceId[deviceId] = { state.pairedByDeviceId[deviceId] = {
...existing, ...existing,
...patch, ...patch,
@@ -276,6 +291,7 @@ export async function updatePairedDeviceMetadata(
approvedAtMs: existing.approvedAtMs, approvedAtMs: existing.approvedAtMs,
role: patch.role ?? existing.role, role: patch.role ?? existing.role,
roles, roles,
scopes,
}; };
await persistState(state, baseDir); await persistState(state, baseDir);
}); });