diff --git a/src/agents/schema/typebox.ts b/src/agents/schema/typebox.ts index 982f0fe77..428ed73d1 100644 --- a/src/agents/schema/typebox.ts +++ b/src/agents/schema/typebox.ts @@ -1,4 +1,8 @@ import { Type } from "@sinclair/typebox"; +import { + CHANNEL_TARGET_DESCRIPTION, + CHANNEL_TARGETS_DESCRIPTION, +} from "../../infra/outbound/channel-target.js"; type StringEnumOptions = { description?: string; @@ -25,3 +29,13 @@ export function optionalStringEnum( ) { return Type.Optional(stringEnum(values, options)); } + +export function channelTargetSchema(options?: { description?: string }) { + return Type.String({ + description: options?.description ?? CHANNEL_TARGET_DESCRIPTION, + }); +} + +export function channelTargetsSchema(options?: { description?: string }) { + return Type.Array(channelTargetSchema({ description: options?.description ?? CHANNEL_TARGETS_DESCRIPTION })); +} diff --git a/src/agents/tools/message-tool.ts b/src/agents/tools/message-tool.ts index 94a48d466..ac9020284 100644 --- a/src/agents/tools/message-tool.ts +++ b/src/agents/tools/message-tool.ts @@ -17,7 +17,7 @@ import { GATEWAY_CLIENT_IDS, GATEWAY_CLIENT_MODES } from "../../gateway/protocol import { runMessageAction } from "../../infra/outbound/message-action-runner.js"; import { resolveSessionAgentId } from "../agent-scope.js"; import { normalizeAccountId } from "../../routing/session-key.js"; -import { stringEnum } from "../schema/typebox.js"; +import { channelTargetSchema, channelTargetsSchema, stringEnum } from "../schema/typebox.js"; import type { AnyAgentTool } from "./common.js"; import { jsonResult, readNumberParam, readStringParam } from "./common.js"; @@ -25,8 +25,8 @@ const AllMessageActions = CHANNEL_MESSAGE_ACTION_NAMES; const MessageToolCommonSchema = { channel: Type.Optional(Type.String()), - to: Type.Optional(Type.String()), - targets: Type.Optional(Type.Array(Type.String())), + to: Type.Optional(channelTargetSchema()), + targets: Type.Optional(channelTargetsSchema()), message: Type.Optional(Type.String()), media: Type.Optional(Type.String()), buttons: Type.Optional( @@ -59,8 +59,8 @@ const MessageToolCommonSchema = { pollOption: Type.Optional(Type.Array(Type.String())), pollDurationHours: Type.Optional(Type.Number()), pollMulti: Type.Optional(Type.Boolean()), - channelId: Type.Optional(Type.String()), - channelIds: Type.Optional(Type.Array(Type.String())), + channelId: Type.Optional(channelTargetSchema()), + channelIds: Type.Optional(channelTargetsSchema()), guildId: Type.Optional(Type.String()), userId: Type.Optional(Type.String()), authorId: Type.Optional(Type.String()), diff --git a/src/cli/program/message/helpers.ts b/src/cli/program/message/helpers.ts index 66bca02e8..c5249159d 100644 --- a/src/cli/program/message/helpers.ts +++ b/src/cli/program/message/helpers.ts @@ -1,6 +1,7 @@ import type { Command } from "commander"; import { messageCommand } from "../../../commands/message.js"; import { danger, setVerbose } from "../../../globals.js"; +import { CHANNEL_TARGET_DESCRIPTION } from "../../../infra/outbound/channel-target.js"; import { defaultRuntime } from "../../../runtime.js"; import { createDefaultDeps } from "../../deps.js"; @@ -26,12 +27,12 @@ export function createMessageCliHelpers( const withMessageTarget = (command: Command) => command.option( "-t, --to ", - "Recipient/channel: E.164 for WhatsApp/Signal, Telegram chat id/@username, Discord/Slack channel/user, or iMessage handle/chat_id", + CHANNEL_TARGET_DESCRIPTION, ); const withRequiredMessageTarget = (command: Command) => command.requiredOption( "-t, --to ", - "Recipient/channel: E.164 for WhatsApp/Signal, Telegram chat id/@username, Discord/Slack channel/user, or iMessage handle/chat_id", + CHANNEL_TARGET_DESCRIPTION, ); const runMessageAction = async (action: string, opts: Record) => { diff --git a/src/cli/program/message/register.broadcast.ts b/src/cli/program/message/register.broadcast.ts index ed7ac4c7c..f4fde0660 100644 --- a/src/cli/program/message/register.broadcast.ts +++ b/src/cli/program/message/register.broadcast.ts @@ -1,4 +1,5 @@ import type { Command } from "commander"; +import { CHANNEL_TARGETS_DESCRIPTION } from "../../../infra/outbound/channel-target.js"; import type { MessageCliHelpers } from "./helpers.js"; export function registerMessageBroadcastCommand(message: Command, helpers: MessageCliHelpers) { @@ -8,7 +9,7 @@ export function registerMessageBroadcastCommand(message: Command, helpers: Messa ) .requiredOption( "--targets ", - "Targets to broadcast to (repeatable, accepts names or ids)", + CHANNEL_TARGETS_DESCRIPTION, ) .option("--message ", "Message to send") .option("--media ", "Media URL") diff --git a/src/infra/outbound/channel-adapters.ts b/src/infra/outbound/channel-adapters.ts new file mode 100644 index 000000000..b66d1edbf --- /dev/null +++ b/src/infra/outbound/channel-adapters.ts @@ -0,0 +1,24 @@ +import type { ChannelId } from "../../channels/plugins/types.js"; + +export type ChannelMessageAdapter = { + supportsEmbeds: boolean; + buildCrossContextEmbeds?: (originLabel: string) => unknown[]; +}; + +const DEFAULT_ADAPTER: ChannelMessageAdapter = { + supportsEmbeds: false, +}; + +const DISCORD_ADAPTER: ChannelMessageAdapter = { + supportsEmbeds: true, + buildCrossContextEmbeds: (originLabel: string) => [ + { + description: `From ${originLabel}`, + }, + ], +}; + +export function getChannelMessageAdapter(channel: ChannelId): ChannelMessageAdapter { + if (channel === "discord") return DISCORD_ADAPTER; + return DEFAULT_ADAPTER; +} diff --git a/src/infra/outbound/channel-target.ts b/src/infra/outbound/channel-target.ts new file mode 100644 index 000000000..d60813267 --- /dev/null +++ b/src/infra/outbound/channel-target.ts @@ -0,0 +1,9 @@ +export const CHANNEL_TARGET_DESCRIPTION = + "Recipient/channel: E.164 for WhatsApp/Signal, Telegram chat id/@username, Discord/Slack channel/user, or iMessage handle/chat_id"; + +export const CHANNEL_TARGETS_DESCRIPTION = + "Recipient/channel targets (same format as --to); accepts ids or names when the directory is available."; + +export function normalizeChannelTargetInput(raw: string): string { + return raw.trim(); +} diff --git a/src/infra/outbound/directory-cache.ts b/src/infra/outbound/directory-cache.ts new file mode 100644 index 000000000..197d9789f --- /dev/null +++ b/src/infra/outbound/directory-cache.ts @@ -0,0 +1,53 @@ +import type { ChannelDirectoryEntryKind, ChannelId } from "../../channels/plugins/types.js"; +import type { ClawdbotConfig } from "../../config/config.js"; + +type CacheEntry = { + value: T; + fetchedAt: number; +}; + +export type DirectoryCacheKey = { + channel: ChannelId; + accountId?: string | null; + kind: ChannelDirectoryEntryKind; + source: "cache" | "live"; +}; + +export function buildDirectoryCacheKey(key: DirectoryCacheKey): string { + return `${key.channel}:${key.accountId ?? "default"}:${key.kind}:${key.source}`; +} + +export class DirectoryCache { + private readonly cache = new Map>(); + private lastConfigRef: ClawdbotConfig | null = null; + + constructor(private readonly ttlMs: number) {} + + get(key: string, cfg: ClawdbotConfig): T | undefined { + this.resetIfConfigChanged(cfg); + const entry = this.cache.get(key); + if (!entry) return undefined; + if (Date.now() - entry.fetchedAt > this.ttlMs) { + this.cache.delete(key); + return undefined; + } + return entry.value; + } + + set(key: string, value: T, cfg: ClawdbotConfig): void { + this.resetIfConfigChanged(cfg); + this.cache.set(key, { value, fetchedAt: Date.now() }); + } + + clear(cfg?: ClawdbotConfig): void { + this.cache.clear(); + if (cfg) this.lastConfigRef = cfg; + } + + private resetIfConfigChanged(cfg: ClawdbotConfig): void { + if (this.lastConfigRef && this.lastConfigRef !== cfg) { + this.cache.clear(); + } + this.lastConfigRef = cfg; + } +} diff --git a/src/infra/outbound/message-action-runner.ts b/src/infra/outbound/message-action-runner.ts index ba4e52113..c15860dd1 100644 --- a/src/infra/outbound/message-action-runner.ts +++ b/src/infra/outbound/message-action-runner.ts @@ -1,5 +1,4 @@ import type { AgentToolResult } from "@mariozechner/pi-agent-core"; -import { normalizeTargetForProvider } from "../../agents/pi-embedded-messaging.js"; import { readNumberParam, readStringArrayParam, @@ -18,7 +17,13 @@ import { listConfiguredMessageChannels, resolveMessageChannelSelection } from ". import type { OutboundSendDeps } from "./deliver.js"; import type { MessagePollResult, MessageSendResult } from "./message.js"; import { sendMessage, sendPoll } from "./message.js"; -import { lookupDirectoryDisplay, resolveMessagingTarget } from "./target-resolver.js"; +import { + applyCrossContextDecoration, + buildCrossContextDecoration, + enforceCrossContextPolicy, + shouldApplyCrossContextMarker, +} from "./outbound-policy.js"; +import { resolveMessagingTarget } from "./target-resolver.js"; export type MessageActionRunnerGateway = { url?: string; @@ -139,72 +144,6 @@ function parseButtonsParam(params: Record): void { } } -const CONTEXT_GUARDED_ACTIONS = new Set([ - "send", - "poll", - "thread-create", - "thread-reply", - "sticker", -]); - -function resolveContextGuardTarget( - action: ChannelMessageActionName, - params: Record, -): string | undefined { - if (!CONTEXT_GUARDED_ACTIONS.has(action)) return undefined; - - if (action === "thread-reply" || action === "thread-create") { - return readStringParam(params, "channelId") ?? readStringParam(params, "to"); - } - - return readStringParam(params, "to") ?? readStringParam(params, "channelId"); -} - -function enforceContextIsolation(params: { - channel: ChannelId; - action: ChannelMessageActionName; - params: Record; - toolContext?: ChannelThreadingToolContext; - cfg: ClawdbotConfig; -}): void { - const currentTarget = params.toolContext?.currentChannelId?.trim(); - if (!currentTarget) return; - if (!CONTEXT_GUARDED_ACTIONS.has(params.action)) return; - - if (params.cfg.tools?.message?.allowCrossContextSend) return; - - const currentProvider = params.toolContext?.currentChannelProvider; - const allowWithinProvider = params.cfg.tools?.message?.crossContext?.allowWithinProvider !== false; - const allowAcrossProviders = - params.cfg.tools?.message?.crossContext?.allowAcrossProviders === true; - - if (currentProvider && currentProvider !== params.channel) { - if (!allowAcrossProviders) { - throw new Error( - `Cross-context messaging denied: action=${params.action} target provider "${params.channel}" while bound to "${currentProvider}".`, - ); - } - return; - } - - if (allowWithinProvider) return; - - const target = resolveContextGuardTarget(params.action, params.params); - if (!target) return; - - const normalizedTarget = - normalizeTargetForProvider(params.channel, target) ?? target.toLowerCase(); - const normalizedCurrent = - normalizeTargetForProvider(params.channel, currentTarget) ?? currentTarget.toLowerCase(); - - if (!normalizedTarget || !normalizedCurrent) return; - if (normalizedTarget === normalizedCurrent) return; - - throw new Error( - `Cross-context messaging denied: action=${params.action} target="${target}" while bound to "${currentTarget}" (channel=${params.channel}).`, - ); -} - async function resolveChannel(cfg: ClawdbotConfig, params: Record) { const channelHint = readStringParam(params, "channel"); const selection = await resolveMessageChannelSelection({ @@ -214,57 +153,6 @@ async function resolveChannel(cfg: ClawdbotConfig, params: Record([ + "send", + "poll", + "thread-create", + "thread-reply", + "sticker", +]); + +const CONTEXT_MARKER_ACTIONS = new Set([ + "send", + "poll", + "thread-reply", + "sticker", +]); + +function resolveContextGuardTarget( + action: ChannelMessageActionName, + params: Record, +): string | undefined { + if (!CONTEXT_GUARDED_ACTIONS.has(action)) return undefined; + + if (action === "thread-reply" || action === "thread-create") { + if (typeof params.channelId === "string") return params.channelId; + if (typeof params.to === "string") return params.to; + return undefined; + } + + if (typeof params.to === "string") return params.to; + if (typeof params.channelId === "string") return params.channelId; + return undefined; +} + +function normalizeTarget(channel: ChannelId, raw: string): string | undefined { + return normalizeTargetForProvider(channel, raw) ?? raw.trim().toLowerCase(); +} + +function isCrossContextTarget(params: { + channel: ChannelId; + target: string; + toolContext?: ChannelThreadingToolContext; +}): boolean { + const currentTarget = params.toolContext?.currentChannelId?.trim(); + if (!currentTarget) return false; + const normalizedTarget = normalizeTarget(params.channel, params.target); + const normalizedCurrent = normalizeTarget(params.channel, currentTarget); + if (!normalizedTarget || !normalizedCurrent) return false; + return normalizedTarget !== normalizedCurrent; +} + +export function enforceCrossContextPolicy(params: { + channel: ChannelId; + action: ChannelMessageActionName; + args: Record; + toolContext?: ChannelThreadingToolContext; + cfg: ClawdbotConfig; +}): void { + const currentTarget = params.toolContext?.currentChannelId?.trim(); + if (!currentTarget) return; + if (!CONTEXT_GUARDED_ACTIONS.has(params.action)) return; + + if (params.cfg.tools?.message?.allowCrossContextSend) return; + + const currentProvider = params.toolContext?.currentChannelProvider; + const allowWithinProvider = params.cfg.tools?.message?.crossContext?.allowWithinProvider !== false; + const allowAcrossProviders = + params.cfg.tools?.message?.crossContext?.allowAcrossProviders === true; + + if (currentProvider && currentProvider !== params.channel) { + if (!allowAcrossProviders) { + throw new Error( + `Cross-context messaging denied: action=${params.action} target provider "${params.channel}" while bound to "${currentProvider}".`, + ); + } + return; + } + + if (allowWithinProvider) return; + + const target = resolveContextGuardTarget(params.action, params.args); + if (!target) return; + + if (!isCrossContextTarget({ channel: params.channel, target, toolContext: params.toolContext })) { + return; + } + + throw new Error( + `Cross-context messaging denied: action=${params.action} target="${target}" while bound to "${currentTarget}" (channel=${params.channel}).`, + ); +} + +export async function buildCrossContextDecoration(params: { + cfg: ClawdbotConfig; + channel: ChannelId; + target: string; + toolContext?: ChannelThreadingToolContext; + accountId?: string | null; +}): Promise { + if (!params.toolContext?.currentChannelId) return null; + if (!isCrossContextTarget(params)) return null; + + const markerConfig = params.cfg.tools?.message?.crossContext?.marker; + if (markerConfig?.enabled === false) return null; + + const currentName = + (await lookupDirectoryDisplay({ + cfg: params.cfg, + channel: params.channel, + targetId: params.toolContext.currentChannelId, + accountId: params.accountId ?? undefined, + })) ?? params.toolContext.currentChannelId; + const originLabel = currentName.startsWith("#") ? currentName : `#${currentName}`; + const prefixTemplate = markerConfig?.prefix ?? "[from {channel}] "; + const suffixTemplate = markerConfig?.suffix ?? ""; + const prefix = prefixTemplate.replaceAll("{channel}", originLabel); + const suffix = suffixTemplate.replaceAll("{channel}", originLabel); + + const adapter = getChannelMessageAdapter(params.channel); + const embeds = adapter.supportsEmbeds + ? adapter.buildCrossContextEmbeds?.(originLabel) ?? undefined + : undefined; + + return { prefix, suffix, embeds }; +} + +export function shouldApplyCrossContextMarker(action: ChannelMessageActionName): boolean { + return CONTEXT_MARKER_ACTIONS.has(action); +} + +export function applyCrossContextDecoration(params: { + message: string; + decoration: CrossContextDecoration; + preferEmbeds: boolean; +}): { message: string; embeds?: unknown[]; usedEmbeds: boolean } { + const useEmbeds = params.preferEmbeds && params.decoration.embeds?.length; + if (useEmbeds) { + return { message: params.message, embeds: params.decoration.embeds, usedEmbeds: true }; + } + const message = `${params.decoration.prefix}${params.message}${params.decoration.suffix}`; + return { message, usedEmbeds: false }; +} diff --git a/src/infra/outbound/target-resolver.ts b/src/infra/outbound/target-resolver.ts index d77895355..11e6a4cd7 100644 --- a/src/infra/outbound/target-resolver.ts +++ b/src/infra/outbound/target-resolver.ts @@ -7,6 +7,8 @@ import type { } from "../../channels/plugins/types.js"; import type { ClawdbotConfig } from "../../config/config.js"; import { defaultRuntime, type RuntimeEnv } from "../../runtime.js"; +import { normalizeChannelTargetInput } from "./channel-target.js"; +import { buildDirectoryCacheKey, DirectoryCache } from "./directory-cache.js"; export type TargetResolveKind = ChannelDirectoryEntryKind | "channel"; @@ -21,30 +23,8 @@ export type ResolveMessagingTargetResult = | { ok: true; target: ResolvedMessagingTarget } | { ok: false; error: Error; candidates?: ChannelDirectoryEntry[] }; -type DirectoryCacheEntry = { - entries: ChannelDirectoryEntry[]; - fetchedAt: number; -}; - const CACHE_TTL_MS = 30 * 60 * 1000; -const directoryCache = new Map(); -let lastConfigRef: ClawdbotConfig | null = null; - -function resetCacheIfConfigChanged(cfg: ClawdbotConfig): void { - if (lastConfigRef && lastConfigRef !== cfg) { - directoryCache.clear(); - } - lastConfigRef = cfg; -} - -function buildCacheKey(params: { - channel: ChannelId; - accountId?: string | null; - kind: ChannelDirectoryEntryKind; - source: "cache" | "live"; -}) { - return `${params.channel}:${params.accountId ?? "default"}:${params.kind}:${params.source}`; -} +const directoryCache = new DirectoryCache(CACHE_TTL_MS); function normalizeQuery(value: string): string { return value.trim().toLowerCase(); @@ -182,17 +162,14 @@ async function getDirectoryEntries(params: { runtime?: RuntimeEnv; preferLiveOnMiss?: boolean; }): Promise { - resetCacheIfConfigChanged(params.cfg); - const cacheKey = buildCacheKey({ + const cacheKey = buildDirectoryCacheKey({ channel: params.channel, accountId: params.accountId, kind: params.kind, source: "cache", }); - const cached = directoryCache.get(cacheKey); - if (cached && Date.now() - cached.fetchedAt < CACHE_TTL_MS) { - return cached.entries; - } + const cached = directoryCache.get(cacheKey, params.cfg); + if (cached) return cached; const entries = await listDirectoryEntries({ cfg: params.cfg, channel: params.channel, @@ -203,10 +180,10 @@ async function getDirectoryEntries(params: { source: "cache", }); if (entries.length > 0 || !params.preferLiveOnMiss) { - directoryCache.set(cacheKey, { entries, fetchedAt: Date.now() }); + directoryCache.set(cacheKey, entries, params.cfg); return entries; } - const liveKey = buildCacheKey({ + const liveKey = buildDirectoryCacheKey({ channel: params.channel, accountId: params.accountId, kind: params.kind, @@ -221,7 +198,7 @@ async function getDirectoryEntries(params: { runtime: params.runtime, source: "live", }); - directoryCache.set(liveKey, { entries: liveEntries, fetchedAt: Date.now() }); + directoryCache.set(liveKey, liveEntries, params.cfg); return liveEntries; } @@ -233,7 +210,7 @@ export async function resolveMessagingTarget(params: { preferredKind?: TargetResolveKind; runtime?: RuntimeEnv; }): Promise { - const raw = params.input.trim(); + const raw = normalizeChannelTargetInput(params.input); if (!raw) { return { ok: false, error: new Error("Target is required") }; }