feat: add OpenAI batch memory indexing
This commit is contained in:
@@ -15,6 +15,13 @@ export type EmbeddingProviderResult = {
|
||||
requestedProvider: "openai" | "local";
|
||||
fallbackFrom?: "local";
|
||||
fallbackReason?: string;
|
||||
openAi?: OpenAiEmbeddingClient;
|
||||
};
|
||||
|
||||
export type OpenAiEmbeddingClient = {
|
||||
baseUrl: string;
|
||||
headers: Record<string, string>;
|
||||
model: string;
|
||||
};
|
||||
|
||||
export type EmbeddingProviderOptions = {
|
||||
@@ -46,7 +53,45 @@ function normalizeOpenAiModel(model: string): string {
|
||||
|
||||
async function createOpenAiEmbeddingProvider(
|
||||
options: EmbeddingProviderOptions,
|
||||
): Promise<EmbeddingProvider> {
|
||||
): Promise<{ provider: EmbeddingProvider; client: OpenAiEmbeddingClient }> {
|
||||
const client = await resolveOpenAiEmbeddingClient(options);
|
||||
const url = `${client.baseUrl.replace(/\/$/, "")}/embeddings`;
|
||||
|
||||
const embed = async (input: string[]): Promise<number[][]> => {
|
||||
if (input.length === 0) return [];
|
||||
const res = await fetch(url, {
|
||||
method: "POST",
|
||||
headers: client.headers,
|
||||
body: JSON.stringify({ model: client.model, input }),
|
||||
});
|
||||
if (!res.ok) {
|
||||
const text = await res.text();
|
||||
throw new Error(`openai embeddings failed: ${res.status} ${text}`);
|
||||
}
|
||||
const payload = (await res.json()) as {
|
||||
data?: Array<{ embedding?: number[] }>;
|
||||
};
|
||||
const data = payload.data ?? [];
|
||||
return data.map((entry) => entry.embedding ?? []);
|
||||
};
|
||||
|
||||
return {
|
||||
provider: {
|
||||
id: "openai",
|
||||
model: client.model,
|
||||
embedQuery: async (text) => {
|
||||
const [vec] = await embed([text]);
|
||||
return vec ?? [];
|
||||
},
|
||||
embedBatch: embed,
|
||||
},
|
||||
client,
|
||||
};
|
||||
}
|
||||
|
||||
async function resolveOpenAiEmbeddingClient(
|
||||
options: EmbeddingProviderOptions,
|
||||
): Promise<OpenAiEmbeddingClient> {
|
||||
const remote = options.remote;
|
||||
const remoteApiKey = remote?.apiKey?.trim();
|
||||
const remoteBaseUrl = remote?.baseUrl?.trim();
|
||||
@@ -61,7 +106,6 @@ async function createOpenAiEmbeddingProvider(
|
||||
|
||||
const providerConfig = options.config.models?.providers?.openai;
|
||||
const baseUrl = remoteBaseUrl || providerConfig?.baseUrl?.trim() || DEFAULT_OPENAI_BASE_URL;
|
||||
const url = `${baseUrl.replace(/\/$/, "")}/embeddings`;
|
||||
const headerOverrides = Object.assign({}, providerConfig?.headers, remote?.headers);
|
||||
const headers: Record<string, string> = {
|
||||
"Content-Type": "application/json",
|
||||
@@ -69,34 +113,7 @@ async function createOpenAiEmbeddingProvider(
|
||||
...headerOverrides,
|
||||
};
|
||||
const model = normalizeOpenAiModel(options.model);
|
||||
|
||||
const embed = async (input: string[]): Promise<number[][]> => {
|
||||
if (input.length === 0) return [];
|
||||
const res = await fetch(url, {
|
||||
method: "POST",
|
||||
headers,
|
||||
body: JSON.stringify({ model, input }),
|
||||
});
|
||||
if (!res.ok) {
|
||||
const text = await res.text();
|
||||
throw new Error(`openai embeddings failed: ${res.status} ${text}`);
|
||||
}
|
||||
const payload = (await res.json()) as {
|
||||
data?: Array<{ embedding?: number[] }>;
|
||||
};
|
||||
const data = payload.data ?? [];
|
||||
return data.map((entry) => entry.embedding ?? []);
|
||||
};
|
||||
|
||||
return {
|
||||
id: "openai",
|
||||
model,
|
||||
embedQuery: async (text) => {
|
||||
const [vec] = await embed([text]);
|
||||
return vec ?? [];
|
||||
},
|
||||
embedBatch: embed,
|
||||
};
|
||||
return { baseUrl, headers, model };
|
||||
}
|
||||
|
||||
async function createLocalEmbeddingProvider(
|
||||
@@ -159,12 +176,13 @@ export async function createEmbeddingProvider(
|
||||
const reason = formatLocalSetupError(err);
|
||||
if (options.fallback === "openai") {
|
||||
try {
|
||||
const provider = await createOpenAiEmbeddingProvider(options);
|
||||
const { provider, client } = await createOpenAiEmbeddingProvider(options);
|
||||
return {
|
||||
provider,
|
||||
requestedProvider,
|
||||
fallbackFrom: "local",
|
||||
fallbackReason: reason,
|
||||
openAi: client,
|
||||
};
|
||||
} catch (fallbackErr) {
|
||||
throw new Error(`${reason}\n\nFallback to OpenAI failed: ${formatError(fallbackErr)}`);
|
||||
@@ -173,8 +191,8 @@ export async function createEmbeddingProvider(
|
||||
throw new Error(reason);
|
||||
}
|
||||
}
|
||||
const provider = await createOpenAiEmbeddingProvider(options);
|
||||
return { provider, requestedProvider };
|
||||
const { provider, client } = await createOpenAiEmbeddingProvider(options);
|
||||
return { provider, requestedProvider, openAi: client };
|
||||
}
|
||||
|
||||
function formatError(err: unknown): string {
|
||||
|
||||
Reference in New Issue
Block a user