From 3668388912e69ba9c3e8f4f941c3f072765b6346 Mon Sep 17 00:00:00 2001 From: Peter Steinberger Date: Wed, 7 Jan 2026 05:02:34 +0000 Subject: [PATCH] fix(routing): harden originating reply routing --- CHANGELOG.md | 1 + src/agents/pi-tool-definition-adapter.test.ts | 3 +- src/agents/pi-tool-definition-adapter.ts | 2 +- src/auto-reply/reply.ts | 2 + .../reply/dispatch-from-config.test.ts | 91 ++++++++++++++ src/auto-reply/reply/dispatch-from-config.ts | 15 ++- src/auto-reply/reply/followup-runner.ts | 2 + .../reply/queue.collect-routing.test.ts | 111 ++++++++++++++++++ src/auto-reply/reply/queue.ts | 54 +++++++-- src/auto-reply/reply/route-reply.ts | 97 +++++++++++---- src/telegram/send.ts | 26 ++-- src/web/auto-reply.ts | 3 + 12 files changed, 356 insertions(+), 51 deletions(-) create mode 100644 src/auto-reply/reply/dispatch-from-config.test.ts create mode 100644 src/auto-reply/reply/queue.collect-routing.test.ts diff --git a/CHANGELOG.md b/CHANGELOG.md index 744737a4f..2fbcc32b1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -111,6 +111,7 @@ - Block streaming: preserve leading indentation in block replies (lists, indented fences). - Docs: document systemd lingering and logged-in session requirements on macOS/Windows. - Auto-reply: centralize tool/block/final dispatch across providers for consistent streaming + heartbeat/prefix handling. Thanks @MSch for PR #225. +- Routing: route replies back to the originating provider/chat when multiple providers share the same session. Thanks @jalehman for PR #328. - Heartbeat: make HEARTBEAT_OK ack padding configurable across heartbeat and cron delivery. (#238) — thanks @jalehman - Skills: emit MEDIA token after Nano Banana Pro image generation. Thanks @Iamadig for PR #271. - WhatsApp: set sender E.164 for direct chats so owner commands work in DMs. diff --git a/src/agents/pi-tool-definition-adapter.test.ts b/src/agents/pi-tool-definition-adapter.test.ts index e05d579c0..27a101002 100644 --- a/src/agents/pi-tool-definition-adapter.test.ts +++ b/src/agents/pi-tool-definition-adapter.test.ts @@ -1,6 +1,5 @@ -import { describe, expect, it } from "vitest"; - import type { AgentTool } from "@mariozechner/pi-agent-core"; +import { describe, expect, it } from "vitest"; import { toToolDefinitions } from "./pi-tool-definition-adapter.js"; diff --git a/src/agents/pi-tool-definition-adapter.ts b/src/agents/pi-tool-definition-adapter.ts index 6df70bf6b..df8b64d8d 100644 --- a/src/agents/pi-tool-definition-adapter.ts +++ b/src/agents/pi-tool-definition-adapter.ts @@ -38,7 +38,7 @@ export function toToolDefinitions(tools: AnyAgentTool[]): ToolDefinition[] { : ""; if (name === "AbortError") throw err; const message = - err instanceof Error ? err.stack ?? err.message : String(err); + err instanceof Error ? (err.stack ?? err.message) : String(err); logError(`[tools] ${tool.name} failed: ${message}`); return jsonResult({ status: "error", diff --git a/src/auto-reply/reply.ts b/src/auto-reply/reply.ts index b1ea70ef9..40a26c8f7 100644 --- a/src/auto-reply/reply.ts +++ b/src/auto-reply/reply.ts @@ -717,6 +717,8 @@ export async function getReplyFromConfig( // Originating channel for reply routing. originatingChannel: ctx.OriginatingChannel, originatingTo: ctx.OriginatingTo, + originatingAccountId: ctx.AccountId, + originatingThreadId: ctx.MessageThreadId, run: { agentId, agentDir, diff --git a/src/auto-reply/reply/dispatch-from-config.test.ts b/src/auto-reply/reply/dispatch-from-config.test.ts new file mode 100644 index 000000000..69ea5102d --- /dev/null +++ b/src/auto-reply/reply/dispatch-from-config.test.ts @@ -0,0 +1,91 @@ +import { describe, expect, it, vi } from "vitest"; + +import type { ClawdbotConfig } from "../../config/config.js"; +import type { MsgContext } from "../templating.js"; +import type { GetReplyOptions, ReplyPayload } from "../types.js"; +import type { ReplyDispatcher } from "./reply-dispatcher.js"; + +const mocks = vi.hoisted(() => ({ + routeReply: vi.fn(async () => ({ ok: true, messageId: "mock" })), +})); + +vi.mock("./route-reply.js", () => ({ + isRoutableChannel: (channel: string | undefined) => + Boolean( + channel && + [ + "telegram", + "slack", + "discord", + "signal", + "imessage", + "whatsapp", + ].includes(channel), + ), + routeReply: mocks.routeReply, +})); + +const { dispatchReplyFromConfig } = await import("./dispatch-from-config.js"); + +function createDispatcher(): ReplyDispatcher { + return { + sendToolResult: vi.fn(() => true), + sendBlockReply: vi.fn(() => true), + sendFinalReply: vi.fn(() => true), + waitForIdle: vi.fn(async () => {}), + getQueuedCounts: vi.fn(() => ({ tool: 0, block: 0, final: 0 })), + }; +} + +describe("dispatchReplyFromConfig", () => { + it("does not route when Provider matches OriginatingChannel (even if Surface is missing)", async () => { + mocks.routeReply.mockClear(); + const cfg = {} as ClawdbotConfig; + const dispatcher = createDispatcher(); + const ctx: MsgContext = { + Provider: "slack", + OriginatingChannel: "slack", + OriginatingTo: "channel:C123", + }; + + const replyResolver = async ( + _ctx: MsgContext, + _opts: GetReplyOptions | undefined, + _cfg: ClawdbotConfig, + ) => ({ text: "hi" }) satisfies ReplyPayload; + await dispatchReplyFromConfig({ ctx, cfg, dispatcher, replyResolver }); + + expect(mocks.routeReply).not.toHaveBeenCalled(); + expect(dispatcher.sendFinalReply).toHaveBeenCalledTimes(1); + }); + + it("routes when OriginatingChannel differs from Provider", async () => { + mocks.routeReply.mockClear(); + const cfg = {} as ClawdbotConfig; + const dispatcher = createDispatcher(); + const ctx: MsgContext = { + Provider: "slack", + AccountId: "acc-1", + MessageThreadId: 123, + OriginatingChannel: "telegram", + OriginatingTo: "telegram:999", + }; + + const replyResolver = async ( + _ctx: MsgContext, + _opts: GetReplyOptions | undefined, + _cfg: ClawdbotConfig, + ) => ({ text: "hi" }) satisfies ReplyPayload; + await dispatchReplyFromConfig({ ctx, cfg, dispatcher, replyResolver }); + + expect(dispatcher.sendFinalReply).not.toHaveBeenCalled(); + expect(mocks.routeReply).toHaveBeenCalledWith( + expect.objectContaining({ + channel: "telegram", + to: "telegram:999", + accountId: "acc-1", + threadId: 123, + }), + ); + }); +}); diff --git a/src/auto-reply/reply/dispatch-from-config.ts b/src/auto-reply/reply/dispatch-from-config.ts index b9ec5a2c6..0e6654cfa 100644 --- a/src/auto-reply/reply/dispatch-from-config.ts +++ b/src/auto-reply/reply/dispatch-from-config.ts @@ -27,7 +27,7 @@ export async function dispatchReplyFromConfig(params: { // flow when the provider handles its own messages. const originatingChannel = ctx.OriginatingChannel; const originatingTo = ctx.OriginatingTo; - const currentSurface = ctx.Surface?.toLowerCase(); + const currentSurface = (ctx.Surface ?? ctx.Provider)?.toLowerCase(); const shouldRouteToOriginating = isRoutableChannel(originatingChannel) && originatingTo && @@ -47,6 +47,8 @@ export async function dispatchReplyFromConfig(params: { payload, channel: originatingChannel, to: originatingTo, + accountId: ctx.AccountId, + threadId: ctx.MessageThreadId, cfg, }); if (!result.ok) { @@ -89,6 +91,7 @@ export async function dispatchReplyFromConfig(params: { : []; let queuedFinal = false; + let routedFinalCount = 0; for (const reply of replies) { if (shouldRouteToOriginating && originatingChannel && originatingTo) { // Route final reply to originating channel. @@ -96,6 +99,8 @@ export async function dispatchReplyFromConfig(params: { payload: reply, channel: originatingChannel, to: originatingTo, + accountId: ctx.AccountId, + threadId: ctx.MessageThreadId, cfg, }); if (!result.ok) { @@ -103,13 +108,15 @@ export async function dispatchReplyFromConfig(params: { `dispatch-from-config: route-reply (final) failed: ${result.error ?? "unknown error"}`, ); } - // Mark as queued since we handled it ourselves. - queuedFinal = true; + queuedFinal = result.ok || queuedFinal; + if (result.ok) routedFinalCount += 1; } else { queuedFinal = dispatcher.sendFinalReply(reply) || queuedFinal; } } await dispatcher.waitForIdle(); - return { queuedFinal, counts: dispatcher.getQueuedCounts() }; + const counts = dispatcher.getQueuedCounts(); + counts.final += routedFinalCount; + return { queuedFinal, counts }; } diff --git a/src/auto-reply/reply/followup-runner.ts b/src/auto-reply/reply/followup-runner.ts index ea88ef983..6086b7f40 100644 --- a/src/auto-reply/reply/followup-runner.ts +++ b/src/auto-reply/reply/followup-runner.ts @@ -79,6 +79,8 @@ export function createFollowupRunner(params: { payload, channel: originatingChannel, to: originatingTo, + accountId: queued.originatingAccountId, + threadId: queued.originatingThreadId, cfg: queued.run.config, }); if (!result.ok) { diff --git a/src/auto-reply/reply/queue.collect-routing.test.ts b/src/auto-reply/reply/queue.collect-routing.test.ts new file mode 100644 index 000000000..089f38ae1 --- /dev/null +++ b/src/auto-reply/reply/queue.collect-routing.test.ts @@ -0,0 +1,111 @@ +import { describe, expect, it } from "vitest"; + +import type { ClawdbotConfig } from "../../config/config.js"; +import type { FollowupRun, QueueSettings } from "./queue.js"; +import { enqueueFollowupRun, scheduleFollowupDrain } from "./queue.js"; + +function createRun(params: { + prompt: string; + originatingChannel?: FollowupRun["originatingChannel"]; + originatingTo?: string; +}): FollowupRun { + return { + prompt: params.prompt, + enqueuedAt: Date.now(), + originatingChannel: params.originatingChannel, + originatingTo: params.originatingTo, + run: { + agentId: "agent", + agentDir: "/tmp", + sessionId: "sess", + sessionFile: "/tmp/session.json", + workspaceDir: "/tmp", + config: {} as ClawdbotConfig, + provider: "openai", + model: "gpt-test", + timeoutMs: 10_000, + blockReplyBreak: "text_end", + }, + }; +} + +describe("followup queue collect routing", () => { + it("does not collect when destinations differ", async () => { + const key = `test-collect-diff-to-${Date.now()}`; + const calls: FollowupRun[] = []; + const runFollowup = async (run: FollowupRun) => { + calls.push(run); + }; + const settings: QueueSettings = { + mode: "collect", + debounceMs: 0, + cap: 50, + dropPolicy: "summarize", + }; + + enqueueFollowupRun( + key, + createRun({ + prompt: "one", + originatingChannel: "slack", + originatingTo: "channel:A", + }), + settings, + ); + enqueueFollowupRun( + key, + createRun({ + prompt: "two", + originatingChannel: "slack", + originatingTo: "channel:B", + }), + settings, + ); + + scheduleFollowupDrain(key, runFollowup); + await expect.poll(() => calls.length).toBe(2); + expect(calls[0]?.prompt).toBe("one"); + expect(calls[1]?.prompt).toBe("two"); + }); + + it("collects when channel+destination match", async () => { + const key = `test-collect-same-to-${Date.now()}`; + const calls: FollowupRun[] = []; + const runFollowup = async (run: FollowupRun) => { + calls.push(run); + }; + const settings: QueueSettings = { + mode: "collect", + debounceMs: 0, + cap: 50, + dropPolicy: "summarize", + }; + + enqueueFollowupRun( + key, + createRun({ + prompt: "one", + originatingChannel: "slack", + originatingTo: "channel:A", + }), + settings, + ); + enqueueFollowupRun( + key, + createRun({ + prompt: "two", + originatingChannel: "slack", + originatingTo: "channel:A", + }), + settings, + ); + + scheduleFollowupDrain(key, runFollowup); + await expect.poll(() => calls.length).toBe(1); + expect(calls[0]?.prompt).toContain( + "[Queued messages while agent was busy]", + ); + expect(calls[0]?.originatingChannel).toBe("slack"); + expect(calls[0]?.originatingTo).toBe("channel:A"); + }); +}); diff --git a/src/auto-reply/reply/queue.ts b/src/auto-reply/reply/queue.ts index 51bc9a84d..a5208e674 100644 --- a/src/auto-reply/reply/queue.ts +++ b/src/auto-reply/reply/queue.ts @@ -35,6 +35,10 @@ export type FollowupRun = { * The chat/channel/user ID where the reply should be sent. */ originatingTo?: string; + /** Provider account id (multi-account). */ + originatingAccountId?: string; + /** Telegram forum topic thread id. */ + originatingThreadId?: number; run: { agentId: string; agentDir: string; @@ -396,23 +400,34 @@ function buildCollectPrompt(items: FollowupRun[], summary?: string): string { * Also returns true for a mix of routable and non-routable channels. */ function hasCrossProviderItems(items: FollowupRun[]): boolean { - const routableChannels = new Set(); - let hasNonRoutable = false; + const keys = new Set(); + let hasUnkeyed = false; for (const item of items) { const channel = item.originatingChannel; - if (isRoutableChannel(channel)) { - routableChannels.add(channel); - } else if (channel) { - // Has a channel but it's not routable (whatsapp, webchat). - hasNonRoutable = true; + const to = item.originatingTo; + const accountId = item.originatingAccountId; + const threadId = item.originatingThreadId; + if (!channel && !to && !accountId && typeof threadId !== "number") { + hasUnkeyed = true; + continue; } + if (!isRoutableChannel(channel) || !to) { + return true; + } + keys.add( + [ + channel, + to, + accountId || "", + typeof threadId === "number" ? String(threadId) : "", + ].join("|"), + ); } - // Cross-provider if: multiple routable channels, or mix of routable + non-routable. - return ( - routableChannels.size > 1 || (routableChannels.size > 0 && hasNonRoutable) - ); + if (keys.size === 0) return false; + if (hasUnkeyed) return true; + return keys.size > 1; } export function scheduleFollowupDrain( key: string, @@ -423,14 +438,23 @@ export function scheduleFollowupDrain( queue.draining = true; void (async () => { try { + let forceIndividualCollect = false; while (queue.items.length > 0 || queue.droppedCount > 0) { await waitForQueueDebounce(queue); if (queue.mode === "collect") { + if (forceIndividualCollect) { + const next = queue.items.shift(); + if (!next) break; + await runFollowup(next); + continue; + } + // Check if messages span multiple providers. // If so, process individually to preserve per-message routing. const isCrossProvider = hasCrossProviderItems(queue.items); if (isCrossProvider) { + forceIndividualCollect = true; // Process one at a time to preserve per-message routing info. const next = queue.items.shift(); if (!next) break; @@ -451,6 +475,12 @@ export function scheduleFollowupDrain( const originatingTo = items.find( (i) => i.originatingTo, )?.originatingTo; + const originatingAccountId = items.find( + (i) => i.originatingAccountId, + )?.originatingAccountId; + const originatingThreadId = items.find( + (i) => typeof i.originatingThreadId === "number", + )?.originatingThreadId; const prompt = buildCollectPrompt(items, summary); await runFollowup({ @@ -459,6 +489,8 @@ export function scheduleFollowupDrain( enqueuedAt: Date.now(), originatingChannel, originatingTo, + originatingAccountId, + originatingThreadId, }); continue; } diff --git a/src/auto-reply/reply/route-reply.ts b/src/auto-reply/reply/route-reply.ts index 55f251b43..39bb28ab8 100644 --- a/src/auto-reply/reply/route-reply.ts +++ b/src/auto-reply/reply/route-reply.ts @@ -13,6 +13,7 @@ import { sendMessageIMessage } from "../../imessage/send.js"; import { sendMessageSignal } from "../../signal/send.js"; import { sendMessageSlack } from "../../slack/send.js"; import { sendMessageTelegram } from "../../telegram/send.js"; +import { sendMessageWhatsApp } from "../../web/outbound.js"; import type { OriginatingChannelType } from "../templating.js"; import type { ReplyPayload } from "../types.js"; @@ -23,6 +24,10 @@ export type RouteReplyParams = { channel: OriginatingChannelType; /** The destination chat/channel/user ID. */ to: string; + /** Provider account id (multi-account). */ + accountId?: string; + /** Telegram message thread id (forum topics). */ + threadId?: number; /** Config for provider-specific settings. */ cfg: ClawdbotConfig; }; @@ -47,29 +52,48 @@ export type RouteReplyResult = { export async function routeReply( params: RouteReplyParams, ): Promise { - const { payload, channel, to } = params; + const { payload, channel, to, accountId, threadId } = params; + const text = payload.text ?? ""; - const mediaUrl = payload.mediaUrl ?? payload.mediaUrls?.[0]; + const mediaUrls = (payload.mediaUrls?.filter(Boolean) ?? []).length + ? (payload.mediaUrls?.filter(Boolean) as string[]) + : payload.mediaUrl + ? [payload.mediaUrl] + : []; + const replyToId = payload.replyToId; // Skip empty replies. - if (!text.trim() && !mediaUrl) { + if (!text.trim() && mediaUrls.length === 0) { return { ok: true }; } - try { + const sendOne = async (params: { + text: string; + mediaUrl?: string; + }): Promise => { + const { text, mediaUrl } = params; switch (channel) { case "telegram": { - const result = await sendMessageTelegram(to, text, { mediaUrl }); + const result = await sendMessageTelegram(to, text, { + mediaUrl, + messageThreadId: threadId, + }); return { ok: true, messageId: result.messageId }; } case "slack": { - const result = await sendMessageSlack(to, text, { mediaUrl }); + const result = await sendMessageSlack(to, text, { + mediaUrl, + threadTs: replyToId, + }); return { ok: true, messageId: result.messageId }; } case "discord": { - const result = await sendMessageDiscord(to, text, { mediaUrl }); + const result = await sendMessageDiscord(to, text, { + mediaUrl, + replyTo: replyToId, + }); return { ok: true, messageId: result.messageId }; } @@ -84,17 +108,15 @@ export async function routeReply( } case "whatsapp": { - // WhatsApp doesn't have a standalone send function in this codebase. - // Falls through to unknown channel handling. - return { - ok: false, - error: `WhatsApp routing not yet implemented`, - }; + const result = await sendMessageWhatsApp(to, text, { + verbose: false, + mediaUrl, + accountId, + }); + return { ok: true, messageId: result.messageId }; } case "webchat": { - // Webchat is typically handled differently (real-time WebSocket). - // Falls through to unknown channel handling. return { ok: false, error: `Webchat routing not supported for queued replies`, @@ -102,14 +124,26 @@ export async function routeReply( } default: { - // Exhaustive check for unknown channel types. const _exhaustive: never = channel; - return { - ok: false, - error: `Unknown channel: ${String(_exhaustive)}`, - }; + return { ok: false, error: `Unknown channel: ${String(_exhaustive)}` }; } } + }; + + try { + if (mediaUrls.length === 0) { + return await sendOne({ text }); + } + + let last: RouteReplyResult | undefined; + for (let i = 0; i < mediaUrls.length; i++) { + const mediaUrl = mediaUrls[i]; + const caption = i === 0 ? text : ""; + last = await sendOne({ text: caption, mediaUrl }); + if (!last.ok) return last; + } + + return last ?? { ok: true }; } catch (err) { const message = err instanceof Error ? err.message : String(err); return { @@ -122,14 +156,25 @@ export async function routeReply( /** * Checks if a channel type is routable via routeReply. * - * Some channels (webchat, whatsapp) require special handling and - * cannot be routed through this generic interface. + * Some channels (webchat) require special handling and cannot be routed through + * this generic interface. */ export function isRoutableChannel( channel: OriginatingChannelType | undefined, -): channel is "telegram" | "slack" | "discord" | "signal" | "imessage" { +): channel is + | "telegram" + | "slack" + | "discord" + | "signal" + | "imessage" + | "whatsapp" { if (!channel) return false; - return ["telegram", "slack", "discord", "signal", "imessage"].includes( - channel, - ); + return [ + "telegram", + "slack", + "discord", + "signal", + "imessage", + "whatsapp", + ].includes(channel); } diff --git a/src/telegram/send.ts b/src/telegram/send.ts index afaa2ceb9..3b90e2840 100644 --- a/src/telegram/send.ts +++ b/src/telegram/send.ts @@ -10,6 +10,7 @@ type TelegramSendOpts = { verbose?: boolean; mediaUrl?: string; maxBytes?: number; + messageThreadId?: number; api?: Bot["api"]; }; @@ -88,6 +89,10 @@ export async function sendMessageTelegram( const bot = opts.api ? null : new Bot(token); const api = opts.api ?? bot?.api; const mediaUrl = opts.mediaUrl?.trim(); + const threadParams = + typeof opts.messageThreadId === "number" + ? { message_thread_id: Math.trunc(opts.messageThreadId) } + : undefined; const sleep = (ms: number) => new Promise((resolve) => setTimeout(resolve, ms)); @@ -150,35 +155,35 @@ export async function sendMessageTelegram( | Awaited>; if (isGif) { result = await sendWithRetry( - () => api.sendAnimation(chatId, file, { caption }), + () => api.sendAnimation(chatId, file, { caption, ...threadParams }), "animation", ).catch((err) => { throw wrapChatNotFound(err); }); } else if (kind === "image") { result = await sendWithRetry( - () => api.sendPhoto(chatId, file, { caption }), + () => api.sendPhoto(chatId, file, { caption, ...threadParams }), "photo", ).catch((err) => { throw wrapChatNotFound(err); }); } else if (kind === "video") { result = await sendWithRetry( - () => api.sendVideo(chatId, file, { caption }), + () => api.sendVideo(chatId, file, { caption, ...threadParams }), "video", ).catch((err) => { throw wrapChatNotFound(err); }); } else if (kind === "audio") { result = await sendWithRetry( - () => api.sendAudio(chatId, file, { caption }), + () => api.sendAudio(chatId, file, { caption, ...threadParams }), "audio", ).catch((err) => { throw wrapChatNotFound(err); }); } else { result = await sendWithRetry( - () => api.sendDocument(chatId, file, { caption }), + () => api.sendDocument(chatId, file, { caption, ...threadParams }), "document", ).catch((err) => { throw wrapChatNotFound(err); @@ -192,7 +197,11 @@ export async function sendMessageTelegram( throw new Error("Message must be non-empty for Telegram sends"); } const res = await sendWithRetry( - () => api.sendMessage(chatId, text, { parse_mode: "Markdown" }), + () => + api.sendMessage(chatId, text, { + parse_mode: "Markdown", + ...threadParams, + }), "message", ).catch(async (err) => { // Telegram rejects malformed Markdown (e.g., unbalanced '_' or '*'). @@ -205,7 +214,10 @@ export async function sendMessageTelegram( ); } return await sendWithRetry( - () => api.sendMessage(chatId, text), + () => + threadParams + ? api.sendMessage(chatId, text, threadParams) + : api.sendMessage(chatId, text), "message-plain", ).catch((err2) => { throw wrapChatNotFound(err2); diff --git a/src/web/auto-reply.ts b/src/web/auto-reply.ts index 2532923df..8f4d46b4a 100644 --- a/src/web/auto-reply.ts +++ b/src/web/auto-reply.ts @@ -1252,6 +1252,9 @@ export async function monitorWebProvider( WasMentioned: msg.wasMentioned, ...(msg.location ? toLocationContext(msg.location) : {}), Provider: "whatsapp", + Surface: "whatsapp", + OriginatingChannel: "whatsapp", + OriginatingTo: msg.to, }, cfg, dispatcher,