fix: memory search remote overrides (#819) (thanks @mukhtharcm)
This commit is contained in:
110
src/memory/embeddings.test.ts
Normal file
110
src/memory/embeddings.test.ts
Normal file
@@ -0,0 +1,110 @@
|
||||
import { afterEach, describe, expect, it, vi } from "vitest";
|
||||
|
||||
vi.mock("../agents/model-auth.js", () => ({
|
||||
resolveApiKeyForProvider: vi.fn(),
|
||||
}));
|
||||
|
||||
const createFetchMock = () =>
|
||||
vi.fn(async () => ({
|
||||
ok: true,
|
||||
status: 200,
|
||||
json: async () => ({ data: [{ embedding: [1, 2, 3] }] }),
|
||||
})) as unknown as typeof fetch;
|
||||
|
||||
describe("embedding provider remote overrides", () => {
|
||||
afterEach(() => {
|
||||
vi.resetAllMocks();
|
||||
vi.unstubAllGlobals();
|
||||
});
|
||||
|
||||
it("uses remote baseUrl/apiKey and merges headers", async () => {
|
||||
const fetchMock = createFetchMock();
|
||||
vi.stubGlobal("fetch", fetchMock);
|
||||
|
||||
const { createEmbeddingProvider } = await import("./embeddings.js");
|
||||
const authModule = await import("../agents/model-auth.js");
|
||||
vi.mocked(authModule.resolveApiKeyForProvider).mockResolvedValue({
|
||||
apiKey: "provider-key",
|
||||
});
|
||||
|
||||
const cfg = {
|
||||
models: {
|
||||
providers: {
|
||||
openai: {
|
||||
baseUrl: "https://provider.example/v1",
|
||||
headers: {
|
||||
"X-Provider": "p",
|
||||
"X-Shared": "provider",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
const result = await createEmbeddingProvider({
|
||||
config: cfg as never,
|
||||
provider: "openai",
|
||||
remote: {
|
||||
baseUrl: "https://remote.example/v1",
|
||||
apiKey: " remote-key ",
|
||||
headers: {
|
||||
"X-Shared": "remote",
|
||||
"X-Remote": "r",
|
||||
},
|
||||
},
|
||||
model: "text-embedding-3-small",
|
||||
fallback: "openai",
|
||||
});
|
||||
|
||||
await result.provider.embedQuery("hello");
|
||||
|
||||
expect(authModule.resolveApiKeyForProvider).not.toHaveBeenCalled();
|
||||
const [url, init] = fetchMock.mock.calls[0] ?? [];
|
||||
expect(url).toBe("https://remote.example/v1/embeddings");
|
||||
const headers = (init?.headers ?? {}) as Record<string, string>;
|
||||
expect(headers.Authorization).toBe("Bearer remote-key");
|
||||
expect(headers["Content-Type"]).toBe("application/json");
|
||||
expect(headers["X-Provider"]).toBe("p");
|
||||
expect(headers["X-Shared"]).toBe("remote");
|
||||
expect(headers["X-Remote"]).toBe("r");
|
||||
});
|
||||
|
||||
it("falls back to resolved api key when remote apiKey is blank", async () => {
|
||||
const fetchMock = createFetchMock();
|
||||
vi.stubGlobal("fetch", fetchMock);
|
||||
|
||||
const { createEmbeddingProvider } = await import("./embeddings.js");
|
||||
const authModule = await import("../agents/model-auth.js");
|
||||
vi.mocked(authModule.resolveApiKeyForProvider).mockResolvedValue({
|
||||
apiKey: "provider-key",
|
||||
});
|
||||
|
||||
const cfg = {
|
||||
models: {
|
||||
providers: {
|
||||
openai: {
|
||||
baseUrl: "https://provider.example/v1",
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
const result = await createEmbeddingProvider({
|
||||
config: cfg as never,
|
||||
provider: "openai",
|
||||
remote: {
|
||||
baseUrl: "https://remote.example/v1",
|
||||
apiKey: " ",
|
||||
},
|
||||
model: "text-embedding-3-small",
|
||||
fallback: "openai",
|
||||
});
|
||||
|
||||
await result.provider.embedQuery("hello");
|
||||
|
||||
expect(authModule.resolveApiKeyForProvider).toHaveBeenCalledTimes(1);
|
||||
const headers =
|
||||
(fetchMock.mock.calls[0]?.[1]?.headers as Record<string, string>) ?? {};
|
||||
expect(headers.Authorization).toBe("Bearer provider-key");
|
||||
});
|
||||
});
|
||||
@@ -47,10 +47,12 @@ function normalizeOpenAiModel(model: string): string {
|
||||
async function createOpenAiEmbeddingProvider(
|
||||
options: EmbeddingProviderOptions,
|
||||
): Promise<EmbeddingProvider> {
|
||||
const remote = options.config.agents?.defaults?.memorySearch?.remote;
|
||||
const remote = options.remote;
|
||||
const remoteApiKey = remote?.apiKey?.trim();
|
||||
const remoteBaseUrl = remote?.baseUrl?.trim();
|
||||
|
||||
const { apiKey } = remote?.apiKey
|
||||
? { apiKey: remote.apiKey }
|
||||
const { apiKey } = remoteApiKey
|
||||
? { apiKey: remoteApiKey }
|
||||
: await resolveApiKeyForProvider({
|
||||
provider: "openai",
|
||||
cfg: options.config,
|
||||
@@ -59,11 +61,13 @@ async function createOpenAiEmbeddingProvider(
|
||||
|
||||
const providerConfig = options.config.models?.providers?.openai;
|
||||
const baseUrl =
|
||||
remote?.baseUrl?.trim() ||
|
||||
providerConfig?.baseUrl?.trim() ||
|
||||
DEFAULT_OPENAI_BASE_URL;
|
||||
remoteBaseUrl || providerConfig?.baseUrl?.trim() || DEFAULT_OPENAI_BASE_URL;
|
||||
const url = `${baseUrl.replace(/\/$/, "")}/embeddings`;
|
||||
const headerOverrides = remote?.headers ?? providerConfig?.headers ?? {};
|
||||
const headerOverrides = Object.assign(
|
||||
{},
|
||||
providerConfig?.headers,
|
||||
remote?.headers,
|
||||
);
|
||||
const headers: Record<string, string> = {
|
||||
"Content-Type": "application/json",
|
||||
Authorization: `Bearer ${apiKey}`,
|
||||
|
||||
Reference in New Issue
Block a user