feat: add gemini memory embeddings

This commit is contained in:
Peter Steinberger
2026-01-18 09:09:13 +00:00
parent b015c7e5ad
commit a3a4996adb
13 changed files with 482 additions and 24 deletions

View File

@@ -107,6 +107,39 @@ describe("embedding provider remote overrides", () => {
const headers = (fetchMock.mock.calls[0]?.[1]?.headers as Record<string, string>) ?? {};
expect(headers.Authorization).toBe("Bearer provider-key");
});
it("uses gemini embedContent endpoint with x-goog-api-key", async () => {
const fetchMock = vi.fn(async () => ({
ok: true,
status: 200,
json: async () => ({ embedding: { values: [1, 2, 3] } }),
})) as unknown as typeof fetch;
vi.stubGlobal("fetch", fetchMock);
const { createEmbeddingProvider } = await import("./embeddings.js");
const authModule = await import("../agents/model-auth.js");
vi.mocked(authModule.resolveApiKeyForProvider).mockResolvedValue({
apiKey: "gemini-key",
});
const result = await createEmbeddingProvider({
config: {} as never,
provider: "gemini",
remote: {
baseUrl: "https://gemini.example/v1beta",
},
model: "gemini-embedding-001",
fallback: "openai",
});
await result.provider.embedQuery("hello");
const [url, init] = fetchMock.mock.calls[0] ?? [];
expect(url).toBe("https://gemini.example/v1beta/models/gemini-embedding-001:embedContent");
const headers = (init?.headers ?? {}) as Record<string, string>;
expect(headers["x-goog-api-key"]).toBe("gemini-key");
expect(headers["Content-Type"]).toBe("application/json");
});
});
describe("embedding provider local fallback", () => {

View File

@@ -12,10 +12,11 @@ export type EmbeddingProvider = {
export type EmbeddingProviderResult = {
provider: EmbeddingProvider;
requestedProvider: "openai" | "local";
requestedProvider: "openai" | "gemini" | "local";
fallbackFrom?: "local";
fallbackReason?: string;
openAi?: OpenAiEmbeddingClient;
gemini?: GeminiEmbeddingClient;
};
export type OpenAiEmbeddingClient = {
@@ -24,10 +25,16 @@ export type OpenAiEmbeddingClient = {
model: string;
};
export type GeminiEmbeddingClient = {
baseUrl: string;
headers: Record<string, string>;
model: string;
};
export type EmbeddingProviderOptions = {
config: ClawdbotConfig;
agentDir?: string;
provider: "openai" | "local";
provider: "openai" | "gemini" | "local";
remote?: {
baseUrl?: string;
apiKey?: string;
@@ -43,6 +50,8 @@ export type EmbeddingProviderOptions = {
const DEFAULT_OPENAI_BASE_URL = "https://api.openai.com/v1";
const DEFAULT_LOCAL_MODEL = "hf:ggml-org/embeddinggemma-300M-GGUF/embeddinggemma-300M-Q8_0.gguf";
const DEFAULT_GEMINI_BASE_URL = "https://generativelanguage.googleapis.com/v1beta";
const DEFAULT_GEMINI_MODEL = "gemini-embedding-001";
function normalizeOpenAiModel(model: string): string {
const trimmed = model.trim();
@@ -51,6 +60,14 @@ function normalizeOpenAiModel(model: string): string {
return trimmed;
}
function normalizeGeminiModel(model: string): string {
const trimmed = model.trim();
if (!trimmed) return DEFAULT_GEMINI_MODEL;
if (trimmed.startsWith("models/")) return trimmed.slice("models/".length);
if (trimmed.startsWith("google/")) return trimmed.slice("google/".length);
return trimmed;
}
async function createOpenAiEmbeddingProvider(
options: EmbeddingProviderOptions,
): Promise<{ provider: EmbeddingProvider; client: OpenAiEmbeddingClient }> {
@@ -89,6 +106,83 @@ async function createOpenAiEmbeddingProvider(
};
}
function extractGeminiEmbeddingValues(entry: unknown): number[] {
if (!entry || typeof entry !== "object") return [];
const record = entry as { values?: unknown; embedding?: { values?: unknown } };
const values = record.values ?? record.embedding?.values;
if (!Array.isArray(values)) return [];
return values.filter((value): value is number => typeof value === "number");
}
function parseGeminiEmbeddings(payload: unknown): number[][] {
if (!payload || typeof payload !== "object") return [];
const data = payload as { embedding?: unknown; embeddings?: unknown[] };
if (Array.isArray(data.embeddings)) {
return data.embeddings.map((entry) => extractGeminiEmbeddingValues(entry));
}
if (data.embedding) {
return [extractGeminiEmbeddingValues(data.embedding)];
}
return [];
}
async function createGeminiEmbeddingProvider(
options: EmbeddingProviderOptions,
): Promise<{ provider: EmbeddingProvider; client: GeminiEmbeddingClient }> {
const client = await resolveGeminiEmbeddingClient(options);
const baseUrl = client.baseUrl.replace(/\/$/, "");
const model = `models/${client.model}`;
const embedContent = async (input: string): Promise<number[]> => {
const res = await fetch(`${baseUrl}/${model}:embedContent`, {
method: "POST",
headers: client.headers,
body: JSON.stringify({
model,
content: { parts: [{ text: input }] },
}),
});
if (!res.ok) {
const text = await res.text();
throw new Error(`gemini embeddings failed: ${res.status} ${text}`);
}
const payload = await res.json();
const embeddings = parseGeminiEmbeddings(payload);
return embeddings[0] ?? [];
};
const embedBatch = async (input: string[]): Promise<number[][]> => {
if (input.length === 0) return [];
const res = await fetch(`${baseUrl}/${model}:batchEmbedContents`, {
method: "POST",
headers: client.headers,
body: JSON.stringify({
requests: input.map((text) => ({
model,
content: { parts: [{ text }] },
})),
}),
});
if (!res.ok) {
const text = await res.text();
throw new Error(`gemini embeddings failed: ${res.status} ${text}`);
}
const payload = await res.json();
const embeddings = parseGeminiEmbeddings(payload);
return embeddings;
};
return {
provider: {
id: "gemini",
model: client.model,
embedQuery: embedContent,
embedBatch,
},
client,
};
}
async function resolveOpenAiEmbeddingClient(
options: EmbeddingProviderOptions,
): Promise<OpenAiEmbeddingClient> {
@@ -116,6 +210,33 @@ async function resolveOpenAiEmbeddingClient(
return { baseUrl, headers, model };
}
async function resolveGeminiEmbeddingClient(
options: EmbeddingProviderOptions,
): Promise<GeminiEmbeddingClient> {
const remote = options.remote;
const remoteApiKey = remote?.apiKey?.trim();
const remoteBaseUrl = remote?.baseUrl?.trim();
const { apiKey } = remoteApiKey
? { apiKey: remoteApiKey }
: await resolveApiKeyForProvider({
provider: "google",
cfg: options.config,
agentDir: options.agentDir,
});
const providerConfig = options.config.models?.providers?.google;
const baseUrl = remoteBaseUrl || providerConfig?.baseUrl?.trim() || DEFAULT_GEMINI_BASE_URL;
const headerOverrides = Object.assign({}, providerConfig?.headers, remote?.headers);
const headers: Record<string, string> = {
"Content-Type": "application/json",
"x-goog-api-key": apiKey,
...headerOverrides,
};
const model = normalizeGeminiModel(options.model);
return { baseUrl, headers, model };
}
async function createLocalEmbeddingProvider(
options: EmbeddingProviderOptions,
): Promise<EmbeddingProvider> {
@@ -168,6 +289,10 @@ export async function createEmbeddingProvider(
options: EmbeddingProviderOptions,
): Promise<EmbeddingProviderResult> {
const requestedProvider = options.provider;
if (options.provider === "gemini") {
const { provider, client } = await createGeminiEmbeddingProvider(options);
return { provider, requestedProvider, gemini: client };
}
if (options.provider === "local") {
try {
const provider = await createLocalEmbeddingProvider(options);

View File

@@ -0,0 +1,91 @@
import fs from "node:fs/promises";
import os from "node:os";
import path from "node:path";
import { afterEach, beforeEach, describe, expect, it, vi } from "vitest";
import { getMemorySearchManager, type MemoryIndexManager } from "./index.js";
let shouldFail = false;
vi.mock("chokidar", () => ({
default: {
watch: vi.fn(() => ({
on: vi.fn(),
close: vi.fn(async () => undefined),
})),
},
}));
vi.mock("./embeddings.js", () => {
return {
createEmbeddingProvider: async () => ({
requestedProvider: "openai",
provider: {
id: "mock",
model: "mock-embed",
embedQuery: async () => [0.1, 0.2, 0.3],
embedBatch: async (texts: string[]) => {
if (shouldFail) {
throw new Error("embedding failure");
}
return texts.map((_, index) => [index + 1, 0, 0]);
},
},
}),
};
});
describe("memory manager atomic reindex", () => {
let workspaceDir: string;
let indexPath: string;
let manager: MemoryIndexManager | null = null;
beforeEach(async () => {
shouldFail = false;
workspaceDir = await fs.mkdtemp(path.join(os.tmpdir(), "clawdbot-mem-"));
indexPath = path.join(workspaceDir, "index.sqlite");
await fs.mkdir(path.join(workspaceDir, "memory"));
await fs.writeFile(path.join(workspaceDir, "MEMORY.md"), "Hello memory.");
});
afterEach(async () => {
if (manager) {
await manager.close();
manager = null;
}
await fs.rm(workspaceDir, { recursive: true, force: true });
});
it("keeps the prior index when a full reindex fails", async () => {
const cfg = {
agents: {
defaults: {
workspace: workspaceDir,
memorySearch: {
provider: "openai",
model: "mock-embed",
store: { path: indexPath },
sync: { watch: false, onSessionStart: false, onSearch: false },
},
},
list: [{ id: "main", default: true }],
},
};
const result = await getMemorySearchManager({ cfg, agentId: "main" });
expect(result.manager).not.toBeNull();
if (!result.manager) throw new Error("manager missing");
manager = result.manager;
await manager.sync({ force: true });
const before = await manager.search("Hello");
expect(before.length).toBeGreaterThan(0);
shouldFail = true;
await expect(manager.sync({ force: true })).rejects.toThrow("embedding failure");
const after = await manager.search("Hello");
expect(after.length).toBeGreaterThan(0);
});
});

View File

@@ -16,6 +16,7 @@ import {
createEmbeddingProvider,
type EmbeddingProvider,
type EmbeddingProviderResult,
type GeminiEmbeddingClient,
type OpenAiEmbeddingClient,
} from "./embeddings.js";
import {
@@ -104,9 +105,10 @@ export class MemoryIndexManager {
private readonly workspaceDir: string;
private readonly settings: ResolvedMemorySearchConfig;
private readonly provider: EmbeddingProvider;
private readonly requestedProvider: "openai" | "local";
private readonly requestedProvider: "openai" | "gemini" | "local";
private readonly fallbackReason?: string;
private readonly openAi?: OpenAiEmbeddingClient;
private readonly gemini?: GeminiEmbeddingClient;
private readonly batch: {
enabled: boolean;
wait: boolean;
@@ -114,7 +116,7 @@ export class MemoryIndexManager {
pollIntervalMs: number;
timeoutMs: number;
};
private readonly db: DatabaseSync;
private db: DatabaseSync;
private readonly sources: Set<MemorySource>;
private readonly providerKey: string;
private readonly cache: { enabled: boolean; maxEntries?: number };
@@ -142,6 +144,7 @@ export class MemoryIndexManager {
private sessionsDirtyFiles = new Set<string>();
private sessionWarm = new Set<string>();
private syncing: Promise<void> | null = null;
private readonly allowAtomicReindex: boolean;
static async get(params: {
cfg: ClawdbotConfig;
@@ -182,6 +185,7 @@ export class MemoryIndexManager {
workspaceDir: string;
settings: ResolvedMemorySearchConfig;
providerResult: EmbeddingProviderResult;
options?: { allowAtomicReindex?: boolean; enableBackgroundSync?: boolean };
}) {
this.cacheKey = params.cacheKey;
this.cfg = params.cfg;
@@ -192,6 +196,8 @@ export class MemoryIndexManager {
this.requestedProvider = params.providerResult.requestedProvider;
this.fallbackReason = params.providerResult.fallbackReason;
this.openAi = params.providerResult.openAi;
this.gemini = params.providerResult.gemini;
this.allowAtomicReindex = params.options?.allowAtomicReindex ?? true;
this.sources = new Set(params.settings.sources);
this.db = this.openDatabase();
this.providerKey = computeEmbeddingProviderKey({
@@ -200,6 +206,13 @@ export class MemoryIndexManager {
openAi: this.openAi
? { baseUrl: this.openAi.baseUrl, model: this.openAi.model, headers: this.openAi.headers }
: undefined,
gemini: this.gemini
? {
baseUrl: this.gemini.baseUrl,
model: this.gemini.model,
headers: this.gemini.headers,
}
: undefined,
});
this.cache = {
enabled: params.settings.cache.enabled,
@@ -216,9 +229,12 @@ export class MemoryIndexManager {
if (meta?.vectorDims) {
this.vector.dims = meta.vectorDims;
}
this.ensureWatcher();
this.ensureSessionListener();
this.ensureIntervalSync();
const enableBackgroundSync = params.options?.enableBackgroundSync ?? true;
if (enableBackgroundSync) {
this.ensureWatcher();
this.ensureSessionListener();
this.ensureIntervalSync();
}
this.dirty = this.sources.has("memory");
if (this.sources.has("sessions")) {
this.sessionsDirty = true;
@@ -782,7 +798,7 @@ export class MemoryIndexManager {
force?: boolean;
progress?: (update: MemorySyncProgressUpdate) => void;
}) {
const progress = params?.progress ? this.createSyncProgress(params.progress) : undefined;
const progressCallback = params?.progress;
const vectorReady = await this.ensureVectorReady();
const meta = this.readMeta();
const needsFullReindex =
@@ -794,6 +810,12 @@ export class MemoryIndexManager {
meta.chunkTokens !== this.settings.chunking.tokens ||
meta.chunkOverlap !== this.settings.chunking.overlap ||
(vectorReady && !meta?.vectorDims);
if (needsFullReindex && this.allowAtomicReindex) {
await this.runAtomicReindex({ reason: params?.reason, progress: progressCallback });
return;
}
const progress = progressCallback ? this.createSyncProgress(progressCallback) : undefined;
if (needsFullReindex) {
this.resetIndex();
}
@@ -833,6 +855,126 @@ export class MemoryIndexManager {
}
}
private createScratchManager(tempPath: string): MemoryIndexManager {
const scratchSettings: ResolvedMemorySearchConfig = {
...this.settings,
store: {
...this.settings.store,
path: tempPath,
},
sync: {
...this.settings.sync,
watch: false,
intervalMinutes: 0,
},
};
return new MemoryIndexManager({
cacheKey: `${this.cacheKey}:scratch:${Date.now()}`,
cfg: this.cfg,
agentId: this.agentId,
workspaceDir: this.workspaceDir,
settings: scratchSettings,
providerResult: {
provider: this.provider,
requestedProvider: this.requestedProvider,
fallbackReason: this.fallbackReason,
openAi: this.openAi,
gemini: this.gemini,
},
options: {
allowAtomicReindex: false,
enableBackgroundSync: false,
},
});
}
private buildTempIndexPath(): string {
const basePath = resolveUserPath(this.settings.store.path);
const dir = path.dirname(basePath);
ensureDir(dir);
const stamp = `${Date.now()}-${Math.random().toString(16).slice(2, 10)}`;
return path.join(dir, `${path.basename(basePath)}.tmp-${stamp}`);
}
private reopenDatabase() {
this.db = this.openDatabase();
this.fts.available = false;
this.fts.loadError = undefined;
this.ensureSchema();
this.vector.available = null;
this.vector.loadError = undefined;
this.vectorReady = null;
this.vector.dims = undefined;
const meta = this.readMeta();
if (meta?.vectorDims) {
this.vector.dims = meta.vectorDims;
}
}
private async swapIndexFile(tempPath: string): Promise<void> {
const dbPath = resolveUserPath(this.settings.store.path);
const backupPath = `${dbPath}.bak-${Date.now()}`;
let hasBackup = false;
let shouldReopen = false;
this.db.close();
try {
try {
await fs.rename(dbPath, backupPath);
hasBackup = true;
} catch (err) {
const code = (err as NodeJS.ErrnoException).code;
if (code !== "ENOENT") throw err;
}
await fs.rename(tempPath, dbPath);
shouldReopen = true;
if (hasBackup) {
await fs.rm(backupPath, { force: true });
}
} catch (err) {
if (hasBackup) {
try {
await fs.rename(backupPath, dbPath);
shouldReopen = true;
} catch {}
}
if (!shouldReopen) {
try {
await fs.access(dbPath);
shouldReopen = true;
} catch {}
}
throw err;
} finally {
await fs.rm(tempPath, { force: true });
if (shouldReopen) {
this.reopenDatabase();
}
}
}
private async runAtomicReindex(params: {
reason?: string;
progress?: (update: MemorySyncProgressUpdate) => void;
}) {
const tempPath = this.buildTempIndexPath();
const scratch = this.createScratchManager(tempPath);
try {
await scratch.sync({ reason: params.reason, force: true, progress: params.progress });
} catch (err) {
await fs.rm(tempPath, { force: true });
throw err;
} finally {
await scratch.close().catch(() => undefined);
}
await this.swapIndexFile(tempPath);
this.dirty = false;
this.sessionsDirty = false;
this.sessionsDirtyFiles.clear();
}
private resetIndex() {
this.db.exec(`DELETE FROM files`);
this.db.exec(`DELETE FROM chunks`);

View File

@@ -5,6 +5,7 @@ export function computeEmbeddingProviderKey(params: {
providerId: string;
providerModel: string;
openAi?: { baseUrl: string; model: string; headers: Record<string, string> };
gemini?: { baseUrl: string; model: string; headers: Record<string, string> };
}): string {
if (params.openAi) {
const headerNames = fingerprintHeaderNames(params.openAi.headers);
@@ -17,5 +18,16 @@ export function computeEmbeddingProviderKey(params: {
}),
);
}
if (params.gemini) {
const headerNames = fingerprintHeaderNames(params.gemini.headers);
return hashText(
JSON.stringify({
provider: "gemini",
baseUrl: params.gemini.baseUrl,
model: params.gemini.model,
headerNames,
}),
);
}
return hashText(JSON.stringify({ provider: params.providerId, model: params.providerModel }));
}