import { create } from "zustand" import { persist } from "zustand/middleware" import { shallow } from "zustand/shallow" import { immer } from "zustand/middleware/immer" import { createWithEqualityFn } from "zustand/traditional" import { CV2Flag, FreeuConfig, LDMSampler, Line, LineGroup, ModelInfo, Point, SDSampler, Size, SortBy, SortOrder, } from "./types" import { DEFAULT_BRUSH_SIZE, MODEL_TYPE_INPAINT } from "./const" type FileManagerState = { sortBy: SortBy sortOrder: SortOrder layout: "rows" | "masonry" searchText: string inputDirectory: string outputDirectory: string } type CropperState = { x: number y: number width: number height: number } export type Settings = { model: ModelInfo enableDownloadMask: boolean enableManualInpainting: boolean enableUploadMask: boolean showCroper: boolean // For LDM ldmSteps: number ldmSampler: LDMSampler // For ZITS zitsWireframe: boolean // For OpenCV2 cv2Radius: number cv2Flag: CV2Flag // For Diffusion moel prompt: string negativePrompt: string seed: number seedFixed: boolean // For SD sdMaskBlur: number sdStrength: number sdSteps: number sdGuidanceScale: number sdSampler: SDSampler sdMatchHistograms: boolean sdScale: number // Paint by Example paintByExampleSteps: number paintByExampleGuidanceScale: number paintByExampleMaskBlur: number paintByExampleMatchHistograms: boolean // InstructPix2Pix p2pSteps: number p2pImageGuidanceScale: number p2pGuidanceScale: number // ControlNet controlnetConditioningScale: number controlnetMethod: string enableLCMLora: boolean enableFreeu: boolean freeuConfig: FreeuConfig } type ServerConfig = { plugins: string[] enableFileManager: boolean enableAutoSaving: boolean } type InteractiveSegState = { isInteractiveSeg: boolean isInteractiveSegRunning: boolean clicks: number[][] } type EditorState = { baseBrushSize: number brushSizeScale: number renders: HTMLImageElement[] lineGroups: LineGroup[] lastLineGroup: LineGroup curLineGroup: LineGroup // redo 相关 redoRenders: HTMLImageElement[] redoCurLines: Line[] redoLineGroups: LineGroup[] } type AppState = { file: File | null customMask: File | null imageHeight: number imageWidth: number isInpainting: boolean isPluginRunning: boolean windowSize: Size editorState: EditorState interactiveSegState: InteractiveSegState fileManagerState: FileManagerState cropperState: CropperState serverConfig: ServerConfig settings: Settings } type AppAction = { updateAppState: (newState: Partial) => void setFile: (file: File) => void setCustomFile: (file: File) => void setIsInpainting: (newValue: boolean) => void setIsPluginRunning: (newValue: boolean) => void setBaseBrushSize: (newValue: number) => void getBrushSize: () => number setImageSize: (width: number, height: number) => void setCropperX: (newValue: number) => void setCropperY: (newValue: number) => void setCropperWidth: (newValue: number) => void setCropperHeight: (newValue: number) => void setServerConfig: (newValue: ServerConfig) => void setSeed: (newValue: number) => void updateSettings: (newSettings: Partial) => void updateFileManagerState: (newState: Partial) => void updateInteractiveSegState: (newState: Partial) => void resetInteractiveSegState: () => void showPromptInput: () => boolean showSidePanel: () => boolean // EditorState updateEditorState: (newState: Partial) => void runMannually: () => boolean handleCanvasMouseDown: (point: Point) => void handleCanvasMouseMove: (point: Point) => void cleanCurLineGroup: () => void resetRedoState: () => void undo: () => void redo: () => void undoDisabled: () => boolean redoDisabled: () => boolean } const defaultValues: AppState = { file: null, customMask: null, imageHeight: 0, imageWidth: 0, isInpainting: false, isPluginRunning: false, windowSize: { height: 600, width: 800, }, editorState: { baseBrushSize: DEFAULT_BRUSH_SIZE, brushSizeScale: 1, renders: [], lineGroups: [], lastLineGroup: [], curLineGroup: [], redoRenders: [], redoCurLines: [], redoLineGroups: [], }, interactiveSegState: { isInteractiveSeg: false, isInteractiveSegRunning: false, clicks: [], }, cropperState: { x: 0, y: 0, width: 512, height: 512, }, fileManagerState: { sortBy: SortBy.CTIME, sortOrder: SortOrder.DESCENDING, layout: "masonry", searchText: "", inputDirectory: "", outputDirectory: "", }, serverConfig: { plugins: [], enableFileManager: false, enableAutoSaving: false, }, settings: { model: { name: "lama", path: "lama", model_type: "inpaint", support_controlnet: false, controlnets: [], support_freeu: false, support_lcm_lora: false, is_single_file_diffusers: false, need_prompt: false, }, showCroper: false, enableDownloadMask: false, enableManualInpainting: false, enableUploadMask: false, ldmSteps: 30, ldmSampler: LDMSampler.ddim, zitsWireframe: true, cv2Radius: 5, cv2Flag: CV2Flag.INPAINT_NS, prompt: "", negativePrompt: "", seed: 42, seedFixed: false, sdMaskBlur: 5, sdStrength: 1.0, sdSteps: 50, sdGuidanceScale: 7.5, sdSampler: SDSampler.uni_pc, sdMatchHistograms: false, sdScale: 100, paintByExampleSteps: 50, paintByExampleGuidanceScale: 7.5, paintByExampleMaskBlur: 5, paintByExampleMatchHistograms: false, p2pSteps: 50, p2pImageGuidanceScale: 1.5, p2pGuidanceScale: 7.5, controlnetConditioningScale: 0.4, controlnetMethod: "lllyasviel/control_v11p_sd15_canny", enableLCMLora: false, enableFreeu: false, freeuConfig: { s1: 0.9, s2: 0.2, b1: 1.2, b2: 1.4 }, }, } export const useStore = createWithEqualityFn()( persist( immer((set, get) => ({ ...defaultValues, // Edirot State // updateEditorState: (newState: Partial) => { set((state) => { return { ...state, editorState: { ...state.editorState, ...newState, }, } }) }, cleanCurLineGroup: () => { get().updateEditorState({ curLineGroup: [] }) }, handleCanvasMouseDown: (point: Point) => { let lineGroup: LineGroup = [] const state = get() if (state.runMannually()) { lineGroup = [...state.editorState.curLineGroup] } lineGroup.push({ size: state.getBrushSize(), pts: [point] }) set((state) => { state.editorState.curLineGroup = lineGroup }) }, handleCanvasMouseMove: (point: Point) => { set((state) => { const curLineGroup = state.editorState.curLineGroup if (curLineGroup.length) { curLineGroup[curLineGroup.length - 1].pts.push(point) } }) }, runMannually: (): boolean => { const state = get() return ( state.settings.enableManualInpainting || state.settings.model.model_type !== MODEL_TYPE_INPAINT ) }, // undo/redo undoDisabled: (): boolean => { const editorState = get().editorState if (editorState.renders.length > 0) { return false } if (get().runMannually()) { if (editorState.curLineGroup.length === 0) { return true } } else if (editorState.renders.length === 0) { return true } return false }, undo: () => { if ( get().runMannually() && get().editorState.curLineGroup.length !== 0 ) { // undoStroke set((state) => { const editorState = state.editorState if (editorState.curLineGroup.length === 0) { return } editorState.lastLineGroup = [] const lastLine = editorState.curLineGroup.pop()! editorState.redoCurLines.push(lastLine) }) } else { set((state) => { const editorState = state.editorState if ( editorState.renders.length === 0 || editorState.lineGroups.length === 0 ) { return } const lastLineGroup = editorState.lineGroups.pop()! editorState.redoLineGroups.push(lastLineGroup) editorState.redoCurLines = [] editorState.curLineGroup = [] const lastRender = editorState.renders.pop()! editorState.redoRenders.push(lastRender) }) } }, redoDisabled: (): boolean => { const editorState = get().editorState if (editorState.redoRenders.length > 0) { return false } if (get().runMannually()) { if (editorState.redoCurLines.length === 0) { return true } } else if (editorState.redoRenders.length === 0) { return true } return false }, redo: () => { if ( get().runMannually() && get().editorState.redoCurLines.length !== 0 ) { set((state) => { const editorState = state.editorState if (editorState.redoCurLines.length === 0) { return } const line = editorState.redoCurLines.pop()! editorState.curLineGroup.push(line) }) } else { set((state) => { const editorState = state.editorState if ( editorState.redoRenders.length === 0 || editorState.redoLineGroups.length === 0 ) { return } const lastLineGroup = editorState.redoLineGroups.pop()! editorState.lineGroups.push(lastLineGroup) editorState.curLineGroup = [] const lastRender = editorState.redoRenders.pop()! editorState.renders.push(lastRender) }) } }, resetRedoState: () => { set((state) => { state.editorState.redoCurLines = [] state.editorState.redoLineGroups = [] state.editorState.redoRenders = [] }) }, //****// updateAppState: (newState: Partial) => { set(() => newState) }, getBrushSize: (): number => { return ( get().editorState.baseBrushSize * get().editorState.brushSizeScale ) }, showPromptInput: (): boolean => { const model_type = get().settings.model.model_type return ["diffusers_sd", "diffusers_sd_inpaint"].includes(model_type) }, showSidePanel: (): boolean => { const model_type = get().settings.model.model_type return ["diffusers_sd", "diffusers_sd_inpaint"].includes(model_type) }, setServerConfig: (newValue: ServerConfig) => { set((state) => { state.serverConfig = newValue }) }, updateSettings: (newSettings: Partial) => { set((state) => { state.settings = { ...state.settings, ...newSettings, } }) }, updateFileManagerState: (newState: Partial) => { set((state) => { state.fileManagerState = { ...state.fileManagerState, ...newState, } }) }, updateInteractiveSegState: (newState: Partial) => { set((state) => { state.interactiveSegState = { ...state.interactiveSegState, ...newState, } }) }, resetInteractiveSegState: () => { set((state) => { state.interactiveSegState = defaultValues.interactiveSegState }) }, setIsInpainting: (newValue: boolean) => set((state) => { state.isInpainting = newValue }), setIsPluginRunning: (newValue: boolean) => set((state) => { state.isPluginRunning = newValue }), setFile: (file: File) => set((state) => { // TODO: 清空各种状态 state.file = file }), setCustomFile: (file: File) => set((state) => { state.customMask = file }), setBaseBrushSize: (newValue: number) => set((state) => { state.editorState.baseBrushSize = newValue }), setImageSize: (width: number, height: number) => { // 根据图片尺寸调整 brushSize 的 scale set((state) => { state.imageWidth = width state.imageHeight = height state.editorState.brushSizeScale = Math.max(Math.min(width, height), 512) / 512 }) }, setCropperX: (newValue: number) => set((state) => { state.cropperState.x = newValue }), setCropperY: (newValue: number) => set((state) => { state.cropperState.y = newValue }), setCropperWidth: (newValue: number) => set((state) => { state.cropperState.width = newValue }), setCropperHeight: (newValue: number) => set((state) => { state.cropperState.height = newValue }), setSeed: (newValue: number) => set((state) => { state.settings.seed = newValue }), })), { name: "ZUSTAND_STATE", // name of the item in the storage (must be unique) version: 0, partialize: (state) => Object.fromEntries( Object.entries(state).filter(([key]) => ["fileManagerState", "settings"].includes(key) ) ), } ), shallow ) // export const useStore = (selector: (state: AppState & AppAction) => U) => { // return createWithEqualityFn(selector, shallow) // } // export const useStore = createWithEqualityFn(useBaseStore, shallow)