diff --git a/src/gateway/server.test.ts b/src/gateway/server.test.ts index 56298adc8..d9e64ecd0 100644 --- a/src/gateway/server.test.ts +++ b/src/gateway/server.test.ts @@ -10,7 +10,10 @@ import { emitAgentEvent } from "../infra/agent-events.js"; import { GatewayLockError } from "../infra/gateway-lock.js"; import { emitHeartbeatEvent } from "../infra/heartbeat-events.js"; import { PROTOCOL_VERSION } from "./protocol/index.js"; -import { startGatewayServer } from "./server.js"; +import { + __resetModelCatalogCacheForTest, + startGatewayServer, +} from "./server.js"; type BridgeClientInfo = { nodeId: string; @@ -58,6 +61,36 @@ const bridgeSendEvent = vi.hoisted(() => vi.fn()); const testTailnetIPv4 = vi.hoisted(() => ({ value: undefined as string | undefined, })); + +const piAiMock = vi.hoisted(() => ({ + enabled: false, + getModelsCalls: [] as string[], + providers: ["openai", "anthropic"], + modelsByProvider: {} as Record< + string, + Array<{ id: string; name?: string; contextWindow?: number }> + >, +})); + +vi.mock("@mariozechner/pi-ai", async () => { + const actual = await vi.importActual<{ + getProviders: () => string[]; + getModels: ( + provider: string, + ) => Array<{ id: string; name?: string; contextWindow?: number }>; + }>("@mariozechner/pi-ai"); + + return { + ...actual, + getProviders: () => + piAiMock.enabled ? piAiMock.providers : actual.getProviders(), + getModels: (provider: string) => { + if (!piAiMock.enabled) return actual.getModels(provider); + piAiMock.getModelsCalls.push(provider); + return piAiMock.modelsByProvider[provider] ?? []; + }, + }; +}); vi.mock("../infra/bridge/server.js", () => ({ startNodeBridgeServer: vi.fn(async (opts: BridgeStartOpts) => { bridgeStartCalls.push(opts); @@ -141,6 +174,11 @@ beforeEach(async () => { sessionStoreSaveDelayMs.value = 0; testTailnetIPv4.value = undefined; testGatewayBind = undefined; + __resetModelCatalogCacheForTest(); + piAiMock.enabled = false; + piAiMock.getModelsCalls.length = 0; + piAiMock.providers = ["openai", "anthropic"]; + piAiMock.modelsByProvider = { openai: [], anthropic: [] }; }); afterEach(async () => { @@ -346,6 +384,129 @@ describe("gateway server", () => { } }); + test("models.list returns model catalog", async () => { + piAiMock.enabled = true; + piAiMock.providers = ["openai", "anthropic"]; + piAiMock.modelsByProvider = { + openai: [ + { id: "gpt-test-z", contextWindow: 0 }, + { id: "gpt-test-a", name: "A-Model", contextWindow: 8000 }, + ], + anthropic: [ + { id: "claude-test-b", name: "B-Model", contextWindow: 1000 }, + { id: "claude-test-a", name: "A-Model", contextWindow: 200_000 }, + ], + }; + + const { server, ws } = await startServerWithClient(); + await connectOk(ws); + + const res1 = await rpcReq<{ + models: Array<{ + id: string; + name: string; + provider: string; + contextWindow?: number; + }>; + }>(ws, "models.list"); + + const res2 = await rpcReq<{ + models: Array<{ + id: string; + name: string; + provider: string; + contextWindow?: number; + }>; + }>(ws, "models.list"); + + expect(res1.ok).toBe(true); + expect(res2.ok).toBe(true); + + const models = res1.payload?.models ?? []; + expect(models).toEqual([ + { + id: "claude-test-a", + name: "A-Model", + provider: "anthropic", + contextWindow: 200_000, + }, + { + id: "claude-test-b", + name: "B-Model", + provider: "anthropic", + contextWindow: 1000, + }, + { + id: "gpt-test-a", + name: "A-Model", + provider: "openai", + contextWindow: 8000, + }, + { + id: "gpt-test-z", + name: "gpt-test-z", + provider: "openai", + }, + ]); + + // Cached across requests: should only call getModels once per provider. + expect(piAiMock.getModelsCalls).toEqual(["openai", "anthropic"]); + + ws.close(); + await server.close(); + }); + + test("models.list rejects unknown params", async () => { + piAiMock.providers = ["openai"]; + piAiMock.modelsByProvider = { openai: [{ id: "gpt-test-a", name: "A" }] }; + + const { server, ws } = await startServerWithClient(); + await connectOk(ws); + + const res = await rpcReq(ws, "models.list", { extra: true }); + expect(res.ok).toBe(false); + expect(res.error?.message ?? "").toMatch(/invalid models\.list params/i); + + ws.close(); + await server.close(); + }); + + test("bridge RPC supports models.list and validates params", async () => { + piAiMock.enabled = true; + piAiMock.providers = ["openai"]; + piAiMock.modelsByProvider = { openai: [{ id: "gpt-test-a", name: "A" }] }; + + const { server, ws } = await startServerWithClient(); + await connectOk(ws); + + const startCall = bridgeStartCalls.at(-1); + expect(startCall).toBeTruthy(); + + const okRes = await startCall?.onRequest?.("n1", { + id: "1", + method: "models.list", + paramsJSON: "{}", + }); + expect(okRes?.ok).toBe(true); + const okPayload = JSON.parse(String(okRes?.payloadJSON ?? "{}")) as { + models?: unknown; + }; + expect(Array.isArray(okPayload.models)).toBe(true); + + const badRes = await startCall?.onRequest?.("n1", { + id: "2", + method: "models.list", + paramsJSON: JSON.stringify({ extra: true }), + }); + expect(badRes?.ok).toBe(false); + expect(badRes && "error" in badRes ? badRes.error.code : "").toBe( + "INVALID_REQUEST", + ); + + ws.close(); + await server.close(); + }); + test("pushes voicewake.changed to nodes on connect and on updates", async () => { const homeDir = await fs.mkdtemp(path.join(os.tmpdir(), "clawdis-home-")); const prevHome = process.env.HOME; @@ -1702,19 +1863,26 @@ describe("gateway server", () => { ws, (o) => o.type === "res" && o.id === "presence1", ); + const providersP = onceMessage( + ws, + (o) => o.type === "res" && o.id === "providers1", + ); const sendReq = (id: string, method: string) => ws.send(JSON.stringify({ type: "req", id, method })); sendReq("health1", "health"); sendReq("status1", "status"); sendReq("presence1", "system-presence"); + sendReq("providers1", "providers.status"); const health = await healthP; const status = await statusP; const presence = await presenceP; + const providers = await providersP; expect(health.ok).toBe(true); expect(status.ok).toBe(true); expect(presence.ok).toBe(true); + expect(providers.ok).toBe(true); expect(Array.isArray(presence.payload)).toBe(true); ws.close();