refactor: migrate embedded pi to sdk

This commit is contained in:
Peter Steinberger
2025-12-22 18:05:44 +01:00
parent 79c0fd27a0
commit 2d7c5f8c53
12 changed files with 276 additions and 386 deletions

View File

@@ -63,32 +63,28 @@ const testTailnetIPv4 = vi.hoisted(() => ({
value: undefined as string | undefined,
}));
const piAiMock = vi.hoisted(() => ({
const piSdkMock = vi.hoisted(() => ({
enabled: false,
getModelsCalls: [] as string[],
providers: ["openai", "anthropic"],
modelsByProvider: {} as Record<
string,
Array<{ id: string; name?: string; contextWindow?: number }>
>,
discoverCalls: 0,
models: [] as Array<{
id: string;
name?: string;
provider: string;
contextWindow?: number;
}>,
}));
vi.mock("@mariozechner/pi-ai", async () => {
const actual = await vi.importActual<{
getProviders: () => string[];
getModels: (
provider: string,
) => Array<{ id: string; name?: string; contextWindow?: number }>;
}>("@mariozechner/pi-ai");
vi.mock("@mariozechner/pi-coding-agent", async () => {
const actual = await vi.importActual<
typeof import("@mariozechner/pi-coding-agent")
>("@mariozechner/pi-coding-agent");
return {
...actual,
getProviders: () =>
piAiMock.enabled ? piAiMock.providers : actual.getProviders(),
getModels: (provider: string) => {
if (!piAiMock.enabled) return actual.getModels(provider);
piAiMock.getModelsCalls.push(provider);
return piAiMock.modelsByProvider[provider] ?? [];
discoverModels: () => {
if (!piSdkMock.enabled) return actual.discoverModels();
piSdkMock.discoverCalls += 1;
return piSdkMock.models;
},
};
});
@@ -252,10 +248,9 @@ beforeEach(async () => {
testGatewayBind = undefined;
testGatewayAuth = undefined;
__resetModelCatalogCacheForTest();
piAiMock.enabled = false;
piAiMock.getModelsCalls.length = 0;
piAiMock.providers = ["openai", "anthropic"];
piAiMock.modelsByProvider = { openai: [], anthropic: [] };
piSdkMock.enabled = false;
piSdkMock.discoverCalls = 0;
piSdkMock.models = [];
});
afterEach(async () => {
@@ -471,18 +466,28 @@ describe("gateway server", () => {
});
test("models.list returns model catalog", async () => {
piAiMock.enabled = true;
piAiMock.providers = ["openai", "anthropic"];
piAiMock.modelsByProvider = {
openai: [
{ id: "gpt-test-z", contextWindow: 0 },
{ id: "gpt-test-a", name: "A-Model", contextWindow: 8000 },
],
anthropic: [
{ id: "claude-test-b", name: "B-Model", contextWindow: 1000 },
{ id: "claude-test-a", name: "A-Model", contextWindow: 200_000 },
],
};
piSdkMock.enabled = true;
piSdkMock.models = [
{ id: "gpt-test-z", provider: "openai", contextWindow: 0 },
{
id: "gpt-test-a",
name: "A-Model",
provider: "openai",
contextWindow: 8000,
},
{
id: "claude-test-b",
name: "B-Model",
provider: "anthropic",
contextWindow: 1000,
},
{
id: "claude-test-a",
name: "A-Model",
provider: "anthropic",
contextWindow: 200_000,
},
];
const { server, ws } = await startServerWithClient();
await connectOk(ws);
@@ -535,16 +540,16 @@ describe("gateway server", () => {
},
]);
// Cached across requests: should only call getModels once per provider.
expect(piAiMock.getModelsCalls).toEqual(["openai", "anthropic"]);
// Cached across requests: should only call discoverModels once.
expect(piSdkMock.discoverCalls).toBe(1);
ws.close();
await server.close();
});
test("models.list rejects unknown params", async () => {
piAiMock.providers = ["openai"];
piAiMock.modelsByProvider = { openai: [{ id: "gpt-test-a", name: "A" }] };
piSdkMock.enabled = true;
piSdkMock.models = [{ id: "gpt-test-a", name: "A", provider: "openai" }];
const { server, ws } = await startServerWithClient();
await connectOk(ws);
@@ -558,9 +563,8 @@ describe("gateway server", () => {
});
test("bridge RPC supports models.list and validates params", async () => {
piAiMock.enabled = true;
piAiMock.providers = ["openai"];
piAiMock.modelsByProvider = { openai: [{ id: "gpt-test-a", name: "A" }] };
piSdkMock.enabled = true;
piSdkMock.models = [{ id: "gpt-test-a", name: "A", provider: "openai" }];
const { server, ws } = await startServerWithClient();
await connectOk(ws);
@@ -2503,7 +2507,7 @@ describe("gateway server", () => {
await server.close();
});
test("chat.history caps payload bytes", async () => {
test("chat.history caps payload bytes", { timeout: 15_000 }, async () => {
const dir = await fs.mkdtemp(path.join(os.tmpdir(), "clawdis-gw-"));
testSessionStorePath = path.join(dir, "sessions.json");
await fs.writeFile(
@@ -2524,9 +2528,9 @@ describe("gateway server", () => {
const { server, ws } = await startServerWithClient();
await connectOk(ws);
const bigText = "x".repeat(300_000);
const bigText = "x".repeat(200_000);
const largeLines: string[] = [];
for (let i = 0; i < 60; i += 1) {
for (let i = 0; i < 40; i += 1) {
largeLines.push(
JSON.stringify({
message: {