fix(routing): harden originating reply routing

This commit is contained in:
Peter Steinberger
2026-01-07 05:02:34 +00:00
parent 2a2e327cae
commit 3668388912
12 changed files with 356 additions and 51 deletions

View File

@@ -111,6 +111,7 @@
- Block streaming: preserve leading indentation in block replies (lists, indented fences).
- Docs: document systemd lingering and logged-in session requirements on macOS/Windows.
- Auto-reply: centralize tool/block/final dispatch across providers for consistent streaming + heartbeat/prefix handling. Thanks @MSch for PR #225.
- Routing: route replies back to the originating provider/chat when multiple providers share the same session. Thanks @jalehman for PR #328.
- Heartbeat: make HEARTBEAT_OK ack padding configurable across heartbeat and cron delivery. (#238) — thanks @jalehman
- Skills: emit MEDIA token after Nano Banana Pro image generation. Thanks @Iamadig for PR #271.
- WhatsApp: set sender E.164 for direct chats so owner commands work in DMs.

View File

@@ -1,6 +1,5 @@
import { describe, expect, it } from "vitest";
import type { AgentTool } from "@mariozechner/pi-agent-core";
import { describe, expect, it } from "vitest";
import { toToolDefinitions } from "./pi-tool-definition-adapter.js";

View File

@@ -38,7 +38,7 @@ export function toToolDefinitions(tools: AnyAgentTool[]): ToolDefinition[] {
: "";
if (name === "AbortError") throw err;
const message =
err instanceof Error ? err.stack ?? err.message : String(err);
err instanceof Error ? (err.stack ?? err.message) : String(err);
logError(`[tools] ${tool.name} failed: ${message}`);
return jsonResult({
status: "error",

View File

@@ -717,6 +717,8 @@ export async function getReplyFromConfig(
// Originating channel for reply routing.
originatingChannel: ctx.OriginatingChannel,
originatingTo: ctx.OriginatingTo,
originatingAccountId: ctx.AccountId,
originatingThreadId: ctx.MessageThreadId,
run: {
agentId,
agentDir,

View File

@@ -0,0 +1,91 @@
import { describe, expect, it, vi } from "vitest";
import type { ClawdbotConfig } from "../../config/config.js";
import type { MsgContext } from "../templating.js";
import type { GetReplyOptions, ReplyPayload } from "../types.js";
import type { ReplyDispatcher } from "./reply-dispatcher.js";
const mocks = vi.hoisted(() => ({
routeReply: vi.fn(async () => ({ ok: true, messageId: "mock" })),
}));
vi.mock("./route-reply.js", () => ({
isRoutableChannel: (channel: string | undefined) =>
Boolean(
channel &&
[
"telegram",
"slack",
"discord",
"signal",
"imessage",
"whatsapp",
].includes(channel),
),
routeReply: mocks.routeReply,
}));
const { dispatchReplyFromConfig } = await import("./dispatch-from-config.js");
function createDispatcher(): ReplyDispatcher {
return {
sendToolResult: vi.fn(() => true),
sendBlockReply: vi.fn(() => true),
sendFinalReply: vi.fn(() => true),
waitForIdle: vi.fn(async () => {}),
getQueuedCounts: vi.fn(() => ({ tool: 0, block: 0, final: 0 })),
};
}
describe("dispatchReplyFromConfig", () => {
it("does not route when Provider matches OriginatingChannel (even if Surface is missing)", async () => {
mocks.routeReply.mockClear();
const cfg = {} as ClawdbotConfig;
const dispatcher = createDispatcher();
const ctx: MsgContext = {
Provider: "slack",
OriginatingChannel: "slack",
OriginatingTo: "channel:C123",
};
const replyResolver = async (
_ctx: MsgContext,
_opts: GetReplyOptions | undefined,
_cfg: ClawdbotConfig,
) => ({ text: "hi" }) satisfies ReplyPayload;
await dispatchReplyFromConfig({ ctx, cfg, dispatcher, replyResolver });
expect(mocks.routeReply).not.toHaveBeenCalled();
expect(dispatcher.sendFinalReply).toHaveBeenCalledTimes(1);
});
it("routes when OriginatingChannel differs from Provider", async () => {
mocks.routeReply.mockClear();
const cfg = {} as ClawdbotConfig;
const dispatcher = createDispatcher();
const ctx: MsgContext = {
Provider: "slack",
AccountId: "acc-1",
MessageThreadId: 123,
OriginatingChannel: "telegram",
OriginatingTo: "telegram:999",
};
const replyResolver = async (
_ctx: MsgContext,
_opts: GetReplyOptions | undefined,
_cfg: ClawdbotConfig,
) => ({ text: "hi" }) satisfies ReplyPayload;
await dispatchReplyFromConfig({ ctx, cfg, dispatcher, replyResolver });
expect(dispatcher.sendFinalReply).not.toHaveBeenCalled();
expect(mocks.routeReply).toHaveBeenCalledWith(
expect.objectContaining({
channel: "telegram",
to: "telegram:999",
accountId: "acc-1",
threadId: 123,
}),
);
});
});

View File

@@ -27,7 +27,7 @@ export async function dispatchReplyFromConfig(params: {
// flow when the provider handles its own messages.
const originatingChannel = ctx.OriginatingChannel;
const originatingTo = ctx.OriginatingTo;
const currentSurface = ctx.Surface?.toLowerCase();
const currentSurface = (ctx.Surface ?? ctx.Provider)?.toLowerCase();
const shouldRouteToOriginating =
isRoutableChannel(originatingChannel) &&
originatingTo &&
@@ -47,6 +47,8 @@ export async function dispatchReplyFromConfig(params: {
payload,
channel: originatingChannel,
to: originatingTo,
accountId: ctx.AccountId,
threadId: ctx.MessageThreadId,
cfg,
});
if (!result.ok) {
@@ -89,6 +91,7 @@ export async function dispatchReplyFromConfig(params: {
: [];
let queuedFinal = false;
let routedFinalCount = 0;
for (const reply of replies) {
if (shouldRouteToOriginating && originatingChannel && originatingTo) {
// Route final reply to originating channel.
@@ -96,6 +99,8 @@ export async function dispatchReplyFromConfig(params: {
payload: reply,
channel: originatingChannel,
to: originatingTo,
accountId: ctx.AccountId,
threadId: ctx.MessageThreadId,
cfg,
});
if (!result.ok) {
@@ -103,13 +108,15 @@ export async function dispatchReplyFromConfig(params: {
`dispatch-from-config: route-reply (final) failed: ${result.error ?? "unknown error"}`,
);
}
// Mark as queued since we handled it ourselves.
queuedFinal = true;
queuedFinal = result.ok || queuedFinal;
if (result.ok) routedFinalCount += 1;
} else {
queuedFinal = dispatcher.sendFinalReply(reply) || queuedFinal;
}
}
await dispatcher.waitForIdle();
return { queuedFinal, counts: dispatcher.getQueuedCounts() };
const counts = dispatcher.getQueuedCounts();
counts.final += routedFinalCount;
return { queuedFinal, counts };
}

View File

@@ -79,6 +79,8 @@ export function createFollowupRunner(params: {
payload,
channel: originatingChannel,
to: originatingTo,
accountId: queued.originatingAccountId,
threadId: queued.originatingThreadId,
cfg: queued.run.config,
});
if (!result.ok) {

View File

@@ -0,0 +1,111 @@
import { describe, expect, it } from "vitest";
import type { ClawdbotConfig } from "../../config/config.js";
import type { FollowupRun, QueueSettings } from "./queue.js";
import { enqueueFollowupRun, scheduleFollowupDrain } from "./queue.js";
function createRun(params: {
prompt: string;
originatingChannel?: FollowupRun["originatingChannel"];
originatingTo?: string;
}): FollowupRun {
return {
prompt: params.prompt,
enqueuedAt: Date.now(),
originatingChannel: params.originatingChannel,
originatingTo: params.originatingTo,
run: {
agentId: "agent",
agentDir: "/tmp",
sessionId: "sess",
sessionFile: "/tmp/session.json",
workspaceDir: "/tmp",
config: {} as ClawdbotConfig,
provider: "openai",
model: "gpt-test",
timeoutMs: 10_000,
blockReplyBreak: "text_end",
},
};
}
describe("followup queue collect routing", () => {
it("does not collect when destinations differ", async () => {
const key = `test-collect-diff-to-${Date.now()}`;
const calls: FollowupRun[] = [];
const runFollowup = async (run: FollowupRun) => {
calls.push(run);
};
const settings: QueueSettings = {
mode: "collect",
debounceMs: 0,
cap: 50,
dropPolicy: "summarize",
};
enqueueFollowupRun(
key,
createRun({
prompt: "one",
originatingChannel: "slack",
originatingTo: "channel:A",
}),
settings,
);
enqueueFollowupRun(
key,
createRun({
prompt: "two",
originatingChannel: "slack",
originatingTo: "channel:B",
}),
settings,
);
scheduleFollowupDrain(key, runFollowup);
await expect.poll(() => calls.length).toBe(2);
expect(calls[0]?.prompt).toBe("one");
expect(calls[1]?.prompt).toBe("two");
});
it("collects when channel+destination match", async () => {
const key = `test-collect-same-to-${Date.now()}`;
const calls: FollowupRun[] = [];
const runFollowup = async (run: FollowupRun) => {
calls.push(run);
};
const settings: QueueSettings = {
mode: "collect",
debounceMs: 0,
cap: 50,
dropPolicy: "summarize",
};
enqueueFollowupRun(
key,
createRun({
prompt: "one",
originatingChannel: "slack",
originatingTo: "channel:A",
}),
settings,
);
enqueueFollowupRun(
key,
createRun({
prompt: "two",
originatingChannel: "slack",
originatingTo: "channel:A",
}),
settings,
);
scheduleFollowupDrain(key, runFollowup);
await expect.poll(() => calls.length).toBe(1);
expect(calls[0]?.prompt).toContain(
"[Queued messages while agent was busy]",
);
expect(calls[0]?.originatingChannel).toBe("slack");
expect(calls[0]?.originatingTo).toBe("channel:A");
});
});

View File

@@ -35,6 +35,10 @@ export type FollowupRun = {
* The chat/channel/user ID where the reply should be sent.
*/
originatingTo?: string;
/** Provider account id (multi-account). */
originatingAccountId?: string;
/** Telegram forum topic thread id. */
originatingThreadId?: number;
run: {
agentId: string;
agentDir: string;
@@ -396,23 +400,34 @@ function buildCollectPrompt(items: FollowupRun[], summary?: string): string {
* Also returns true for a mix of routable and non-routable channels.
*/
function hasCrossProviderItems(items: FollowupRun[]): boolean {
const routableChannels = new Set<string>();
let hasNonRoutable = false;
const keys = new Set<string>();
let hasUnkeyed = false;
for (const item of items) {
const channel = item.originatingChannel;
if (isRoutableChannel(channel)) {
routableChannels.add(channel);
} else if (channel) {
// Has a channel but it's not routable (whatsapp, webchat).
hasNonRoutable = true;
const to = item.originatingTo;
const accountId = item.originatingAccountId;
const threadId = item.originatingThreadId;
if (!channel && !to && !accountId && typeof threadId !== "number") {
hasUnkeyed = true;
continue;
}
if (!isRoutableChannel(channel) || !to) {
return true;
}
keys.add(
[
channel,
to,
accountId || "",
typeof threadId === "number" ? String(threadId) : "",
].join("|"),
);
}
// Cross-provider if: multiple routable channels, or mix of routable + non-routable.
return (
routableChannels.size > 1 || (routableChannels.size > 0 && hasNonRoutable)
);
if (keys.size === 0) return false;
if (hasUnkeyed) return true;
return keys.size > 1;
}
export function scheduleFollowupDrain(
key: string,
@@ -423,14 +438,23 @@ export function scheduleFollowupDrain(
queue.draining = true;
void (async () => {
try {
let forceIndividualCollect = false;
while (queue.items.length > 0 || queue.droppedCount > 0) {
await waitForQueueDebounce(queue);
if (queue.mode === "collect") {
if (forceIndividualCollect) {
const next = queue.items.shift();
if (!next) break;
await runFollowup(next);
continue;
}
// Check if messages span multiple providers.
// If so, process individually to preserve per-message routing.
const isCrossProvider = hasCrossProviderItems(queue.items);
if (isCrossProvider) {
forceIndividualCollect = true;
// Process one at a time to preserve per-message routing info.
const next = queue.items.shift();
if (!next) break;
@@ -451,6 +475,12 @@ export function scheduleFollowupDrain(
const originatingTo = items.find(
(i) => i.originatingTo,
)?.originatingTo;
const originatingAccountId = items.find(
(i) => i.originatingAccountId,
)?.originatingAccountId;
const originatingThreadId = items.find(
(i) => typeof i.originatingThreadId === "number",
)?.originatingThreadId;
const prompt = buildCollectPrompt(items, summary);
await runFollowup({
@@ -459,6 +489,8 @@ export function scheduleFollowupDrain(
enqueuedAt: Date.now(),
originatingChannel,
originatingTo,
originatingAccountId,
originatingThreadId,
});
continue;
}

View File

@@ -13,6 +13,7 @@ import { sendMessageIMessage } from "../../imessage/send.js";
import { sendMessageSignal } from "../../signal/send.js";
import { sendMessageSlack } from "../../slack/send.js";
import { sendMessageTelegram } from "../../telegram/send.js";
import { sendMessageWhatsApp } from "../../web/outbound.js";
import type { OriginatingChannelType } from "../templating.js";
import type { ReplyPayload } from "../types.js";
@@ -23,6 +24,10 @@ export type RouteReplyParams = {
channel: OriginatingChannelType;
/** The destination chat/channel/user ID. */
to: string;
/** Provider account id (multi-account). */
accountId?: string;
/** Telegram message thread id (forum topics). */
threadId?: number;
/** Config for provider-specific settings. */
cfg: ClawdbotConfig;
};
@@ -47,29 +52,48 @@ export type RouteReplyResult = {
export async function routeReply(
params: RouteReplyParams,
): Promise<RouteReplyResult> {
const { payload, channel, to } = params;
const { payload, channel, to, accountId, threadId } = params;
const text = payload.text ?? "";
const mediaUrl = payload.mediaUrl ?? payload.mediaUrls?.[0];
const mediaUrls = (payload.mediaUrls?.filter(Boolean) ?? []).length
? (payload.mediaUrls?.filter(Boolean) as string[])
: payload.mediaUrl
? [payload.mediaUrl]
: [];
const replyToId = payload.replyToId;
// Skip empty replies.
if (!text.trim() && !mediaUrl) {
if (!text.trim() && mediaUrls.length === 0) {
return { ok: true };
}
try {
const sendOne = async (params: {
text: string;
mediaUrl?: string;
}): Promise<RouteReplyResult> => {
const { text, mediaUrl } = params;
switch (channel) {
case "telegram": {
const result = await sendMessageTelegram(to, text, { mediaUrl });
const result = await sendMessageTelegram(to, text, {
mediaUrl,
messageThreadId: threadId,
});
return { ok: true, messageId: result.messageId };
}
case "slack": {
const result = await sendMessageSlack(to, text, { mediaUrl });
const result = await sendMessageSlack(to, text, {
mediaUrl,
threadTs: replyToId,
});
return { ok: true, messageId: result.messageId };
}
case "discord": {
const result = await sendMessageDiscord(to, text, { mediaUrl });
const result = await sendMessageDiscord(to, text, {
mediaUrl,
replyTo: replyToId,
});
return { ok: true, messageId: result.messageId };
}
@@ -84,17 +108,15 @@ export async function routeReply(
}
case "whatsapp": {
// WhatsApp doesn't have a standalone send function in this codebase.
// Falls through to unknown channel handling.
return {
ok: false,
error: `WhatsApp routing not yet implemented`,
};
const result = await sendMessageWhatsApp(to, text, {
verbose: false,
mediaUrl,
accountId,
});
return { ok: true, messageId: result.messageId };
}
case "webchat": {
// Webchat is typically handled differently (real-time WebSocket).
// Falls through to unknown channel handling.
return {
ok: false,
error: `Webchat routing not supported for queued replies`,
@@ -102,14 +124,26 @@ export async function routeReply(
}
default: {
// Exhaustive check for unknown channel types.
const _exhaustive: never = channel;
return {
ok: false,
error: `Unknown channel: ${String(_exhaustive)}`,
};
return { ok: false, error: `Unknown channel: ${String(_exhaustive)}` };
}
}
};
try {
if (mediaUrls.length === 0) {
return await sendOne({ text });
}
let last: RouteReplyResult | undefined;
for (let i = 0; i < mediaUrls.length; i++) {
const mediaUrl = mediaUrls[i];
const caption = i === 0 ? text : "";
last = await sendOne({ text: caption, mediaUrl });
if (!last.ok) return last;
}
return last ?? { ok: true };
} catch (err) {
const message = err instanceof Error ? err.message : String(err);
return {
@@ -122,14 +156,25 @@ export async function routeReply(
/**
* Checks if a channel type is routable via routeReply.
*
* Some channels (webchat, whatsapp) require special handling and
* cannot be routed through this generic interface.
* Some channels (webchat) require special handling and cannot be routed through
* this generic interface.
*/
export function isRoutableChannel(
channel: OriginatingChannelType | undefined,
): channel is "telegram" | "slack" | "discord" | "signal" | "imessage" {
): channel is
| "telegram"
| "slack"
| "discord"
| "signal"
| "imessage"
| "whatsapp" {
if (!channel) return false;
return ["telegram", "slack", "discord", "signal", "imessage"].includes(
channel,
);
return [
"telegram",
"slack",
"discord",
"signal",
"imessage",
"whatsapp",
].includes(channel);
}

View File

@@ -10,6 +10,7 @@ type TelegramSendOpts = {
verbose?: boolean;
mediaUrl?: string;
maxBytes?: number;
messageThreadId?: number;
api?: Bot["api"];
};
@@ -88,6 +89,10 @@ export async function sendMessageTelegram(
const bot = opts.api ? null : new Bot(token);
const api = opts.api ?? bot?.api;
const mediaUrl = opts.mediaUrl?.trim();
const threadParams =
typeof opts.messageThreadId === "number"
? { message_thread_id: Math.trunc(opts.messageThreadId) }
: undefined;
const sleep = (ms: number) =>
new Promise((resolve) => setTimeout(resolve, ms));
@@ -150,35 +155,35 @@ export async function sendMessageTelegram(
| Awaited<ReturnType<typeof api.sendDocument>>;
if (isGif) {
result = await sendWithRetry(
() => api.sendAnimation(chatId, file, { caption }),
() => api.sendAnimation(chatId, file, { caption, ...threadParams }),
"animation",
).catch((err) => {
throw wrapChatNotFound(err);
});
} else if (kind === "image") {
result = await sendWithRetry(
() => api.sendPhoto(chatId, file, { caption }),
() => api.sendPhoto(chatId, file, { caption, ...threadParams }),
"photo",
).catch((err) => {
throw wrapChatNotFound(err);
});
} else if (kind === "video") {
result = await sendWithRetry(
() => api.sendVideo(chatId, file, { caption }),
() => api.sendVideo(chatId, file, { caption, ...threadParams }),
"video",
).catch((err) => {
throw wrapChatNotFound(err);
});
} else if (kind === "audio") {
result = await sendWithRetry(
() => api.sendAudio(chatId, file, { caption }),
() => api.sendAudio(chatId, file, { caption, ...threadParams }),
"audio",
).catch((err) => {
throw wrapChatNotFound(err);
});
} else {
result = await sendWithRetry(
() => api.sendDocument(chatId, file, { caption }),
() => api.sendDocument(chatId, file, { caption, ...threadParams }),
"document",
).catch((err) => {
throw wrapChatNotFound(err);
@@ -192,7 +197,11 @@ export async function sendMessageTelegram(
throw new Error("Message must be non-empty for Telegram sends");
}
const res = await sendWithRetry(
() => api.sendMessage(chatId, text, { parse_mode: "Markdown" }),
() =>
api.sendMessage(chatId, text, {
parse_mode: "Markdown",
...threadParams,
}),
"message",
).catch(async (err) => {
// Telegram rejects malformed Markdown (e.g., unbalanced '_' or '*').
@@ -205,7 +214,10 @@ export async function sendMessageTelegram(
);
}
return await sendWithRetry(
() => api.sendMessage(chatId, text),
() =>
threadParams
? api.sendMessage(chatId, text, threadParams)
: api.sendMessage(chatId, text),
"message-plain",
).catch((err2) => {
throw wrapChatNotFound(err2);

View File

@@ -1252,6 +1252,9 @@ export async function monitorWebProvider(
WasMentioned: msg.wasMentioned,
...(msg.location ? toLocationContext(msg.location) : {}),
Provider: "whatsapp",
Surface: "whatsapp",
OriginatingChannel: "whatsapp",
OriginatingTo: msg.to,
},
cfg,
dispatcher,