fix: improve compaction queueing and oauth flows
This commit is contained in:
@@ -6,6 +6,7 @@ import {
|
||||
type OAuthCredentials,
|
||||
type OAuthProvider,
|
||||
} from "@mariozechner/pi-ai";
|
||||
import lockfile from "proper-lockfile";
|
||||
|
||||
import type { ClawdbotConfig } from "../config/config.js";
|
||||
import { resolveOAuthPath } from "../config/paths.js";
|
||||
@@ -68,6 +69,83 @@ function saveJsonFile(pathname: string, data: unknown) {
|
||||
fs.chmodSync(pathname, 0o600);
|
||||
}
|
||||
|
||||
function ensureAuthStoreFile(pathname: string) {
|
||||
if (fs.existsSync(pathname)) return;
|
||||
const payload: AuthProfileStore = {
|
||||
version: AUTH_STORE_VERSION,
|
||||
profiles: {},
|
||||
};
|
||||
saveJsonFile(pathname, payload);
|
||||
}
|
||||
|
||||
function buildOAuthApiKey(
|
||||
provider: OAuthProvider,
|
||||
credentials: OAuthCredentials,
|
||||
): string {
|
||||
const needsProjectId =
|
||||
provider === "google-gemini-cli" || provider === "google-antigravity";
|
||||
return needsProjectId
|
||||
? JSON.stringify({
|
||||
token: credentials.access,
|
||||
projectId: credentials.projectId,
|
||||
})
|
||||
: credentials.access;
|
||||
}
|
||||
|
||||
async function refreshOAuthTokenWithLock(params: {
|
||||
profileId: string;
|
||||
provider: OAuthProvider;
|
||||
}): Promise<{ apiKey: string; newCredentials: OAuthCredentials } | null> {
|
||||
const authPath = resolveAuthStorePath();
|
||||
ensureAuthStoreFile(authPath);
|
||||
|
||||
let release: (() => Promise<void>) | undefined;
|
||||
try {
|
||||
release = await lockfile.lock(authPath, {
|
||||
retries: {
|
||||
retries: 10,
|
||||
factor: 2,
|
||||
minTimeout: 100,
|
||||
maxTimeout: 10_000,
|
||||
randomize: true,
|
||||
},
|
||||
stale: 30_000,
|
||||
});
|
||||
|
||||
const store = ensureAuthProfileStore();
|
||||
const cred = store.profiles[params.profileId];
|
||||
if (!cred || cred.type !== "oauth") return null;
|
||||
|
||||
if (Date.now() < cred.expires) {
|
||||
return {
|
||||
apiKey: buildOAuthApiKey(cred.provider, cred),
|
||||
newCredentials: cred,
|
||||
};
|
||||
}
|
||||
|
||||
const oauthCreds: Record<string, OAuthCredentials> = {
|
||||
[cred.provider]: cred,
|
||||
};
|
||||
const result = await getOAuthApiKey(cred.provider, oauthCreds);
|
||||
if (!result) return null;
|
||||
store.profiles[params.profileId] = {
|
||||
...cred,
|
||||
...result.newCredentials,
|
||||
type: "oauth",
|
||||
};
|
||||
saveAuthProfileStore(store);
|
||||
return result;
|
||||
} finally {
|
||||
if (release) {
|
||||
try {
|
||||
await release();
|
||||
} catch {
|
||||
// ignore unlock errors
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function coerceLegacyStore(raw: unknown): LegacyAuthStore | null {
|
||||
if (!raw || typeof raw !== "object") return null;
|
||||
const record = raw as Record<string, unknown>;
|
||||
@@ -323,23 +401,41 @@ export async function resolveApiKeyForProfile(params: {
|
||||
if (cred.type === "api_key") {
|
||||
return { apiKey: cred.key, provider: cred.provider, email: cred.email };
|
||||
}
|
||||
if (Date.now() < cred.expires) {
|
||||
return {
|
||||
apiKey: buildOAuthApiKey(cred.provider, cred),
|
||||
provider: cred.provider,
|
||||
email: cred.email,
|
||||
};
|
||||
}
|
||||
|
||||
const oauthCreds: Record<string, OAuthCredentials> = {
|
||||
[cred.provider]: cred,
|
||||
};
|
||||
const result = await getOAuthApiKey(cred.provider, oauthCreds);
|
||||
if (!result) return null;
|
||||
store.profiles[profileId] = {
|
||||
...cred,
|
||||
...result.newCredentials,
|
||||
type: "oauth",
|
||||
};
|
||||
saveAuthProfileStore(store);
|
||||
return {
|
||||
apiKey: result.apiKey,
|
||||
provider: cred.provider,
|
||||
email: cred.email,
|
||||
};
|
||||
try {
|
||||
const result = await refreshOAuthTokenWithLock({
|
||||
profileId,
|
||||
provider: cred.provider,
|
||||
});
|
||||
if (!result) return null;
|
||||
return {
|
||||
apiKey: result.apiKey,
|
||||
provider: cred.provider,
|
||||
email: cred.email,
|
||||
};
|
||||
} catch (error) {
|
||||
const refreshedStore = ensureAuthProfileStore();
|
||||
const refreshed = refreshedStore.profiles[profileId];
|
||||
if (refreshed?.type === "oauth" && Date.now() < refreshed.expires) {
|
||||
return {
|
||||
apiKey: buildOAuthApiKey(refreshed.provider, refreshed),
|
||||
provider: refreshed.provider,
|
||||
email: refreshed.email ?? cred.email,
|
||||
};
|
||||
}
|
||||
const message = error instanceof Error ? error.message : String(error);
|
||||
throw new Error(
|
||||
`OAuth token refresh failed for ${cred.provider}: ${message}. ` +
|
||||
"Please try again or re-authenticate.",
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
export function markAuthProfileGood(params: {
|
||||
|
||||
@@ -113,6 +113,7 @@ export type EmbeddedPiCompactResult = {
|
||||
type EmbeddedPiQueueHandle = {
|
||||
queueMessage: (text: string) => Promise<void>;
|
||||
isStreaming: () => boolean;
|
||||
isCompacting: () => boolean;
|
||||
abort: () => void;
|
||||
};
|
||||
|
||||
@@ -212,6 +213,7 @@ export function queueEmbeddedPiMessage(
|
||||
const handle = ACTIVE_EMBEDDED_RUNS.get(sessionId);
|
||||
if (!handle) return false;
|
||||
if (!handle.isStreaming()) return false;
|
||||
if (handle.isCompacting()) return false;
|
||||
void handle.queueMessage(text);
|
||||
return true;
|
||||
}
|
||||
@@ -810,21 +812,7 @@ export async function runEmbeddedPiAgent(params: {
|
||||
aborted = true;
|
||||
void session.abort();
|
||||
};
|
||||
const queueHandle: EmbeddedPiQueueHandle = {
|
||||
queueMessage: async (text: string) => {
|
||||
await session.steer(text);
|
||||
},
|
||||
isStreaming: () => session.isStreaming,
|
||||
abort: abortRun,
|
||||
};
|
||||
ACTIVE_EMBEDDED_RUNS.set(params.sessionId, queueHandle);
|
||||
|
||||
const {
|
||||
assistantTexts,
|
||||
toolMetas,
|
||||
unsubscribe,
|
||||
waitForCompactionRetry,
|
||||
} = subscribeEmbeddedPiSession({
|
||||
const subscription = subscribeEmbeddedPiSession({
|
||||
session,
|
||||
runId: params.runId,
|
||||
verboseLevel: params.verboseLevel,
|
||||
@@ -837,6 +825,22 @@ export async function runEmbeddedPiAgent(params: {
|
||||
onAgentEvent: params.onAgentEvent,
|
||||
enforceFinalTag: params.enforceFinalTag,
|
||||
});
|
||||
const {
|
||||
assistantTexts,
|
||||
toolMetas,
|
||||
unsubscribe,
|
||||
waitForCompactionRetry,
|
||||
} = subscription;
|
||||
|
||||
const queueHandle: EmbeddedPiQueueHandle = {
|
||||
queueMessage: async (text: string) => {
|
||||
await session.steer(text);
|
||||
},
|
||||
isStreaming: () => session.isStreaming,
|
||||
isCompacting: () => subscription.isCompacting(),
|
||||
abort: abortRun,
|
||||
};
|
||||
ACTIVE_EMBEDDED_RUNS.set(params.sessionId, queueHandle);
|
||||
|
||||
let abortWarnTimer: NodeJS.Timeout | undefined;
|
||||
const abortTimer = setTimeout(
|
||||
|
||||
@@ -968,6 +968,7 @@ describe("subscribeEmbeddedPiSession", () => {
|
||||
});
|
||||
}
|
||||
|
||||
expect(subscription.isCompacting()).toBe(true);
|
||||
expect(subscription.assistantTexts.length).toBe(0);
|
||||
|
||||
let resolved = false;
|
||||
@@ -1004,6 +1005,8 @@ describe("subscribeEmbeddedPiSession", () => {
|
||||
listener({ type: "auto_compaction_start" });
|
||||
}
|
||||
|
||||
expect(subscription.isCompacting()).toBe(true);
|
||||
|
||||
let resolved = false;
|
||||
const waitPromise = subscription.waitForCompactionRetry().then(() => {
|
||||
resolved = true;
|
||||
@@ -1018,6 +1021,7 @@ describe("subscribeEmbeddedPiSession", () => {
|
||||
|
||||
await waitPromise;
|
||||
expect(resolved).toBe(true);
|
||||
expect(subscription.isCompacting()).toBe(false);
|
||||
});
|
||||
|
||||
it("waits for multiple compaction retries before resolving", async () => {
|
||||
|
||||
@@ -604,6 +604,7 @@ export function subscribeEmbeddedPiSession(params: {
|
||||
assistantTexts,
|
||||
toolMetas,
|
||||
unsubscribe,
|
||||
isCompacting: () => compactionInFlight || pendingCompactionRetry > 0,
|
||||
waitForCompactionRetry: () => {
|
||||
if (compactionInFlight || pendingCompactionRetry > 0) {
|
||||
ensureCompactionPromise();
|
||||
|
||||
Reference in New Issue
Block a user