feat: add OpenAI batch memory indexing

This commit is contained in:
Peter Steinberger
2026-01-17 22:31:12 +00:00
parent acc3eb11d0
commit a31a79396b
11 changed files with 587 additions and 38 deletions

View File

@@ -15,6 +15,13 @@ export type EmbeddingProviderResult = {
requestedProvider: "openai" | "local";
fallbackFrom?: "local";
fallbackReason?: string;
openAi?: OpenAiEmbeddingClient;
};
export type OpenAiEmbeddingClient = {
baseUrl: string;
headers: Record<string, string>;
model: string;
};
export type EmbeddingProviderOptions = {
@@ -46,7 +53,45 @@ function normalizeOpenAiModel(model: string): string {
async function createOpenAiEmbeddingProvider(
options: EmbeddingProviderOptions,
): Promise<EmbeddingProvider> {
): Promise<{ provider: EmbeddingProvider; client: OpenAiEmbeddingClient }> {
const client = await resolveOpenAiEmbeddingClient(options);
const url = `${client.baseUrl.replace(/\/$/, "")}/embeddings`;
const embed = async (input: string[]): Promise<number[][]> => {
if (input.length === 0) return [];
const res = await fetch(url, {
method: "POST",
headers: client.headers,
body: JSON.stringify({ model: client.model, input }),
});
if (!res.ok) {
const text = await res.text();
throw new Error(`openai embeddings failed: ${res.status} ${text}`);
}
const payload = (await res.json()) as {
data?: Array<{ embedding?: number[] }>;
};
const data = payload.data ?? [];
return data.map((entry) => entry.embedding ?? []);
};
return {
provider: {
id: "openai",
model: client.model,
embedQuery: async (text) => {
const [vec] = await embed([text]);
return vec ?? [];
},
embedBatch: embed,
},
client,
};
}
async function resolveOpenAiEmbeddingClient(
options: EmbeddingProviderOptions,
): Promise<OpenAiEmbeddingClient> {
const remote = options.remote;
const remoteApiKey = remote?.apiKey?.trim();
const remoteBaseUrl = remote?.baseUrl?.trim();
@@ -61,7 +106,6 @@ async function createOpenAiEmbeddingProvider(
const providerConfig = options.config.models?.providers?.openai;
const baseUrl = remoteBaseUrl || providerConfig?.baseUrl?.trim() || DEFAULT_OPENAI_BASE_URL;
const url = `${baseUrl.replace(/\/$/, "")}/embeddings`;
const headerOverrides = Object.assign({}, providerConfig?.headers, remote?.headers);
const headers: Record<string, string> = {
"Content-Type": "application/json",
@@ -69,34 +113,7 @@ async function createOpenAiEmbeddingProvider(
...headerOverrides,
};
const model = normalizeOpenAiModel(options.model);
const embed = async (input: string[]): Promise<number[][]> => {
if (input.length === 0) return [];
const res = await fetch(url, {
method: "POST",
headers,
body: JSON.stringify({ model, input }),
});
if (!res.ok) {
const text = await res.text();
throw new Error(`openai embeddings failed: ${res.status} ${text}`);
}
const payload = (await res.json()) as {
data?: Array<{ embedding?: number[] }>;
};
const data = payload.data ?? [];
return data.map((entry) => entry.embedding ?? []);
};
return {
id: "openai",
model,
embedQuery: async (text) => {
const [vec] = await embed([text]);
return vec ?? [];
},
embedBatch: embed,
};
return { baseUrl, headers, model };
}
async function createLocalEmbeddingProvider(
@@ -159,12 +176,13 @@ export async function createEmbeddingProvider(
const reason = formatLocalSetupError(err);
if (options.fallback === "openai") {
try {
const provider = await createOpenAiEmbeddingProvider(options);
const { provider, client } = await createOpenAiEmbeddingProvider(options);
return {
provider,
requestedProvider,
fallbackFrom: "local",
fallbackReason: reason,
openAi: client,
};
} catch (fallbackErr) {
throw new Error(`${reason}\n\nFallback to OpenAI failed: ${formatError(fallbackErr)}`);
@@ -173,8 +191,8 @@ export async function createEmbeddingProvider(
throw new Error(reason);
}
}
const provider = await createOpenAiEmbeddingProvider(options);
return { provider, requestedProvider };
const { provider, client } = await createOpenAiEmbeddingProvider(options);
return { provider, requestedProvider, openAi: client };
}
function formatError(err: unknown): string {