feat(chat): Swift chat parity (abort/sessions/stream)
This commit is contained in:
@@ -3,6 +3,8 @@ import {
|
||||
type AgentEvent,
|
||||
AgentEventSchema,
|
||||
AgentParamsSchema,
|
||||
type ChatAbortParams,
|
||||
ChatAbortParamsSchema,
|
||||
type ChatEvent,
|
||||
ChatEventSchema,
|
||||
ChatHistoryParamsSchema,
|
||||
@@ -137,6 +139,9 @@ export const validateCronRunsParams =
|
||||
ajv.compile<CronRunsParams>(CronRunsParamsSchema);
|
||||
export const validateChatHistoryParams = ajv.compile(ChatHistoryParamsSchema);
|
||||
export const validateChatSendParams = ajv.compile(ChatSendParamsSchema);
|
||||
export const validateChatAbortParams = ajv.compile<ChatAbortParams>(
|
||||
ChatAbortParamsSchema,
|
||||
);
|
||||
export const validateChatEvent = ajv.compile(ChatEventSchema);
|
||||
|
||||
export function formatValidationErrors(
|
||||
|
||||
@@ -480,6 +480,14 @@ export const ChatSendParamsSchema = Type.Object(
|
||||
{ additionalProperties: false },
|
||||
);
|
||||
|
||||
export const ChatAbortParamsSchema = Type.Object(
|
||||
{
|
||||
sessionKey: NonEmptyString,
|
||||
runId: NonEmptyString,
|
||||
},
|
||||
{ additionalProperties: false },
|
||||
);
|
||||
|
||||
export const ChatEventSchema = Type.Object(
|
||||
{
|
||||
runId: NonEmptyString,
|
||||
@@ -488,6 +496,7 @@ export const ChatEventSchema = Type.Object(
|
||||
state: Type.Union([
|
||||
Type.Literal("delta"),
|
||||
Type.Literal("final"),
|
||||
Type.Literal("aborted"),
|
||||
Type.Literal("error"),
|
||||
]),
|
||||
message: Type.Optional(Type.Unknown()),
|
||||
@@ -533,6 +542,7 @@ export const ProtocolSchemas: Record<string, TSchema> = {
|
||||
CronRunLogEntry: CronRunLogEntrySchema,
|
||||
ChatHistoryParams: ChatHistoryParamsSchema,
|
||||
ChatSendParams: ChatSendParamsSchema,
|
||||
ChatAbortParams: ChatAbortParamsSchema,
|
||||
ChatEvent: ChatEventSchema,
|
||||
TickEvent: TickEventSchema,
|
||||
ShutdownEvent: ShutdownEventSchema,
|
||||
@@ -570,6 +580,7 @@ export type CronRemoveParams = Static<typeof CronRemoveParamsSchema>;
|
||||
export type CronRunParams = Static<typeof CronRunParamsSchema>;
|
||||
export type CronRunsParams = Static<typeof CronRunsParamsSchema>;
|
||||
export type CronRunLogEntry = Static<typeof CronRunLogEntrySchema>;
|
||||
export type ChatAbortParams = Static<typeof ChatAbortParamsSchema>;
|
||||
export type ChatEvent = Static<typeof ChatEventSchema>;
|
||||
export type TickEvent = Static<typeof TickEventSchema>;
|
||||
export type ShutdownEvent = Static<typeof ShutdownEventSchema>;
|
||||
|
||||
@@ -1970,6 +1970,87 @@ describe("gateway server", () => {
|
||||
await server.close();
|
||||
});
|
||||
|
||||
test("chat.abort cancels an in-flight chat.send", async () => {
|
||||
const dir = await fs.mkdtemp(path.join(os.tmpdir(), "clawdis-gw-"));
|
||||
testSessionStorePath = path.join(dir, "sessions.json");
|
||||
await fs.writeFile(
|
||||
testSessionStorePath,
|
||||
JSON.stringify(
|
||||
{
|
||||
main: {
|
||||
sessionId: "sess-main",
|
||||
updatedAt: Date.now(),
|
||||
},
|
||||
},
|
||||
null,
|
||||
2,
|
||||
),
|
||||
"utf-8",
|
||||
);
|
||||
|
||||
const { server, ws } = await startServerWithClient();
|
||||
await connectOk(ws);
|
||||
|
||||
const spy = vi.mocked(agentCommand);
|
||||
spy.mockImplementationOnce(async (opts) => {
|
||||
const signal = (opts as { abortSignal?: AbortSignal }).abortSignal;
|
||||
await new Promise<void>((resolve) => {
|
||||
if (!signal) return resolve();
|
||||
if (signal.aborted) return resolve();
|
||||
signal.addEventListener("abort", () => resolve(), { once: true });
|
||||
});
|
||||
});
|
||||
|
||||
const abortedEventP = onceMessage(
|
||||
ws,
|
||||
(o) => o.type === "event" && o.event === "chat" && o.payload?.state === "aborted",
|
||||
);
|
||||
|
||||
ws.send(
|
||||
JSON.stringify({
|
||||
type: "req",
|
||||
id: "send-abort-1",
|
||||
method: "chat.send",
|
||||
params: {
|
||||
sessionKey: "main",
|
||||
message: "hello",
|
||||
idempotencyKey: "idem-abort-1",
|
||||
timeoutMs: 30_000,
|
||||
},
|
||||
}),
|
||||
);
|
||||
|
||||
await new Promise((r) => setTimeout(r, 10));
|
||||
|
||||
ws.send(
|
||||
JSON.stringify({
|
||||
type: "req",
|
||||
id: "abort-1",
|
||||
method: "chat.abort",
|
||||
params: { sessionKey: "main", runId: "idem-abort-1" },
|
||||
}),
|
||||
);
|
||||
|
||||
const abortRes = await onceMessage(
|
||||
ws,
|
||||
(o) => o.type === "res" && o.id === "abort-1",
|
||||
);
|
||||
expect(abortRes.ok).toBe(true);
|
||||
|
||||
const sendRes = await onceMessage(
|
||||
ws,
|
||||
(o) => o.type === "res" && o.id === "send-abort-1",
|
||||
);
|
||||
expect(sendRes.ok).toBe(true);
|
||||
|
||||
const evt = await abortedEventP;
|
||||
expect(evt.payload?.runId).toBe("idem-abort-1");
|
||||
expect(evt.payload?.sessionKey).toBe("main");
|
||||
|
||||
ws.close();
|
||||
await server.close();
|
||||
});
|
||||
|
||||
test("bridge RPC chat.history returns session messages", async () => {
|
||||
const dir = await fs.mkdtemp(path.join(os.tmpdir(), "clawdis-gw-"));
|
||||
testSessionStorePath = path.join(dir, "sessions.json");
|
||||
@@ -2029,6 +2110,54 @@ describe("gateway server", () => {
|
||||
await server.close();
|
||||
});
|
||||
|
||||
test("bridge RPC sessions.list returns session rows", async () => {
|
||||
const dir = await fs.mkdtemp(path.join(os.tmpdir(), "clawdis-gw-"));
|
||||
testSessionStorePath = path.join(dir, "sessions.json");
|
||||
await fs.writeFile(
|
||||
testSessionStorePath,
|
||||
JSON.stringify(
|
||||
{
|
||||
main: {
|
||||
sessionId: "sess-main",
|
||||
updatedAt: Date.now(),
|
||||
},
|
||||
},
|
||||
null,
|
||||
2,
|
||||
),
|
||||
"utf-8",
|
||||
);
|
||||
|
||||
const port = await getFreePort();
|
||||
const server = await startGatewayServer(port);
|
||||
const bridgeCall = bridgeStartCalls.at(-1);
|
||||
expect(bridgeCall?.onRequest).toBeDefined();
|
||||
|
||||
const res = await bridgeCall?.onRequest?.("ios-node", {
|
||||
id: "r1",
|
||||
method: "sessions.list",
|
||||
paramsJSON: JSON.stringify({
|
||||
includeGlobal: true,
|
||||
includeUnknown: false,
|
||||
limit: 50,
|
||||
}),
|
||||
});
|
||||
|
||||
expect(res?.ok).toBe(true);
|
||||
const payload = JSON.parse(
|
||||
String((res as { payloadJSON?: string }).payloadJSON ?? "{}"),
|
||||
) as {
|
||||
sessions?: unknown[];
|
||||
count?: number;
|
||||
path?: string;
|
||||
};
|
||||
expect(Array.isArray(payload.sessions)).toBe(true);
|
||||
expect(typeof payload.count).toBe("number");
|
||||
expect(typeof payload.path).toBe("string");
|
||||
|
||||
await server.close();
|
||||
});
|
||||
|
||||
test("bridge chat events are pushed to subscribed nodes", async () => {
|
||||
const dir = await fs.mkdtemp(path.join(os.tmpdir(), "clawdis-gw-"));
|
||||
testSessionStorePath = path.join(dir, "sessions.json");
|
||||
@@ -2092,6 +2221,13 @@ describe("gateway server", () => {
|
||||
// Wait a tick for the bridge send to happen.
|
||||
await new Promise((r) => setTimeout(r, 25));
|
||||
|
||||
expect(bridgeSendEvent).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
nodeId: "ios-node",
|
||||
event: "agent",
|
||||
}),
|
||||
);
|
||||
|
||||
expect(bridgeSendEvent).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
nodeId: "ios-node",
|
||||
|
||||
@@ -103,6 +103,7 @@ import {
|
||||
type SessionsPatchParams,
|
||||
type Snapshot,
|
||||
validateAgentParams,
|
||||
validateChatAbortParams,
|
||||
validateChatHistoryParams,
|
||||
validateChatSendParams,
|
||||
validateConnectParams,
|
||||
@@ -208,6 +209,7 @@ const METHODS = [
|
||||
"agent",
|
||||
// WebChat WebSocket-native chat methods
|
||||
"chat.history",
|
||||
"chat.abort",
|
||||
"chat.send",
|
||||
];
|
||||
|
||||
@@ -282,7 +284,11 @@ const chatRunSessions = new Map<
|
||||
string,
|
||||
{ sessionKey: string; clientRunId: string }
|
||||
>();
|
||||
const chatRunBuffers = new Map<string, string[]>();
|
||||
const chatRunBuffers = new Map<string, string>();
|
||||
const chatAbortControllers = new Map<
|
||||
string,
|
||||
{ controller: AbortController; sessionId: string; sessionKey: string }
|
||||
>();
|
||||
|
||||
const getGatewayToken = () => process.env.CLAWDIS_GATEWAY_TOKEN;
|
||||
|
||||
@@ -903,6 +909,120 @@ export async function startGatewayServer(
|
||||
const snap = await refreshHealthSnapshot({ probe: false });
|
||||
return { ok: true, payloadJSON: JSON.stringify(snap) };
|
||||
}
|
||||
case "sessions.list": {
|
||||
const params = parseParams();
|
||||
if (!validateSessionsListParams(params)) {
|
||||
return {
|
||||
ok: false,
|
||||
error: {
|
||||
code: ErrorCodes.INVALID_REQUEST,
|
||||
message: `invalid sessions.list params: ${formatValidationErrors(validateSessionsListParams.errors)}`,
|
||||
},
|
||||
};
|
||||
}
|
||||
const p = params as SessionsListParams;
|
||||
const cfg = loadConfig();
|
||||
const storePath = resolveStorePath(cfg.inbound?.session?.store);
|
||||
const store = loadSessionStore(storePath);
|
||||
const result = listSessionsFromStore({
|
||||
cfg,
|
||||
storePath,
|
||||
store,
|
||||
opts: p,
|
||||
});
|
||||
return { ok: true, payloadJSON: JSON.stringify(result) };
|
||||
}
|
||||
case "sessions.patch": {
|
||||
const params = parseParams();
|
||||
if (!validateSessionsPatchParams(params)) {
|
||||
return {
|
||||
ok: false,
|
||||
error: {
|
||||
code: ErrorCodes.INVALID_REQUEST,
|
||||
message: `invalid sessions.patch params: ${formatValidationErrors(validateSessionsPatchParams.errors)}`,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
const p = params as SessionsPatchParams;
|
||||
const key = String(p.key ?? "").trim();
|
||||
if (!key) {
|
||||
return {
|
||||
ok: false,
|
||||
error: { code: ErrorCodes.INVALID_REQUEST, message: "key required" },
|
||||
};
|
||||
}
|
||||
|
||||
const cfg = loadConfig();
|
||||
const storePath = resolveStorePath(cfg.inbound?.session?.store);
|
||||
const store = loadSessionStore(storePath);
|
||||
const now = Date.now();
|
||||
|
||||
const existing = store[key];
|
||||
const next: SessionEntry = existing
|
||||
? {
|
||||
...existing,
|
||||
updatedAt: Math.max(existing.updatedAt ?? 0, now),
|
||||
}
|
||||
: { sessionId: randomUUID(), updatedAt: now };
|
||||
|
||||
if ("thinkingLevel" in p) {
|
||||
const raw = p.thinkingLevel;
|
||||
if (raw === null) {
|
||||
delete next.thinkingLevel;
|
||||
} else if (raw !== undefined) {
|
||||
const normalized = normalizeThinkLevel(String(raw));
|
||||
if (!normalized) {
|
||||
return {
|
||||
ok: false,
|
||||
error: {
|
||||
code: ErrorCodes.INVALID_REQUEST,
|
||||
message: `invalid thinkingLevel: ${String(raw)}`,
|
||||
},
|
||||
};
|
||||
}
|
||||
next.thinkingLevel = normalized;
|
||||
}
|
||||
}
|
||||
|
||||
if ("verboseLevel" in p) {
|
||||
const raw = p.verboseLevel;
|
||||
if (raw === null) {
|
||||
delete next.verboseLevel;
|
||||
} else if (raw !== undefined) {
|
||||
const normalized = normalizeVerboseLevel(String(raw));
|
||||
if (!normalized) {
|
||||
return {
|
||||
ok: false,
|
||||
error: {
|
||||
code: ErrorCodes.INVALID_REQUEST,
|
||||
message: `invalid verboseLevel: ${String(raw)}`,
|
||||
},
|
||||
};
|
||||
}
|
||||
next.verboseLevel = normalized;
|
||||
}
|
||||
}
|
||||
|
||||
if ("syncing" in p) {
|
||||
const raw = p.syncing;
|
||||
if (raw === null) {
|
||||
delete next.syncing;
|
||||
} else if (raw !== undefined) {
|
||||
next.syncing = raw as boolean | string;
|
||||
}
|
||||
}
|
||||
|
||||
store[key] = next;
|
||||
await saveSessionStore(storePath, store);
|
||||
const payload: SessionsPatchResult = {
|
||||
ok: true,
|
||||
path: storePath,
|
||||
key,
|
||||
entry: next,
|
||||
};
|
||||
return { ok: true, payloadJSON: JSON.stringify(payload) };
|
||||
}
|
||||
case "chat.history": {
|
||||
const params = parseParams();
|
||||
if (!validateChatHistoryParams(params)) {
|
||||
@@ -945,6 +1065,60 @@ export async function startGatewayServer(
|
||||
}),
|
||||
};
|
||||
}
|
||||
case "chat.abort": {
|
||||
const params = parseParams();
|
||||
if (!validateChatAbortParams(params)) {
|
||||
return {
|
||||
ok: false,
|
||||
error: {
|
||||
code: ErrorCodes.INVALID_REQUEST,
|
||||
message: `invalid chat.abort params: ${formatValidationErrors(validateChatAbortParams.errors)}`,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
const { sessionKey, runId } = params as {
|
||||
sessionKey: string;
|
||||
runId: string;
|
||||
};
|
||||
const active = chatAbortControllers.get(runId);
|
||||
if (!active) {
|
||||
return {
|
||||
ok: true,
|
||||
payloadJSON: JSON.stringify({ ok: true, aborted: false }),
|
||||
};
|
||||
}
|
||||
if (active.sessionKey !== sessionKey) {
|
||||
return {
|
||||
ok: false,
|
||||
error: {
|
||||
code: ErrorCodes.INVALID_REQUEST,
|
||||
message: "runId does not match sessionKey",
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
active.controller.abort();
|
||||
chatAbortControllers.delete(runId);
|
||||
chatRunBuffers.delete(runId);
|
||||
const current = chatRunSessions.get(active.sessionId);
|
||||
if (current?.clientRunId === runId && current.sessionKey === sessionKey) {
|
||||
chatRunSessions.delete(active.sessionId);
|
||||
}
|
||||
|
||||
const payload = {
|
||||
runId,
|
||||
sessionKey,
|
||||
seq: (agentRunSeq.get(active.sessionId) ?? 0) + 1,
|
||||
state: "aborted" as const,
|
||||
};
|
||||
broadcast("chat", payload);
|
||||
bridgeSendToSession(sessionKey, "chat", payload);
|
||||
return {
|
||||
ok: true,
|
||||
payloadJSON: JSON.stringify({ ok: true, aborted: true }),
|
||||
};
|
||||
}
|
||||
case "chat.send": {
|
||||
const params = parseParams();
|
||||
if (!validateChatSendParams(params)) {
|
||||
@@ -1052,6 +1226,13 @@ export async function startGatewayServer(
|
||||
};
|
||||
}
|
||||
|
||||
const abortController = new AbortController();
|
||||
chatAbortControllers.set(clientRunId, {
|
||||
controller: abortController,
|
||||
sessionId,
|
||||
sessionKey: p.sessionKey,
|
||||
});
|
||||
|
||||
try {
|
||||
await agentCommand(
|
||||
{
|
||||
@@ -1061,6 +1242,7 @@ export async function startGatewayServer(
|
||||
deliver: p.deliver,
|
||||
timeout: Math.ceil(timeoutMs / 1000).toString(),
|
||||
surface: `Iris(${nodeId})`,
|
||||
abortSignal: abortController.signal,
|
||||
},
|
||||
defaultRuntime,
|
||||
deps,
|
||||
@@ -1095,6 +1277,8 @@ export async function startGatewayServer(
|
||||
message: String(err),
|
||||
},
|
||||
};
|
||||
} finally {
|
||||
chatAbortControllers.delete(clientRunId);
|
||||
}
|
||||
}
|
||||
default:
|
||||
@@ -1504,26 +1688,25 @@ export async function startGatewayServer(
|
||||
agentRunSeq.set(evt.runId, evt.seq);
|
||||
broadcast("agent", evt);
|
||||
|
||||
const chatLink = chatRunSessions.get(evt.runId);
|
||||
if (chatLink) {
|
||||
// Map agent bus events to chat events for WS WebChat clients.
|
||||
// Use clientRunId so the webchat can correlate with its pending promise.
|
||||
const { sessionKey, clientRunId } = chatLink;
|
||||
const chatLink = chatRunSessions.get(evt.runId);
|
||||
if (chatLink) {
|
||||
// Map agent bus events to chat events for WS WebChat clients.
|
||||
// Use clientRunId so the webchat can correlate with its pending promise.
|
||||
const { sessionKey, clientRunId } = chatLink;
|
||||
bridgeSendToSession(sessionKey, "agent", evt);
|
||||
const base = {
|
||||
runId: clientRunId,
|
||||
sessionKey,
|
||||
seq: evt.seq,
|
||||
};
|
||||
if (evt.stream === "assistant" && typeof evt.data?.text === "string") {
|
||||
const buf = chatRunBuffers.get(clientRunId) ?? [];
|
||||
buf.push(evt.data.text);
|
||||
chatRunBuffers.set(clientRunId, buf);
|
||||
chatRunBuffers.set(clientRunId, evt.data.text);
|
||||
} else if (
|
||||
evt.stream === "job" &&
|
||||
typeof evt.data?.state === "string" &&
|
||||
(evt.data.state === "done" || evt.data.state === "error")
|
||||
) {
|
||||
const text = chatRunBuffers.get(clientRunId)?.join("\n").trim() ?? "";
|
||||
const text = chatRunBuffers.get(clientRunId)?.trim() ?? "";
|
||||
chatRunBuffers.delete(clientRunId);
|
||||
if (evt.data.state === "done") {
|
||||
const payload = {
|
||||
@@ -1962,6 +2145,62 @@ export async function startGatewayServer(
|
||||
});
|
||||
break;
|
||||
}
|
||||
case "chat.abort": {
|
||||
const params = (req.params ?? {}) as Record<string, unknown>;
|
||||
if (!validateChatAbortParams(params)) {
|
||||
respond(
|
||||
false,
|
||||
undefined,
|
||||
errorShape(
|
||||
ErrorCodes.INVALID_REQUEST,
|
||||
`invalid chat.abort params: ${formatValidationErrors(validateChatAbortParams.errors)}`,
|
||||
),
|
||||
);
|
||||
break;
|
||||
}
|
||||
const { sessionKey, runId } = params as {
|
||||
sessionKey: string;
|
||||
runId: string;
|
||||
};
|
||||
const active = chatAbortControllers.get(runId);
|
||||
if (!active) {
|
||||
respond(true, { ok: true, aborted: false });
|
||||
break;
|
||||
}
|
||||
if (active.sessionKey !== sessionKey) {
|
||||
respond(
|
||||
false,
|
||||
undefined,
|
||||
errorShape(
|
||||
ErrorCodes.INVALID_REQUEST,
|
||||
"runId does not match sessionKey",
|
||||
),
|
||||
);
|
||||
break;
|
||||
}
|
||||
|
||||
active.controller.abort();
|
||||
chatAbortControllers.delete(runId);
|
||||
chatRunBuffers.delete(runId);
|
||||
const current = chatRunSessions.get(active.sessionId);
|
||||
if (
|
||||
current?.clientRunId === runId &&
|
||||
current.sessionKey === sessionKey
|
||||
) {
|
||||
chatRunSessions.delete(active.sessionId);
|
||||
}
|
||||
|
||||
const payload = {
|
||||
runId,
|
||||
sessionKey,
|
||||
seq: (agentRunSeq.get(active.sessionId) ?? 0) + 1,
|
||||
state: "aborted" as const,
|
||||
};
|
||||
broadcast("chat", payload);
|
||||
bridgeSendToSession(sessionKey, "chat", payload);
|
||||
respond(true, { ok: true, aborted: true });
|
||||
break;
|
||||
}
|
||||
case "chat.send": {
|
||||
const params = (req.params ?? {}) as Record<string, unknown>;
|
||||
if (!validateChatSendParams(params)) {
|
||||
@@ -2061,6 +2300,13 @@ export async function startGatewayServer(
|
||||
}
|
||||
|
||||
try {
|
||||
const abortController = new AbortController();
|
||||
chatAbortControllers.set(clientRunId, {
|
||||
controller: abortController,
|
||||
sessionId,
|
||||
sessionKey: p.sessionKey,
|
||||
});
|
||||
|
||||
await agentCommand(
|
||||
{
|
||||
message: messageWithAttachments,
|
||||
@@ -2069,6 +2315,7 @@ export async function startGatewayServer(
|
||||
deliver: p.deliver,
|
||||
timeout: Math.ceil(timeoutMs / 1000).toString(),
|
||||
surface: "WebChat",
|
||||
abortSignal: abortController.signal,
|
||||
},
|
||||
defaultRuntime,
|
||||
deps,
|
||||
@@ -2100,6 +2347,8 @@ export async function startGatewayServer(
|
||||
runId: clientRunId,
|
||||
error: formatForLog(err),
|
||||
});
|
||||
} finally {
|
||||
chatAbortControllers.delete(clientRunId);
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user