diff --git a/src/gateway/server.ts b/src/gateway/server.ts index 06376c5b2..c920d2c03 100644 --- a/src/gateway/server.ts +++ b/src/gateway/server.ts @@ -5578,6 +5578,83 @@ export async function startGatewayServer( } } + if ("model" in p) { + const raw = p.model; + if (raw === null) { + delete next.providerOverride; + delete next.modelOverride; + } else if (raw !== undefined) { + const trimmed = String(raw).trim(); + if (!trimmed) { + respond( + false, + undefined, + errorShape( + ErrorCodes.INVALID_REQUEST, + "invalid model: empty", + ), + ); + break; + } + const resolvedDefault = resolveConfiguredModelRef({ + cfg, + defaultProvider: DEFAULT_PROVIDER, + defaultModel: DEFAULT_MODEL, + }); + const aliasIndex = buildModelAliasIndex({ + cfg, + defaultProvider: resolvedDefault.provider, + }); + const resolved = resolveModelRefFromString({ + raw: trimmed, + defaultProvider: resolvedDefault.provider, + aliasIndex, + }); + if (!resolved) { + respond( + false, + undefined, + errorShape( + ErrorCodes.INVALID_REQUEST, + `invalid model: ${trimmed}`, + ), + ); + break; + } + const catalog = await loadGatewayModelCatalog(); + const allowed = buildAllowedModelSet({ + cfg, + catalog, + defaultProvider: resolvedDefault.provider, + }); + const key = modelKey( + resolved.ref.provider, + resolved.ref.model, + ); + if (!allowed.allowAny && !allowed.allowedKeys.has(key)) { + respond( + false, + undefined, + errorShape( + ErrorCodes.INVALID_REQUEST, + `model not allowed: ${key}`, + ), + ); + break; + } + if ( + resolved.ref.provider === resolvedDefault.provider && + resolved.ref.model === resolvedDefault.model + ) { + delete next.providerOverride; + delete next.modelOverride; + } else { + next.providerOverride = resolved.ref.provider; + next.modelOverride = resolved.ref.model; + } + } + } + if ("groupActivation" in p) { const raw = p.groupActivation; if (raw === null) { diff --git a/src/tui/commands.ts b/src/tui/commands.ts index 2c1665e4f..a33332065 100644 --- a/src/tui/commands.ts +++ b/src/tui/commands.ts @@ -48,9 +48,9 @@ export function getSlashCommands(): SlashCommand[] { name: "activation", description: "Set group activation", getArgumentCompletions: (prefix) => - ACTIVATION_LEVELS.filter((v) => - v.startsWith(prefix.toLowerCase()), - ).map((value) => ({ value, label: value })), + ACTIVATION_LEVELS.filter((v) => v.startsWith(prefix.toLowerCase())).map( + (value) => ({ value, label: value }), + ), }, { name: "deliver", diff --git a/src/tui/components/chat-log.ts b/src/tui/components/chat-log.ts index 700a670a5..8129707f6 100644 --- a/src/tui/components/chat-log.ts +++ b/src/tui/components/chat-log.ts @@ -1,8 +1,8 @@ import { Container, Spacer, Text } from "@mariozechner/pi-tui"; +import { theme } from "../theme/theme.js"; import { AssistantMessageComponent } from "./assistant-message.js"; import { ToolExecutionComponent } from "./tool-execution.js"; import { UserMessageComponent } from "./user-message.js"; -import { theme } from "../theme/theme.js"; export class ChatLog extends Container { private toolById = new Map(); @@ -46,10 +46,7 @@ export class ChatLog extends Container { } finalizeAssistant(text: string, runId?: string) { - if ( - this.streamingAssistant && - (!runId || runId === this.streamingRunId) - ) { + if (this.streamingAssistant && (!runId || runId === this.streamingRunId)) { this.streamingAssistant.setText(text); } else { this.startAssistant(text, runId); diff --git a/src/tui/components/selectors.ts b/src/tui/components/selectors.ts index e5204d4a8..aa613c623 100644 --- a/src/tui/components/selectors.ts +++ b/src/tui/components/selectors.ts @@ -1,8 +1,8 @@ import { - SelectList, type SelectItem, - SettingsList, + SelectList, type SettingItem, + SettingsList, } from "@mariozechner/pi-tui"; import { selectListTheme, settingsListTheme } from "../theme/theme.js"; diff --git a/src/tui/gateway-chat.ts b/src/tui/gateway-chat.ts index 4638aa835..2ae0145c6 100644 --- a/src/tui/gateway-chat.ts +++ b/src/tui/gateway-chat.ts @@ -42,10 +42,18 @@ export type GatewaySessionList = { verboseLevel?: string; model?: string; contextTokens?: number | null; + totalTokens?: number | null; displayName?: string; }>; }; +export type GatewayModelChoice = { + id: string; + name: string; + provider: string; + contextWindow?: number; +}; + export class GatewayChatClient { private client: GatewayClient; private readyPromise: Promise; @@ -161,16 +169,11 @@ export class GatewayChatClient { return await this.client.request("status"); } - async listModels(): Promise< - Array<{ - id: string; - name: string; - provider: string; - contextWindow?: number; - }> - > { - const res = await this.client.request<{ models?: unknown }>("models.list"); - return Array.isArray(res?.models) ? (res.models as Array) : []; + async listModels(): Promise { + const res = await this.client.request<{ models?: GatewayModelChoice[] }>( + "models.list", + ); + return Array.isArray(res?.models) ? res.models : []; } } diff --git a/src/tui/theme/theme.ts b/src/tui/theme/theme.ts index abfce79e3..03eeff80e 100644 --- a/src/tui/theme/theme.ts +++ b/src/tui/theme/theme.ts @@ -1,10 +1,10 @@ -import chalk from "chalk"; import type { EditorTheme, MarkdownTheme, SelectListTheme, SettingsListTheme, } from "@mariozechner/pi-tui"; +import chalk from "chalk"; const palette = { text: "#E8E3D5", diff --git a/src/tui/tui.ts b/src/tui/tui.ts index e83ab8f4c..31801c5cf 100644 --- a/src/tui/tui.ts +++ b/src/tui/tui.ts @@ -1,16 +1,19 @@ import { CombinedAutocompleteProvider, + type Component, Container, ProcessTerminal, Text, TUI, - type Component, } from "@mariozechner/pi-tui"; import { loadConfig } from "../config/config.js"; +import { getSlashCommands, helpText, parseCommand } from "./commands.js"; import { ChatLog } from "./components/chat-log.js"; import { CustomEditor } from "./components/custom-editor.js"; -import { createSelectList, createSettingsList } from "./components/selectors.js"; -import { getSlashCommands, helpText, parseCommand } from "./commands.js"; +import { + createSelectList, + createSettingsList, +} from "./components/selectors.js"; import { GatewayChatClient } from "./gateway-chat.js"; import { editorTheme, theme } from "./theme/theme.js"; @@ -180,7 +183,9 @@ export async function runTui(opts: TuiOptions) { includeGlobal: false, includeUnknown: false, }); - const entry = result.sessions.find((row) => row.key === currentSessionKey); + const entry = result.sessions.find( + (row) => row.key === currentSessionKey, + ); sessionInfo = { thinkingLevel: entry?.thinkingLevel, verboseLevel: entry?.verboseLevel, @@ -210,7 +215,8 @@ export async function runTui(opts: TuiOptions) { }; currentSessionId = typeof record.sessionId === "string" ? record.sessionId : null; - sessionInfo.thinkingLevel = record.thinkingLevel ?? sessionInfo.thinkingLevel; + sessionInfo.thinkingLevel = + record.thinkingLevel ?? sessionInfo.thinkingLevel; chatLog.clearAll(); chatLog.addSystem(`session ${currentSessionKey}`); for (const entry of record.messages ?? []) {