fix: enforce message context isolation

This commit is contained in:
Peter Steinberger
2026-01-13 01:03:23 +00:00
parent 0edbdb1948
commit ffc465394e
6 changed files with 164 additions and 5 deletions

View File

@@ -9,6 +9,7 @@
- Update: run `clawdbot doctor --non-interactive` during updates to avoid TTY hangs. (#781 — thanks @ronyrus) - Update: run `clawdbot doctor --non-interactive` during updates to avoid TTY hangs. (#781 — thanks @ronyrus)
- Tools: allow Claude/Gemini tool param aliases (`file_path`, `old_string`, `new_string`) while enforcing required params at runtime. (#793 — thanks @hsrvc) - Tools: allow Claude/Gemini tool param aliases (`file_path`, `old_string`, `new_string`) while enforcing required params at runtime. (#793 — thanks @hsrvc)
- Gemini: downgrade tool-call history missing `thought_signature` to avoid INVALID_ARGUMENT errors. (#793 — thanks @hsrvc) - Gemini: downgrade tool-call history missing `thought_signature` to avoid INVALID_ARGUMENT errors. (#793 — thanks @hsrvc)
- Messaging: enforce context isolation for message tool sends across providers (normalized targets + tests). (#793 — thanks @hsrvc)
## 2026.1.12-3 ## 2026.1.12-3

View File

@@ -179,6 +179,7 @@ Core actions:
Notes: Notes:
- `send` routes WhatsApp via the Gateway; other providers go direct. - `send` routes WhatsApp via the Gateway; other providers go direct.
- `poll` uses the Gateway for WhatsApp and MS Teams; Discord polls go direct. - `poll` uses the Gateway for WhatsApp and MS Teams; Discord polls go direct.
- When a message tool call is bound to an active chat session, sends are constrained to that sessions target to avoid cross-context leaks.
### `cron` ### `cron`
Manage Gateway cron jobs and wakeups. Manage Gateway cron jobs and wakeups.

View File

@@ -2,6 +2,10 @@ import { Type } from "@sinclair/typebox";
import type { ClawdbotConfig } from "../../config/config.js"; import type { ClawdbotConfig } from "../../config/config.js";
import { loadConfig } from "../../config/config.js"; import { loadConfig } from "../../config/config.js";
import {
GATEWAY_CLIENT_IDS,
GATEWAY_CLIENT_MODES,
} from "../../gateway/protocol/client-info.js";
import { runMessageAction } from "../../infra/outbound/message-action-runner.js"; import { runMessageAction } from "../../infra/outbound/message-action-runner.js";
import { import {
listProviderMessageActions, listProviderMessageActions,
@@ -12,10 +16,6 @@ import {
type ProviderMessageActionName, type ProviderMessageActionName,
} from "../../providers/plugins/types.js"; } from "../../providers/plugins/types.js";
import { normalizeAccountId } from "../../routing/session-key.js"; import { normalizeAccountId } from "../../routing/session-key.js";
import {
GATEWAY_CLIENT_MODES,
GATEWAY_CLIENT_NAMES,
} from "../../utils/message-provider.js";
import type { AnyAgentTool } from "./common.js"; import type { AnyAgentTool } from "./common.js";
import { jsonResult, readNumberParam, readStringParam } from "./common.js"; import { jsonResult, readNumberParam, readStringParam } from "./common.js";
@@ -184,7 +184,7 @@ export function createMessageTool(options?: MessageToolOptions): AnyAgentTool {
url: readStringParam(params, "gatewayUrl", { trim: false }), url: readStringParam(params, "gatewayUrl", { trim: false }),
token: readStringParam(params, "gatewayToken", { trim: false }), token: readStringParam(params, "gatewayToken", { trim: false }),
timeoutMs: readNumberParam(params, "timeoutMs"), timeoutMs: readNumberParam(params, "timeoutMs"),
clientName: GATEWAY_CLIENT_NAMES.GATEWAY_CLIENT, clientName: GATEWAY_CLIENT_IDS.GATEWAY_CLIENT,
clientDisplayName: "agent", clientDisplayName: "agent",
mode: GATEWAY_CLIENT_MODES.BACKEND, mode: GATEWAY_CLIENT_MODES.BACKEND,
}; };

View File

@@ -0,0 +1,61 @@
import { describe, expect, it } from "vitest";
import type { ClawdbotConfig } from "../../config/config.js";
import { runMessageAction } from "./message-action-runner.js";
const slackConfig = {
slack: {
botToken: "xoxb-test",
appToken: "xapp-test",
},
} as ClawdbotConfig;
describe("runMessageAction context isolation", () => {
it("allows send when target matches current channel", async () => {
const result = await runMessageAction({
cfg: slackConfig,
action: "send",
params: {
provider: "slack",
to: "#C123",
message: "hi",
},
toolContext: { currentChannelId: "C123" },
dryRun: true,
});
expect(result.kind).toBe("send");
});
it("blocks send when target differs from current channel", async () => {
await expect(
runMessageAction({
cfg: slackConfig,
action: "send",
params: {
provider: "slack",
to: "channel:C999",
message: "hi",
},
toolContext: { currentChannelId: "C123" },
dryRun: true,
}),
).rejects.toThrow(/Cross-context messaging denied/);
});
it("blocks thread-reply when channelId differs from current channel", async () => {
await expect(
runMessageAction({
cfg: slackConfig,
action: "thread-reply",
params: {
provider: "slack",
channelId: "C999",
message: "hi",
},
toolContext: { currentChannelId: "C123" },
dryRun: true,
}),
).rejects.toThrow(/Cross-context messaging denied/);
});
});

View File

@@ -1,4 +1,5 @@
import type { AgentToolResult } from "@mariozechner/pi-agent-core"; import type { AgentToolResult } from "@mariozechner/pi-agent-core";
import { normalizeTargetForProvider } from "../../agents/pi-embedded-messaging.js";
import { import {
readNumberParam, readNumberParam,
readStringArrayParam, readStringArrayParam,
@@ -125,6 +126,56 @@ function parseButtonsParam(params: Record<string, unknown>): void {
} }
} }
const CONTEXT_GUARDED_ACTIONS = new Set<ProviderMessageActionName>([
"send",
"poll",
"thread-create",
"thread-reply",
"sticker",
]);
function resolveContextGuardTarget(
action: ProviderMessageActionName,
params: Record<string, unknown>,
): string | undefined {
if (!CONTEXT_GUARDED_ACTIONS.has(action)) return undefined;
if (action === "thread-reply" || action === "thread-create") {
return (
readStringParam(params, "channelId") ?? readStringParam(params, "to")
);
}
return readStringParam(params, "to") ?? readStringParam(params, "channelId");
}
function enforceContextIsolation(params: {
provider: ProviderId;
action: ProviderMessageActionName;
params: Record<string, unknown>;
toolContext?: ProviderThreadingToolContext;
}): void {
const currentTarget = params.toolContext?.currentChannelId?.trim();
if (!currentTarget) return;
if (!CONTEXT_GUARDED_ACTIONS.has(params.action)) return;
const target = resolveContextGuardTarget(params.action, params.params);
if (!target) return;
const normalizedTarget =
normalizeTargetForProvider(params.provider, target) ?? target.toLowerCase();
const normalizedCurrent =
normalizeTargetForProvider(params.provider, currentTarget) ??
currentTarget.toLowerCase();
if (!normalizedTarget || !normalizedCurrent) return;
if (normalizedTarget === normalizedCurrent) return;
throw new Error(
`Cross-context messaging denied: action=${params.action} target="${target}" while bound to "${currentTarget}" (provider=${params.provider}).`,
);
}
async function resolveProvider( async function resolveProvider(
cfg: ClawdbotConfig, cfg: ClawdbotConfig,
params: Record<string, unknown>, params: Record<string, unknown>,
@@ -150,6 +201,13 @@ export async function runMessageAction(
readStringParam(params, "accountId") ?? input.defaultAccountId; readStringParam(params, "accountId") ?? input.defaultAccountId;
const dryRun = Boolean(input.dryRun ?? readBooleanParam(params, "dryRun")); const dryRun = Boolean(input.dryRun ?? readBooleanParam(params, "dryRun"));
enforceContextIsolation({
provider,
action,
params,
toolContext: input.toolContext,
});
const gateway = input.gateway const gateway = input.gateway
? { ? {
url: input.gateway.url, url: input.gateway.url,

View File

@@ -102,6 +102,11 @@ const DOCKS: Record<ProviderId, ProviderDock> = {
}, },
threading: { threading: {
resolveReplyToMode: ({ cfg }) => cfg.telegram?.replyToMode ?? "first", resolveReplyToMode: ({ cfg }) => cfg.telegram?.replyToMode ?? "first",
buildToolContext: ({ context, hasRepliedRef }) => ({
currentChannelId: context.To?.trim() || undefined,
currentThreadTs: context.ReplyToId,
hasRepliedRef,
}),
}, },
}, },
whatsapp: { whatsapp: {
@@ -142,6 +147,13 @@ const DOCKS: Record<ProviderId, ProviderDock> = {
return [escaped, `@${escaped}`]; return [escaped, `@${escaped}`];
}, },
}, },
threading: {
buildToolContext: ({ context, hasRepliedRef }) => ({
currentChannelId: context.To?.trim() || undefined,
currentThreadTs: context.ReplyToId,
hasRepliedRef,
}),
},
}, },
discord: { discord: {
id: "discord", id: "discord",
@@ -175,6 +187,11 @@ const DOCKS: Record<ProviderId, ProviderDock> = {
}, },
threading: { threading: {
resolveReplyToMode: ({ cfg }) => cfg.discord?.replyToMode ?? "off", resolveReplyToMode: ({ cfg }) => cfg.discord?.replyToMode ?? "off",
buildToolContext: ({ context, hasRepliedRef }) => ({
currentChannelId: context.To?.trim() || undefined,
currentThreadTs: context.ReplyToId,
hasRepliedRef,
}),
}, },
}, },
slack: { slack: {
@@ -246,6 +263,13 @@ const DOCKS: Record<ProviderId, ProviderDock> = {
) )
.filter(Boolean), .filter(Boolean),
}, },
threading: {
buildToolContext: ({ context, hasRepliedRef }) => ({
currentChannelId: context.To?.trim() || undefined,
currentThreadTs: context.ReplyToId,
hasRepliedRef,
}),
},
}, },
imessage: { imessage: {
id: "imessage", id: "imessage",
@@ -266,6 +290,13 @@ const DOCKS: Record<ProviderId, ProviderDock> = {
groups: { groups: {
resolveRequireMention: resolveIMessageGroupRequireMention, resolveRequireMention: resolveIMessageGroupRequireMention,
}, },
threading: {
buildToolContext: ({ context, hasRepliedRef }) => ({
currentChannelId: context.To?.trim() || undefined,
currentThreadTs: context.ReplyToId,
hasRepliedRef,
}),
},
}, },
msteams: { msteams: {
id: "msteams", id: "msteams",
@@ -280,6 +311,13 @@ const DOCKS: Record<ProviderId, ProviderDock> = {
resolveAllowFrom: ({ cfg }) => cfg.msteams?.allowFrom ?? [], resolveAllowFrom: ({ cfg }) => cfg.msteams?.allowFrom ?? [],
formatAllowFrom: ({ allowFrom }) => formatLower(allowFrom), formatAllowFrom: ({ allowFrom }) => formatLower(allowFrom),
}, },
threading: {
buildToolContext: ({ context, hasRepliedRef }) => ({
currentChannelId: context.To?.trim() || undefined,
currentThreadTs: context.ReplyToId,
hasRepliedRef,
}),
},
}, },
}; };