add remove bg model selection

This commit is contained in:
Qing
2024-02-08 16:49:54 +08:00
parent cf9ceea4e6
commit 8060e16c70
19 changed files with 915 additions and 222 deletions

View File

@@ -20,8 +20,8 @@ import { Tabs, TabsContent, TabsList, TabsTrigger } from "./ui/tabs"
import { useEffect, useState } from "react"
import { cn } from "@/lib/utils"
import { useQuery } from "@tanstack/react-query"
import { fetchModelInfos, switchModel } from "@/lib/api"
import { ModelInfo } from "@/lib/types"
import { getServerConfig, switchModel, switchPluginModel } from "@/lib/api"
import { ModelInfo, PluginName } from "@/lib/types"
import { useStore } from "@/lib/states"
import { ScrollArea } from "./ui/scroll-area"
import { useToast } from "./ui/use-toast"
@@ -39,6 +39,14 @@ import {
MODEL_TYPE_OTHER,
} from "@/lib/const"
import useHotKey from "@/hooks/useHotkey"
import {
Select,
SelectContent,
SelectGroup,
SelectItem,
SelectTrigger,
SelectValue,
} from "./ui/select"
const formSchema = z.object({
enableFileManager: z.boolean(),
@@ -48,42 +56,45 @@ const formSchema = z.object({
enableManualInpainting: z.boolean(),
enableUploadMask: z.boolean(),
enableAutoExtractPrompt: z.boolean(),
removeBGModel: z.string(),
})
const TAB_GENERAL = "General"
const TAB_MODEL = "Model"
const TAB_PLUGINS = "Plugins"
// const TAB_FILE_MANAGER = "File Manager"
const TAB_NAMES = [TAB_MODEL, TAB_GENERAL]
const TAB_NAMES = [TAB_MODEL, TAB_GENERAL, TAB_PLUGINS]
export function SettingsDialog() {
const [open, toggleOpen] = useToggle(false)
const [openModelSwitching, toggleOpenModelSwitching] = useToggle(false)
const [tab, setTab] = useState(TAB_MODEL)
const [
updateAppState,
settings,
updateSettings,
fileManagerState,
updateFileManagerState,
setAppModel,
setServerConfig,
] = useStore((state) => [
state.updateAppState,
state.settings,
state.updateSettings,
state.fileManagerState,
state.updateFileManagerState,
state.setModel,
state.setServerConfig,
])
const { toast } = useToast()
const [model, setModel] = useState<ModelInfo>(settings.model)
const [modelSwitchingTexts, setModelSwitchingTexts] = useState<string[]>([])
const openModelSwitching = modelSwitchingTexts.length > 0
useEffect(() => {
setModel(settings.model)
}, [settings.model])
const { data: modelInfos, status } = useQuery({
queryKey: ["modelInfos"],
queryFn: fetchModelInfos,
const { data: serverConfig, status } = useQuery({
queryKey: ["serverConfig"],
queryFn: getServerConfig,
})
// 1. Define your form.
@@ -96,9 +107,17 @@ export function SettingsDialog() {
enableAutoExtractPrompt: settings.enableAutoExtractPrompt,
inputDirectory: fileManagerState.inputDirectory,
outputDirectory: fileManagerState.outputDirectory,
removeBGModel: serverConfig?.removeBGModel,
},
})
useEffect(() => {
if (serverConfig) {
setServerConfig(serverConfig)
form.setValue("removeBGModel", serverConfig.removeBGModel)
}
}, [form, serverConfig])
async function onSubmit(values: z.infer<typeof formSchema>) {
// Do something with the form values. ✅ This will be type-safe and validated.
updateSettings({
@@ -109,29 +128,67 @@ export function SettingsDialog() {
})
// TODO: validate input/output Directory
updateFileManagerState({
inputDirectory: values.inputDirectory,
outputDirectory: values.outputDirectory,
})
if (model.name !== settings.model.name) {
toggleOpenModelSwitching()
updateAppState({ disableShortCuts: true })
try {
const newModel = await switchModel(model.name)
toast({
title: `Switch to ${newModel.name} success`,
})
setAppModel(model)
} catch (error: any) {
toast({
variant: "destructive",
title: `Switch to ${model.name} failed: ${error}`,
})
setModel(settings.model)
} finally {
toggleOpenModelSwitching()
updateAppState({ disableShortCuts: false })
// updateFileManagerState({
// inputDirectory: values.inputDirectory,
// outputDirectory: values.outputDirectory,
// })
const shouldSwitchModel = model.name !== settings.model.name
const shouldSwitchRemoveBGModel =
serverConfig?.removeBGModel !== values.removeBGModel
const showModelSwitching = shouldSwitchModel || shouldSwitchRemoveBGModel
if (showModelSwitching) {
const newModelSwitchingTexts: string[] = []
if (shouldSwitchModel) {
newModelSwitchingTexts.push(
`Switching model from ${settings.model.name} to ${model.name}`
)
}
if (shouldSwitchRemoveBGModel) {
newModelSwitchingTexts.push(
`Switching removebg model from ${serverConfig?.removeBGModel} to ${values.removeBGModel}`
)
}
setModelSwitchingTexts(newModelSwitchingTexts)
updateAppState({ disableShortCuts: true })
if (shouldSwitchModel) {
try {
const newModel = await switchModel(model.name)
toast({
title: `Switch to ${newModel.name} success`,
})
setAppModel(model)
} catch (error: any) {
toast({
variant: "destructive",
title: `Switch to ${model.name} failed: ${error}`,
})
setModel(settings.model)
}
}
if (shouldSwitchRemoveBGModel) {
try {
const res = await switchPluginModel(
PluginName.RemoveBG,
values.removeBGModel
)
if (res.status !== 200) {
throw new Error(res.statusText)
}
} catch (error: any) {
toast({
variant: "destructive",
title: `Switch removebg model to ${model.name} failed: ${error}`,
})
}
}
setModelSwitchingTexts([])
updateAppState({ disableShortCuts: false })
}
}
@@ -143,7 +200,17 @@ export function SettingsDialog() {
onSubmit(form.getValues())
}
},
[open, form, model]
[open, form, model, serverConfig]
)
if (status !== "success") {
return <></>
}
const modelInfos = serverConfig.modelInfos
const plugins = serverConfig.plugins
const removeBGEnabled = plugins.some(
(plugin) => plugin.name === PluginName.RemoveBG
)
function onOpenChange(value: boolean) {
@@ -186,10 +253,6 @@ export function SettingsDialog() {
}
function renderModelSettings() {
if (status !== "success") {
return <></>
}
let defaultTab = MODEL_TYPE_INPAINT
for (let info of modelInfos) {
if (model.name === info.name) {
@@ -356,6 +419,44 @@ export function SettingsDialog() {
)
}
function renderPluginsSettings() {
return (
<div className="space-y-4 w-[510px]">
<FormField
control={form.control}
name="removeBGModel"
render={({ field }) => (
<FormItem className="flex items-center justify-between">
<div className="space-y-0.5">
<FormLabel>Remove Background</FormLabel>
<FormDescription>Remove background model</FormDescription>
</div>
<Select
onValueChange={field.onChange}
defaultValue={field.value}
disabled={!removeBGEnabled}
>
<FormControl>
<SelectTrigger className="w-[200px]">
<SelectValue placeholder="Select removebg model" />
</SelectTrigger>
</FormControl>
<SelectContent align="end">
<SelectGroup>
{serverConfig?.removeBGModels.map((model) => (
<SelectItem key={model} value={model}>
{model}
</SelectItem>
))}
</SelectGroup>
</SelectContent>
</Select>
</FormItem>
)}
/>
</div>
)
}
// function renderFileManagerSettings() {
// return (
// <div className="flex flex-col justify-between rounded-lg gap-4 w-[400px]">
@@ -446,7 +547,9 @@ export function SettingsDialog() {
<span className="sr-only">Loading...</span>
</div>
<div>Switching to {model.name}</div>
{modelSwitchingTexts.map((text, index) => (
<div key={index}>{text}</div>
))}
</div>
{/* </AlertDialogDescription> */}
</AlertDialogHeader>
@@ -473,6 +576,7 @@ export function SettingsDialog() {
<Button
key={item}
variant="ghost"
disabled={item === TAB_PLUGINS && !removeBGEnabled}
onClick={() => setTab(item)}
className={cn(
tab === item ? "bg-muted " : "hover:bg-muted",
@@ -489,6 +593,7 @@ export function SettingsDialog() {
<form onSubmit={form.handleSubmit(onSubmit)}>
{tab === TAB_MODEL ? renderModelSettings() : <></>}
{tab === TAB_GENERAL ? renderGeneralSettings() : <></>}
{tab === TAB_PLUGINS ? renderPluginsSettings() : <></>}
{/* {tab === TAB_FILE_MANAGER ? (
renderFileManagerSettings()
) : (

View File

@@ -12,7 +12,7 @@ export default function useInputImage() {
fetch(`${API_ENDPOINT}/inputimage`, { headers })
.then(async (res) => {
if (!res.ok) {
throw new Error("No input image found")
return
}
const filename = res.headers
.get("content-disposition")

View File

@@ -104,15 +104,18 @@ export async function switchModel(name: string): Promise<ModelInfo> {
return res.data
}
export async function switchPluginModel(
plugin_name: string,
model_name: string
) {
return api.post(`/switch_plugin_model`, { plugin_name, model_name })
}
export async function currentModel(): Promise<ModelInfo> {
const res = await api.get("/model")
return res.data
}
export function fetchModelInfos(): Promise<ModelInfo[]> {
return api.get("/models").then((response) => response.data)
}
export async function runPlugin(
genMask: boolean,
name: string,

View File

@@ -14,6 +14,9 @@ export interface PluginInfo {
export interface ServerConfig {
plugins: PluginInfo[]
modelInfos: ModelInfo[]
removeBGModel: string
removeBGModels: string[]
enableFileManager: boolean
enableAutoSaving: boolean
enableControlnet: boolean