feat: add memory embedding cache

This commit is contained in:
Peter Steinberger
2026-01-18 01:35:58 +00:00
parent 568b8ee96c
commit 0fb2777c6d
9 changed files with 372 additions and 27 deletions

View File

@@ -6,12 +6,14 @@ import { afterEach, beforeEach, describe, expect, it, vi } from "vitest";
import { getMemorySearchManager, type MemoryIndexManager } from "./index.js";
let embedBatchCalls = 0;
vi.mock("./embeddings.js", () => {
const embedText = (text: string) => {
const lower = text.toLowerCase();
const alpha = lower.split("alpha").length - 1;
const beta = lower.split("beta").length - 1;
return [alpha, beta, 1];
return [alpha, beta];
};
return {
createEmbeddingProvider: async (options: { model?: string }) => ({
@@ -20,7 +22,10 @@ vi.mock("./embeddings.js", () => {
id: "mock",
model: options.model ?? "mock-embed",
embedQuery: async (text: string) => embedText(text),
embedBatch: async (texts: string[]) => texts.map(embedText),
embedBatch: async (texts: string[]) => {
embedBatchCalls += 1;
return texts.map(embedText);
},
},
}),
};
@@ -32,12 +37,13 @@ describe("memory index", () => {
let manager: MemoryIndexManager | null = null;
beforeEach(async () => {
embedBatchCalls = 0;
workspaceDir = await fs.mkdtemp(path.join(os.tmpdir(), "clawdbot-mem-"));
indexPath = path.join(workspaceDir, "index.sqlite");
await fs.mkdir(path.join(workspaceDir, "memory"));
await fs.writeFile(
path.join(workspaceDir, "memory", "2026-01-12.md"),
"# Log\nAlpha memory line.\nAnother line.",
"# Log\nAlpha memory line.\nZebra memory line.\nAnother line.",
);
await fs.writeFile(path.join(workspaceDir, "MEMORY.md"), "Beta knowledge base entry.");
});
@@ -146,6 +152,35 @@ describe("memory index", () => {
expect(results.length).toBeGreaterThan(0);
});
it("reuses cached embeddings on forced reindex", async () => {
const cfg = {
agents: {
defaults: {
workspace: workspaceDir,
memorySearch: {
provider: "openai",
model: "mock-embed",
store: { path: indexPath, vector: { enabled: false } },
sync: { watch: false, onSessionStart: false, onSearch: false },
query: { minScore: 0 },
cache: { enabled: true },
},
},
list: [{ id: "main", default: true }],
},
};
const result = await getMemorySearchManager({ cfg, agentId: "main" });
expect(result.manager).not.toBeNull();
if (!result.manager) throw new Error("manager missing");
manager = result.manager;
await manager.sync({ force: true });
const afterFirst = embedBatchCalls;
expect(afterFirst).toBeGreaterThan(0);
await manager.sync({ force: true });
expect(embedBatchCalls).toBe(afterFirst);
});
it("reports vector availability after probe", async () => {
const cfg = {
agents: {

View File

@@ -47,6 +47,7 @@ export type MemorySearchResult = {
type MemoryIndexMeta = {
model: string;
provider: string;
providerKey?: string;
chunkTokens: number;
chunkOverlap: number;
vectorDims?: number;
@@ -106,6 +107,7 @@ type OpenAiBatchOutputLine = {
const META_KEY = "memory_index_meta_v1";
const SNIPPET_MAX_CHARS = 700;
const VECTOR_TABLE = "chunks_vec";
const EMBEDDING_CACHE_TABLE = "embedding_cache";
const SESSION_DIRTY_DEBOUNCE_MS = 5000;
const EMBEDDING_BATCH_MAX_TOKENS = 8000;
const EMBEDDING_APPROX_CHARS_PER_TOKEN = 1;
@@ -143,6 +145,8 @@ export class MemoryIndexManager {
};
private readonly db: DatabaseSync;
private readonly sources: Set<MemorySource>;
private readonly providerKey: string;
private readonly cache: { enabled: boolean; maxEntries?: number };
private readonly vector: {
enabled: boolean;
available: boolean | null;
@@ -214,6 +218,11 @@ export class MemoryIndexManager {
this.openAi = params.providerResult.openAi;
this.sources = new Set(params.settings.sources);
this.db = this.openDatabase();
this.providerKey = this.computeProviderKey();
this.cache = {
enabled: params.settings.cache.enabled,
maxEntries: params.settings.cache.maxEntries,
};
this.ensureSchema();
this.vector = {
enabled: params.settings.store.vector.enabled,
@@ -266,19 +275,19 @@ export class MemoryIndexManager {
const minScore = opts?.minScore ?? this.settings.query.minScore;
const maxResults = opts?.maxResults ?? this.settings.query.maxResults;
const queryVec = await this.provider.embedQuery(cleaned);
if (queryVec.length === 0) return [];
if (!queryVec.some((v) => v !== 0)) return [];
if (await this.ensureVectorReady(queryVec.length)) {
const sourceFilter = this.buildSourceFilter("c");
const rows = this.db
.prepare(
`SELECT c.path, c.start_line, c.end_line, c.text,
c.source,
vec_distance_cosine(v.embedding, ?) AS dist
FROM ${VECTOR_TABLE} v
JOIN chunks c ON c.id = v.id
WHERE c.model = ?${sourceFilter.sql}
ORDER BY dist ASC
LIMIT ?`,
`SELECT c.path, c.start_line, c.end_line, c.text,\n` +
` c.source,\n` +
` vec_distance_cosine(v.embedding, ?) AS dist\n` +
` FROM ${VECTOR_TABLE} v\n` +
` JOIN chunks c ON c.id = v.id\n` +
` WHERE c.model = ?${sourceFilter.sql}\n` +
` ORDER BY dist ASC\n` +
` LIMIT ?`,
)
.all(
vectorToBlob(queryVec),
@@ -372,6 +381,7 @@ export class MemoryIndexManager {
requestedProvider: string;
sources: MemorySource[];
sourceCounts: Array<{ source: MemorySource; files: number; chunks: number }>;
cache?: { enabled: boolean; entries?: number; maxEntries?: number };
fallback?: { from: string; reason?: string };
vector?: {
enabled: boolean;
@@ -432,6 +442,16 @@ export class MemoryIndexManager {
requestedProvider: this.requestedProvider,
sources: Array.from(this.sources),
sourceCounts,
cache: this.cache.enabled
? {
enabled: true,
entries:
(this.db
.prepare(`SELECT COUNT(*) as c FROM ${EMBEDDING_CACHE_TABLE}`)
.get() as { c: number } | undefined)?.c ?? 0,
maxEntries: this.cache.maxEntries,
}
: { enabled: false, maxEntries: this.cache.maxEntries },
fallback: this.fallbackReason ? { from: "local", reason: this.fallbackReason } : undefined,
vector: {
enabled: this.vector.enabled,
@@ -603,6 +623,21 @@ export class MemoryIndexManager {
updated_at INTEGER NOT NULL
);
`);
this.db.exec(`
CREATE TABLE IF NOT EXISTS ${EMBEDDING_CACHE_TABLE} (
provider TEXT NOT NULL,
model TEXT NOT NULL,
provider_key TEXT NOT NULL,
hash TEXT NOT NULL,
embedding TEXT NOT NULL,
dims INTEGER,
updated_at INTEGER NOT NULL,
PRIMARY KEY (provider, model, provider_key, hash)
);
`);
this.db.exec(
`CREATE INDEX IF NOT EXISTS idx_embedding_cache_updated_at ON ${EMBEDDING_CACHE_TABLE}(updated_at);`,
);
this.ensureColumn("files", "source", "TEXT NOT NULL DEFAULT 'memory'");
this.ensureColumn("chunks", "source", "TEXT NOT NULL DEFAULT 'memory'");
this.db.exec(`CREATE INDEX IF NOT EXISTS idx_chunks_path ON chunks(path);`);
@@ -681,6 +716,7 @@ export class MemoryIndexManager {
}
private listChunks(): Array<{
id: string;
path: string;
startLine: number;
endLine: number;
@@ -691,11 +727,12 @@ export class MemoryIndexManager {
const sourceFilter = this.buildSourceFilter();
const rows = this.db
.prepare(
`SELECT path, start_line, end_line, text, embedding, source
`SELECT id, path, start_line, end_line, text, embedding, source
FROM chunks
WHERE model = ?${sourceFilter.sql}`,
)
.all(this.provider.model, ...sourceFilter.params) as Array<{
id: string;
path: string;
start_line: number;
end_line: number;
@@ -704,6 +741,7 @@ export class MemoryIndexManager {
source: MemorySource;
}>;
return rows.map((row) => ({
id: row.id,
path: row.path,
startLine: row.start_line,
endLine: row.end_line,
@@ -779,6 +817,13 @@ export class MemoryIndexManager {
for (const stale of staleRows) {
if (activePaths.has(stale.path)) continue;
this.db.prepare(`DELETE FROM files WHERE path = ? AND source = ?`).run(stale.path, "memory");
try {
this.db
.prepare(
`DELETE FROM ${VECTOR_TABLE} WHERE id IN (SELECT id FROM chunks WHERE path = ? AND source = ?)`,
)
.run(stale.path, "memory");
} catch {}
this.db.prepare(`DELETE FROM chunks WHERE path = ? AND source = ?`).run(stale.path, "memory");
}
}
@@ -860,6 +905,13 @@ export class MemoryIndexManager {
this.db
.prepare(`DELETE FROM files WHERE path = ? AND source = ?`)
.run(stale.path, "sessions");
try {
this.db
.prepare(
`DELETE FROM ${VECTOR_TABLE} WHERE id IN (SELECT id FROM chunks WHERE path = ? AND source = ?)`,
)
.run(stale.path, "sessions");
} catch {}
this.db
.prepare(`DELETE FROM chunks WHERE path = ? AND source = ?`)
.run(stale.path, "sessions");
@@ -902,6 +954,7 @@ export class MemoryIndexManager {
!meta ||
meta.model !== this.provider.model ||
meta.provider !== this.provider.id ||
meta.providerKey !== this.providerKey ||
meta.chunkTokens !== this.settings.chunking.tokens ||
meta.chunkOverlap !== this.settings.chunking.overlap ||
(vectorReady && !meta?.vectorDims);
@@ -929,6 +982,7 @@ export class MemoryIndexManager {
const nextMeta: MemoryIndexMeta = {
model: this.provider.model,
provider: this.provider.id,
providerKey: this.providerKey,
chunkTokens: this.settings.chunking.tokens,
chunkOverlap: this.settings.chunking.overlap,
};
@@ -938,6 +992,9 @@ export class MemoryIndexManager {
if (shouldSyncMemory || shouldSyncSessions || needsFullReindex) {
this.writeMeta(nextMeta);
}
if (shouldSyncMemory || shouldSyncSessions || needsFullReindex) {
this.pruneEmbeddingCacheIfNeeded();
}
}
private resetIndex() {
@@ -1091,16 +1148,121 @@ export class MemoryIndexManager {
return batches;
}
private loadEmbeddingCache(hashes: string[]): Map<string, number[]> {
if (!this.cache.enabled) return new Map();
if (hashes.length === 0) return new Map();
const unique: string[] = [];
const seen = new Set<string>();
for (const hash of hashes) {
if (!hash) continue;
if (seen.has(hash)) continue;
seen.add(hash);
unique.push(hash);
}
if (unique.length === 0) return new Map();
const out = new Map<string, number[]>();
const baseParams = [this.provider.id, this.provider.model, this.providerKey];
const batchSize = 400;
for (let start = 0; start < unique.length; start += batchSize) {
const batch = unique.slice(start, start + batchSize);
const placeholders = batch.map(() => "?").join(", ");
const rows = this.db
.prepare(
`SELECT hash, embedding FROM ${EMBEDDING_CACHE_TABLE}\n` +
` WHERE provider = ? AND model = ? AND provider_key = ? AND hash IN (${placeholders})`,
)
.all(...baseParams, ...batch) as Array<{ hash: string; embedding: string }>;
for (const row of rows) {
out.set(row.hash, parseEmbedding(row.embedding));
}
}
return out;
}
private upsertEmbeddingCache(entries: Array<{ hash: string; embedding: number[] }>): void {
if (!this.cache.enabled) return;
if (entries.length === 0) return;
const now = Date.now();
const stmt = this.db.prepare(
`INSERT INTO ${EMBEDDING_CACHE_TABLE} (provider, model, provider_key, hash, embedding, dims, updated_at)\n` +
` VALUES (?, ?, ?, ?, ?, ?, ?)\n` +
` ON CONFLICT(provider, model, provider_key, hash) DO UPDATE SET\n` +
` embedding=excluded.embedding,\n` +
` dims=excluded.dims,\n` +
` updated_at=excluded.updated_at`,
);
for (const entry of entries) {
const embedding = entry.embedding ?? [];
stmt.run(
this.provider.id,
this.provider.model,
this.providerKey,
entry.hash,
JSON.stringify(embedding),
embedding.length,
now,
);
}
}
private pruneEmbeddingCacheIfNeeded(): void {
if (!this.cache.enabled) return;
const max = this.cache.maxEntries;
if (!max || max <= 0) return;
const row = this.db
.prepare(`SELECT COUNT(*) as c FROM ${EMBEDDING_CACHE_TABLE}`)
.get() as { c: number } | undefined;
const count = row?.c ?? 0;
if (count <= max) return;
const excess = count - max;
this.db
.prepare(
`DELETE FROM ${EMBEDDING_CACHE_TABLE}\n` +
` WHERE rowid IN (\n` +
` SELECT rowid FROM ${EMBEDDING_CACHE_TABLE}\n` +
` ORDER BY updated_at ASC\n` +
` LIMIT ?\n` +
` )`,
)
.run(excess);
}
private async embedChunksInBatches(chunks: MemoryChunk[]): Promise<number[][]> {
if (chunks.length === 0) return [];
const batches = this.buildEmbeddingBatches(chunks);
const embeddings: number[][] = [];
const cached = this.loadEmbeddingCache(chunks.map((chunk) => chunk.hash));
const embeddings: number[][] = Array.from({ length: chunks.length }, () => []);
const missing: Array<{ index: number; chunk: MemoryChunk }> = [];
for (let i = 0; i < chunks.length; i += 1) {
const chunk = chunks[i];
const hit = chunk?.hash ? cached.get(chunk.hash) : undefined;
if (hit && hit.length > 0) {
embeddings[i] = hit;
} else if (chunk) {
missing.push({ index: i, chunk });
}
}
if (missing.length === 0) return embeddings;
const missingChunks = missing.map((m) => m.chunk);
const batches = this.buildEmbeddingBatches(missingChunks);
const toCache: Array<{ hash: string; embedding: number[] }> = [];
let cursor = 0;
for (const batch of batches) {
const batchEmbeddings = await this.embedBatchWithRetry(batch.map((chunk) => chunk.text));
for (let i = 0; i < batch.length; i += 1) {
embeddings.push(batchEmbeddings[i] ?? []);
const item = missing[cursor + i];
const embedding = batchEmbeddings[i] ?? [];
if (item) {
embeddings[item.index] = embedding;
toCache.push({ hash: item.chunk.hash, embedding });
}
}
cursor += batch.length;
}
this.upsertEmbeddingCache(toCache);
return embeddings;
}
@@ -1121,6 +1283,24 @@ export class MemoryIndexManager {
return headers;
}
private computeProviderKey(): string {
if (this.provider.id === "openai" && this.openAi) {
const entries = Object.entries(this.openAi.headers)
.filter(([key]) => key.toLowerCase() !== "authorization")
.sort(([a], [b]) => a.localeCompare(b))
.map(([key, value]) => [key, value]);
return hashText(
JSON.stringify({
provider: "openai",
baseUrl: this.openAi.baseUrl,
model: this.openAi.model,
headers: entries,
}),
);
}
return hashText(JSON.stringify({ provider: this.provider.id, model: this.provider.model }));
}
private buildOpenAiBatchRequests(
chunks: MemoryChunk[],
entry: MemoryFileEntry | SessionFileEntry,
@@ -1300,8 +1480,40 @@ export class MemoryIndexManager {
return this.embedChunksInBatches(chunks);
}
if (chunks.length === 0) return [];
const cached = this.loadEmbeddingCache(chunks.map((chunk) => chunk.hash));
const embeddings: number[][] = Array.from({ length: chunks.length }, () => []);
const missing: Array<{ index: number; chunk: MemoryChunk }> = [];
const { requests, mapping } = this.buildOpenAiBatchRequests(chunks, entry, source);
for (let i = 0; i < chunks.length; i += 1) {
const chunk = chunks[i];
const hit = chunk?.hash ? cached.get(chunk.hash) : undefined;
if (hit && hit.length > 0) {
embeddings[i] = hit;
} else if (chunk) {
missing.push({ index: i, chunk });
}
}
if (missing.length === 0) return embeddings;
const requests: OpenAiBatchRequest[] = [];
const mapping = new Map<string, number>();
for (const item of missing) {
const chunk = item.chunk;
const customId = hashText(
`${source}:${entry.path}:${chunk.startLine}:${chunk.endLine}:${chunk.hash}:${item.index}`,
);
mapping.set(customId, item.index);
requests.push({
custom_id: customId,
method: "POST",
url: OPENAI_BATCH_ENDPOINT,
body: {
model: this.openAi?.model ?? this.provider.model,
input: chunk.text,
},
});
}
const groups = this.splitOpenAiBatchRequests(requests);
log.debug("memory embeddings: openai batch submit", {
source,
@@ -1313,7 +1525,7 @@ export class MemoryIndexManager {
pollIntervalMs: this.batch.pollIntervalMs,
timeoutMs: this.batch.timeoutMs,
});
const embeddings: number[][] = Array.from({ length: chunks.length }, () => []);
const toCache: Array<{ hash: string; embedding: number[] }> = [];
const tasks = groups.map((group, groupIndex) => async () => {
const batchInfo = await this.submitOpenAiBatch(group);
@@ -1373,6 +1585,8 @@ export class MemoryIndexManager {
continue;
}
embeddings[index] = embedding;
const chunk = chunks[index];
if (chunk) toCache.push({ hash: chunk.hash, embedding });
}
if (errors.length > 0) {
throw new Error(`openai batch ${batchInfo.id} failed: ${errors.join("; ")}`);
@@ -1385,6 +1599,7 @@ export class MemoryIndexManager {
});
await this.runWithConcurrency(tasks, this.batch.concurrency);
this.upsertEmbeddingCache(toCache);
return embeddings;
}
@@ -1463,9 +1678,16 @@ export class MemoryIndexManager {
const sample = embeddings.find((embedding) => embedding.length > 0);
const vectorReady = sample ? await this.ensureVectorReady(sample.length) : false;
const now = Date.now();
this.db
.prepare(`DELETE FROM chunks WHERE path = ? AND source = ?`)
.run(entry.path, options.source);
if (vectorReady) {
try {
this.db
.prepare(
`DELETE FROM ${VECTOR_TABLE} WHERE id IN (SELECT id FROM chunks WHERE path = ? AND source = ?)`,
)
.run(entry.path, options.source);
} catch {}
}
this.db.prepare(`DELETE FROM chunks WHERE path = ? AND source = ?`).run(entry.path, options.source);
for (let i = 0; i < chunks.length; i++) {
const chunk = chunks[i];
const embedding = embeddings[i] ?? [];