fix: abort runs between tool calls
This commit is contained in:
@@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
## Unreleased
|
## Unreleased
|
||||||
|
|
||||||
|
- Agent: fast abort on /stop and cancel tool calls between tool boundaries. (#617)
|
||||||
- Models/Auth: add OpenCode Zen (multi-model proxy) onboarding. (#623) — thanks @magimetal
|
- Models/Auth: add OpenCode Zen (multi-model proxy) onboarding. (#623) — thanks @magimetal
|
||||||
- WhatsApp: refactor vCard parsing helper and improve empty contact card summaries. (#624) — thanks @steipete
|
- WhatsApp: refactor vCard parsing helper and improve empty contact card summaries. (#624) — thanks @steipete
|
||||||
- WhatsApp: include phone numbers when multiple contacts are shared. (#625) — thanks @mahmoudashraf93
|
- WhatsApp: include phone numbers when multiple contacts are shared. (#625) — thanks @mahmoudashraf93
|
||||||
|
|||||||
@@ -853,6 +853,7 @@ export async function compactEmbeddedPiSession(params: {
|
|||||||
sessionKey: params.sessionKey ?? params.sessionId,
|
sessionKey: params.sessionKey ?? params.sessionId,
|
||||||
agentDir,
|
agentDir,
|
||||||
config: params.config,
|
config: params.config,
|
||||||
|
abortSignal: runAbortController.signal,
|
||||||
// No currentChannelId/currentThreadTs for compaction - not in message context
|
// No currentChannelId/currentThreadTs for compaction - not in message context
|
||||||
});
|
});
|
||||||
const machineName = await getMachineDisplayName();
|
const machineName = await getMachineDisplayName();
|
||||||
@@ -1045,6 +1046,7 @@ export async function runEmbeddedPiAgent(params: {
|
|||||||
const enqueueGlobal =
|
const enqueueGlobal =
|
||||||
params.enqueue ??
|
params.enqueue ??
|
||||||
((task, opts) => enqueueCommandInLane(globalLane, task, opts));
|
((task, opts) => enqueueCommandInLane(globalLane, task, opts));
|
||||||
|
const runAbortController = new AbortController();
|
||||||
return enqueueCommandInLane(sessionLane, () =>
|
return enqueueCommandInLane(sessionLane, () =>
|
||||||
enqueueGlobal(async () => {
|
enqueueGlobal(async () => {
|
||||||
const started = Date.now();
|
const started = Date.now();
|
||||||
@@ -1223,6 +1225,7 @@ export async function runEmbeddedPiAgent(params: {
|
|||||||
sessionKey: params.sessionKey ?? params.sessionId,
|
sessionKey: params.sessionKey ?? params.sessionId,
|
||||||
agentDir,
|
agentDir,
|
||||||
config: params.config,
|
config: params.config,
|
||||||
|
abortSignal: runAbortController.signal,
|
||||||
currentChannelId: params.currentChannelId,
|
currentChannelId: params.currentChannelId,
|
||||||
currentThreadTs: params.currentThreadTs,
|
currentThreadTs: params.currentThreadTs,
|
||||||
replyToMode: params.replyToMode,
|
replyToMode: params.replyToMode,
|
||||||
@@ -1326,6 +1329,7 @@ export async function runEmbeddedPiAgent(params: {
|
|||||||
const abortRun = (isTimeout = false) => {
|
const abortRun = (isTimeout = false) => {
|
||||||
aborted = true;
|
aborted = true;
|
||||||
if (isTimeout) timedOut = true;
|
if (isTimeout) timedOut = true;
|
||||||
|
runAbortController.abort();
|
||||||
void session.abort();
|
void session.abort();
|
||||||
};
|
};
|
||||||
let subscription: ReturnType<typeof subscribeEmbeddedPiSession>;
|
let subscription: ReturnType<typeof subscribeEmbeddedPiSession>;
|
||||||
|
|||||||
@@ -503,6 +503,48 @@ export const __testing = {
|
|||||||
cleanToolSchemaForGemini,
|
cleanToolSchemaForGemini,
|
||||||
} as const;
|
} as const;
|
||||||
|
|
||||||
|
function throwAbortError(): never {
|
||||||
|
const err = new Error("Aborted");
|
||||||
|
err.name = "AbortError";
|
||||||
|
throw err;
|
||||||
|
}
|
||||||
|
|
||||||
|
function combineAbortSignals(
|
||||||
|
a?: AbortSignal,
|
||||||
|
b?: AbortSignal,
|
||||||
|
): AbortSignal | undefined {
|
||||||
|
if (!a && !b) return undefined;
|
||||||
|
if (a && !b) return a;
|
||||||
|
if (b && !a) return b;
|
||||||
|
if (a?.aborted) return a;
|
||||||
|
if (b?.aborted) return b;
|
||||||
|
if (typeof AbortSignal.any === "function") {
|
||||||
|
return AbortSignal.any([a as AbortSignal, b as AbortSignal]);
|
||||||
|
}
|
||||||
|
const controller = new AbortController();
|
||||||
|
const onAbort = () => controller.abort();
|
||||||
|
a?.addEventListener("abort", onAbort, { once: true });
|
||||||
|
b?.addEventListener("abort", onAbort, { once: true });
|
||||||
|
return controller.signal;
|
||||||
|
}
|
||||||
|
|
||||||
|
function wrapToolWithAbortSignal(
|
||||||
|
tool: AnyAgentTool,
|
||||||
|
abortSignal?: AbortSignal,
|
||||||
|
): AnyAgentTool {
|
||||||
|
if (!abortSignal) return tool;
|
||||||
|
const execute = tool.execute;
|
||||||
|
if (!execute) return tool;
|
||||||
|
return {
|
||||||
|
...tool,
|
||||||
|
execute: async (toolCallId, params, signal, onUpdate) => {
|
||||||
|
const combined = combineAbortSignals(signal, abortSignal);
|
||||||
|
if (combined?.aborted) throwAbortError();
|
||||||
|
return await execute(toolCallId, params, combined, onUpdate);
|
||||||
|
},
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
export function createClawdbotCodingTools(options?: {
|
export function createClawdbotCodingTools(options?: {
|
||||||
bash?: BashToolDefaults & ProcessToolDefaults;
|
bash?: BashToolDefaults & ProcessToolDefaults;
|
||||||
messageProvider?: string;
|
messageProvider?: string;
|
||||||
@@ -511,6 +553,7 @@ export function createClawdbotCodingTools(options?: {
|
|||||||
sessionKey?: string;
|
sessionKey?: string;
|
||||||
agentDir?: string;
|
agentDir?: string;
|
||||||
config?: ClawdbotConfig;
|
config?: ClawdbotConfig;
|
||||||
|
abortSignal?: AbortSignal;
|
||||||
/** Current channel ID for auto-threading (Slack). */
|
/** Current channel ID for auto-threading (Slack). */
|
||||||
currentChannelId?: string;
|
currentChannelId?: string;
|
||||||
/** Current thread timestamp for auto-threading (Slack). */
|
/** Current thread timestamp for auto-threading (Slack). */
|
||||||
@@ -607,8 +650,11 @@ export function createClawdbotCodingTools(options?: {
|
|||||||
// Always normalize tool JSON Schemas before handing them to pi-agent/pi-ai.
|
// Always normalize tool JSON Schemas before handing them to pi-agent/pi-ai.
|
||||||
// Without this, some providers (notably OpenAI) will reject root-level union schemas.
|
// Without this, some providers (notably OpenAI) will reject root-level union schemas.
|
||||||
const normalized = subagentFiltered.map(normalizeToolParameters);
|
const normalized = subagentFiltered.map(normalizeToolParameters);
|
||||||
|
const withAbort = options?.abortSignal
|
||||||
|
? normalized.map((tool) => wrapToolWithAbortSignal(tool, options.abortSignal))
|
||||||
|
: normalized;
|
||||||
|
|
||||||
// Anthropic blocks specific lowercase tool names (bash, read, write, edit) with OAuth tokens.
|
// Anthropic blocks specific lowercase tool names (bash, read, write, edit) with OAuth tokens.
|
||||||
// Always use capitalized versions for compatibility with both OAuth and regular API keys.
|
// Always use capitalized versions for compatibility with both OAuth and regular API keys.
|
||||||
return renameBlockedToolsForOAuth(normalized);
|
return renameBlockedToolsForOAuth(withAbort);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,3 +1,22 @@
|
|||||||
|
import { abortEmbeddedPiRun } from "../../agents/pi-embedded.js";
|
||||||
|
import type { ClawdbotConfig } from "../../config/config.js";
|
||||||
|
import {
|
||||||
|
loadSessionStore,
|
||||||
|
resolveStorePath,
|
||||||
|
saveSessionStore,
|
||||||
|
} from "../../config/sessions.js";
|
||||||
|
import {
|
||||||
|
parseAgentSessionKey,
|
||||||
|
resolveAgentIdFromSessionKey,
|
||||||
|
} from "../../routing/session-key.js";
|
||||||
|
import { resolveCommandAuthorization } from "../command-auth.js";
|
||||||
|
import {
|
||||||
|
normalizeCommandBody,
|
||||||
|
shouldHandleTextCommands,
|
||||||
|
} from "../commands-registry.js";
|
||||||
|
import type { MsgContext } from "../templating.js";
|
||||||
|
import { stripMentions, stripStructuralPrefixes } from "./mentions.js";
|
||||||
|
|
||||||
const ABORT_TRIGGERS = new Set(["stop", "esc", "abort", "wait", "exit"]);
|
const ABORT_TRIGGERS = new Set(["stop", "esc", "abort", "wait", "exit"]);
|
||||||
const ABORT_MEMORY = new Map<string, boolean>();
|
const ABORT_MEMORY = new Map<string, boolean>();
|
||||||
|
|
||||||
@@ -14,3 +33,82 @@ export function getAbortMemory(key: string): boolean | undefined {
|
|||||||
export function setAbortMemory(key: string, value: boolean): void {
|
export function setAbortMemory(key: string, value: boolean): void {
|
||||||
ABORT_MEMORY.set(key, value);
|
ABORT_MEMORY.set(key, value);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function resolveSessionEntryForKey(
|
||||||
|
store: Record<string, { sessionId: string; updatedAt: number }> | undefined,
|
||||||
|
sessionKey: string | undefined,
|
||||||
|
) {
|
||||||
|
if (!store || !sessionKey) return {};
|
||||||
|
const direct = store[sessionKey];
|
||||||
|
if (direct) return { entry: direct, key: sessionKey };
|
||||||
|
const parsed = parseAgentSessionKey(sessionKey);
|
||||||
|
const legacyKey = parsed?.rest;
|
||||||
|
if (legacyKey && store[legacyKey]) {
|
||||||
|
return { entry: store[legacyKey], key: legacyKey };
|
||||||
|
}
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
|
function resolveAbortTargetKey(ctx: MsgContext): string | undefined {
|
||||||
|
const target = ctx.CommandTargetSessionKey?.trim();
|
||||||
|
if (target) return target;
|
||||||
|
const sessionKey = ctx.SessionKey?.trim();
|
||||||
|
return sessionKey || undefined;
|
||||||
|
}
|
||||||
|
|
||||||
|
export async function tryFastAbortFromMessage(params: {
|
||||||
|
ctx: MsgContext;
|
||||||
|
cfg: ClawdbotConfig;
|
||||||
|
}): Promise<{ handled: boolean; aborted: boolean }> {
|
||||||
|
const { ctx, cfg } = params;
|
||||||
|
const surface = (ctx.Surface ?? ctx.Provider ?? "").trim().toLowerCase();
|
||||||
|
const allowTextCommands = shouldHandleTextCommands({
|
||||||
|
cfg,
|
||||||
|
surface,
|
||||||
|
commandSource: ctx.CommandSource,
|
||||||
|
});
|
||||||
|
if (!allowTextCommands) return { handled: false, aborted: false };
|
||||||
|
|
||||||
|
const commandAuthorized = ctx.CommandAuthorized ?? true;
|
||||||
|
const auth = resolveCommandAuthorization({
|
||||||
|
ctx,
|
||||||
|
cfg,
|
||||||
|
commandAuthorized,
|
||||||
|
});
|
||||||
|
if (!auth.isAuthorizedSender) return { handled: false, aborted: false };
|
||||||
|
|
||||||
|
const targetKey = resolveAbortTargetKey(ctx);
|
||||||
|
const agentId = resolveAgentIdFromSessionKey(
|
||||||
|
targetKey ?? ctx.SessionKey ?? "",
|
||||||
|
);
|
||||||
|
const raw = stripStructuralPrefixes(ctx.Body ?? "");
|
||||||
|
const isGroup = ctx.ChatType?.trim().toLowerCase() === "group";
|
||||||
|
const stripped = isGroup ? stripMentions(raw, ctx, cfg, agentId) : raw;
|
||||||
|
const normalized = normalizeCommandBody(stripped);
|
||||||
|
const abortRequested = normalized === "/stop" || isAbortTrigger(stripped);
|
||||||
|
if (!abortRequested) return { handled: false, aborted: false };
|
||||||
|
|
||||||
|
const abortKey = targetKey ?? auth.from ?? auth.to;
|
||||||
|
|
||||||
|
if (targetKey) {
|
||||||
|
const storePath = resolveStorePath(cfg.session?.store, { agentId });
|
||||||
|
const store = loadSessionStore(storePath);
|
||||||
|
const { entry, key } = resolveSessionEntryForKey(store, targetKey);
|
||||||
|
const sessionId = entry?.sessionId;
|
||||||
|
const aborted = sessionId ? abortEmbeddedPiRun(sessionId) : false;
|
||||||
|
if (entry && key) {
|
||||||
|
entry.abortedLastRun = true;
|
||||||
|
entry.updatedAt = Date.now();
|
||||||
|
store[key] = entry;
|
||||||
|
await saveSessionStore(storePath, store);
|
||||||
|
} else if (abortKey) {
|
||||||
|
setAbortMemory(abortKey, true);
|
||||||
|
}
|
||||||
|
return { handled: true, aborted };
|
||||||
|
}
|
||||||
|
|
||||||
|
if (abortKey) {
|
||||||
|
setAbortMemory(abortKey, true);
|
||||||
|
}
|
||||||
|
return { handled: true, aborted: false };
|
||||||
|
}
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import type { ReplyDispatcher } from "./reply-dispatcher.js";
|
|||||||
|
|
||||||
const mocks = vi.hoisted(() => ({
|
const mocks = vi.hoisted(() => ({
|
||||||
routeReply: vi.fn(async () => ({ ok: true, messageId: "mock" })),
|
routeReply: vi.fn(async () => ({ ok: true, messageId: "mock" })),
|
||||||
|
tryFastAbortFromMessage: vi.fn(async () => ({ handled: false, aborted: false })),
|
||||||
}));
|
}));
|
||||||
|
|
||||||
vi.mock("./route-reply.js", () => ({
|
vi.mock("./route-reply.js", () => ({
|
||||||
@@ -25,6 +26,10 @@ vi.mock("./route-reply.js", () => ({
|
|||||||
routeReply: mocks.routeReply,
|
routeReply: mocks.routeReply,
|
||||||
}));
|
}));
|
||||||
|
|
||||||
|
vi.mock("./abort.js", () => ({
|
||||||
|
tryFastAbortFromMessage: mocks.tryFastAbortFromMessage,
|
||||||
|
}));
|
||||||
|
|
||||||
const { dispatchReplyFromConfig } = await import("./dispatch-from-config.js");
|
const { dispatchReplyFromConfig } = await import("./dispatch-from-config.js");
|
||||||
|
|
||||||
function createDispatcher(): ReplyDispatcher {
|
function createDispatcher(): ReplyDispatcher {
|
||||||
@@ -39,6 +44,10 @@ function createDispatcher(): ReplyDispatcher {
|
|||||||
|
|
||||||
describe("dispatchReplyFromConfig", () => {
|
describe("dispatchReplyFromConfig", () => {
|
||||||
it("does not route when Provider matches OriginatingChannel (even if Surface is missing)", async () => {
|
it("does not route when Provider matches OriginatingChannel (even if Surface is missing)", async () => {
|
||||||
|
mocks.tryFastAbortFromMessage.mockResolvedValue({
|
||||||
|
handled: false,
|
||||||
|
aborted: false,
|
||||||
|
});
|
||||||
mocks.routeReply.mockClear();
|
mocks.routeReply.mockClear();
|
||||||
const cfg = {} as ClawdbotConfig;
|
const cfg = {} as ClawdbotConfig;
|
||||||
const dispatcher = createDispatcher();
|
const dispatcher = createDispatcher();
|
||||||
@@ -60,6 +69,10 @@ describe("dispatchReplyFromConfig", () => {
|
|||||||
});
|
});
|
||||||
|
|
||||||
it("routes when OriginatingChannel differs from Provider", async () => {
|
it("routes when OriginatingChannel differs from Provider", async () => {
|
||||||
|
mocks.tryFastAbortFromMessage.mockResolvedValue({
|
||||||
|
handled: false,
|
||||||
|
aborted: false,
|
||||||
|
});
|
||||||
mocks.routeReply.mockClear();
|
mocks.routeReply.mockClear();
|
||||||
const cfg = {} as ClawdbotConfig;
|
const cfg = {} as ClawdbotConfig;
|
||||||
const dispatcher = createDispatcher();
|
const dispatcher = createDispatcher();
|
||||||
@@ -88,4 +101,25 @@ describe("dispatchReplyFromConfig", () => {
|
|||||||
}),
|
}),
|
||||||
);
|
);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
it("fast-aborts without calling the reply resolver", async () => {
|
||||||
|
mocks.tryFastAbortFromMessage.mockResolvedValue({
|
||||||
|
handled: true,
|
||||||
|
aborted: true,
|
||||||
|
});
|
||||||
|
const cfg = {} as ClawdbotConfig;
|
||||||
|
const dispatcher = createDispatcher();
|
||||||
|
const ctx: MsgContext = {
|
||||||
|
Provider: "telegram",
|
||||||
|
Body: "/stop",
|
||||||
|
};
|
||||||
|
const replyResolver = vi.fn(async () => ({ text: "hi" }) as ReplyPayload);
|
||||||
|
|
||||||
|
await dispatchReplyFromConfig({ ctx, cfg, dispatcher, replyResolver });
|
||||||
|
|
||||||
|
expect(replyResolver).not.toHaveBeenCalled();
|
||||||
|
expect(dispatcher.sendFinalReply).toHaveBeenCalledWith({
|
||||||
|
text: "⚙️ Agent was aborted.",
|
||||||
|
});
|
||||||
|
});
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ import { logVerbose } from "../../globals.js";
|
|||||||
import { getReplyFromConfig } from "../reply.js";
|
import { getReplyFromConfig } from "../reply.js";
|
||||||
import type { MsgContext } from "../templating.js";
|
import type { MsgContext } from "../templating.js";
|
||||||
import type { GetReplyOptions, ReplyPayload } from "../types.js";
|
import type { GetReplyOptions, ReplyPayload } from "../types.js";
|
||||||
|
import { tryFastAbortFromMessage } from "./abort.js";
|
||||||
import type { ReplyDispatcher, ReplyDispatchKind } from "./reply-dispatcher.js";
|
import type { ReplyDispatcher, ReplyDispatchKind } from "./reply-dispatcher.js";
|
||||||
import { isRoutableChannel, routeReply } from "./route-reply.js";
|
import { isRoutableChannel, routeReply } from "./route-reply.js";
|
||||||
|
|
||||||
@@ -66,6 +67,37 @@ export async function dispatchReplyFromConfig(params: {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const fastAbort = await tryFastAbortFromMessage({ ctx, cfg });
|
||||||
|
if (fastAbort.handled) {
|
||||||
|
const payload = { text: "⚙️ Agent was aborted." } satisfies ReplyPayload;
|
||||||
|
let queuedFinal = false;
|
||||||
|
let routedFinalCount = 0;
|
||||||
|
if (shouldRouteToOriginating && originatingChannel && originatingTo) {
|
||||||
|
const result = await routeReply({
|
||||||
|
payload,
|
||||||
|
channel: originatingChannel,
|
||||||
|
to: originatingTo,
|
||||||
|
sessionKey: ctx.SessionKey,
|
||||||
|
accountId: ctx.AccountId,
|
||||||
|
threadId: ctx.MessageThreadId,
|
||||||
|
cfg,
|
||||||
|
});
|
||||||
|
queuedFinal = result.ok;
|
||||||
|
if (result.ok) routedFinalCount += 1;
|
||||||
|
if (!result.ok) {
|
||||||
|
logVerbose(
|
||||||
|
`dispatch-from-config: route-reply (abort) failed: ${result.error ?? "unknown error"}`,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
queuedFinal = dispatcher.sendFinalReply(payload);
|
||||||
|
}
|
||||||
|
await dispatcher.waitForIdle();
|
||||||
|
const counts = dispatcher.getQueuedCounts();
|
||||||
|
counts.final += routedFinalCount;
|
||||||
|
return { queuedFinal, counts };
|
||||||
|
}
|
||||||
|
|
||||||
const replyResult = await (params.replyResolver ?? getReplyFromConfig)(
|
const replyResult = await (params.replyResolver ?? getReplyFromConfig)(
|
||||||
ctx,
|
ctx,
|
||||||
{
|
{
|
||||||
|
|||||||
Reference in New Issue
Block a user