Files
IOPaint/web_app/src/lib/states.ts
2023-12-05 12:40:04 +08:00

563 lines
14 KiB
TypeScript

import { create } from "zustand"
import { persist } from "zustand/middleware"
import { shallow } from "zustand/shallow"
import { immer } from "zustand/middleware/immer"
import { createWithEqualityFn } from "zustand/traditional"
import {
CV2Flag,
FreeuConfig,
LDMSampler,
Line,
LineGroup,
ModelInfo,
Point,
SDSampler,
Size,
SortBy,
SortOrder,
} from "./types"
import { DEFAULT_BRUSH_SIZE, MODEL_TYPE_INPAINT } from "./const"
type FileManagerState = {
sortBy: SortBy
sortOrder: SortOrder
layout: "rows" | "masonry"
searchText: string
inputDirectory: string
outputDirectory: string
}
type CropperState = {
x: number
y: number
width: number
height: number
}
export type Settings = {
model: ModelInfo
enableDownloadMask: boolean
enableManualInpainting: boolean
enableUploadMask: boolean
showCroper: boolean
// For LDM
ldmSteps: number
ldmSampler: LDMSampler
// For ZITS
zitsWireframe: boolean
// For OpenCV2
cv2Radius: number
cv2Flag: CV2Flag
// For Diffusion moel
prompt: string
negativePrompt: string
seed: number
seedFixed: boolean
// For SD
sdMaskBlur: number
sdStrength: number
sdSteps: number
sdGuidanceScale: number
sdSampler: SDSampler
sdMatchHistograms: boolean
sdScale: number
// Paint by Example
paintByExampleSteps: number
paintByExampleGuidanceScale: number
paintByExampleMaskBlur: number
paintByExampleMatchHistograms: boolean
// InstructPix2Pix
p2pSteps: number
p2pImageGuidanceScale: number
p2pGuidanceScale: number
// ControlNet
controlnetConditioningScale: number
controlnetMethod: string
enableLCMLora: boolean
enableFreeu: boolean
freeuConfig: FreeuConfig
}
type ServerConfig = {
plugins: string[]
enableFileManager: boolean
enableAutoSaving: boolean
}
type InteractiveSegState = {
isInteractiveSeg: boolean
isInteractiveSegRunning: boolean
clicks: number[][]
}
type EditorState = {
baseBrushSize: number
brushSizeScale: number
renders: HTMLImageElement[]
lineGroups: LineGroup[]
lastLineGroup: LineGroup
curLineGroup: LineGroup
// redo 相关
redoRenders: HTMLImageElement[]
redoCurLines: Line[]
redoLineGroups: LineGroup[]
}
type AppState = {
file: File | null
customMask: File | null
imageHeight: number
imageWidth: number
isInpainting: boolean
isPluginRunning: boolean
windowSize: Size
editorState: EditorState
interactiveSegState: InteractiveSegState
fileManagerState: FileManagerState
cropperState: CropperState
serverConfig: ServerConfig
settings: Settings
}
type AppAction = {
updateAppState: (newState: Partial<AppState>) => void
setFile: (file: File) => void
setCustomFile: (file: File) => void
setIsInpainting: (newValue: boolean) => void
setIsPluginRunning: (newValue: boolean) => void
setBaseBrushSize: (newValue: number) => void
getBrushSize: () => number
setImageSize: (width: number, height: number) => void
setCropperX: (newValue: number) => void
setCropperY: (newValue: number) => void
setCropperWidth: (newValue: number) => void
setCropperHeight: (newValue: number) => void
setServerConfig: (newValue: ServerConfig) => void
setSeed: (newValue: number) => void
updateSettings: (newSettings: Partial<Settings>) => void
updateFileManagerState: (newState: Partial<FileManagerState>) => void
updateInteractiveSegState: (newState: Partial<InteractiveSegState>) => void
resetInteractiveSegState: () => void
showPromptInput: () => boolean
showSidePanel: () => boolean
// EditorState
updateEditorState: (newState: Partial<EditorState>) => void
runMannually: () => boolean
handleCanvasMouseDown: (point: Point) => void
handleCanvasMouseMove: (point: Point) => void
cleanCurLineGroup: () => void
resetRedoState: () => void
undo: () => void
redo: () => void
undoDisabled: () => boolean
redoDisabled: () => boolean
}
const defaultValues: AppState = {
file: null,
customMask: null,
imageHeight: 0,
imageWidth: 0,
isInpainting: false,
isPluginRunning: false,
windowSize: {
height: 600,
width: 800,
},
editorState: {
baseBrushSize: DEFAULT_BRUSH_SIZE,
brushSizeScale: 1,
renders: [],
lineGroups: [],
lastLineGroup: [],
curLineGroup: [],
redoRenders: [],
redoCurLines: [],
redoLineGroups: [],
},
interactiveSegState: {
isInteractiveSeg: false,
isInteractiveSegRunning: false,
clicks: [],
},
cropperState: {
x: 0,
y: 0,
width: 512,
height: 512,
},
fileManagerState: {
sortBy: SortBy.CTIME,
sortOrder: SortOrder.DESCENDING,
layout: "masonry",
searchText: "",
inputDirectory: "",
outputDirectory: "",
},
serverConfig: {
plugins: [],
enableFileManager: false,
enableAutoSaving: false,
},
settings: {
model: {
name: "lama",
path: "lama",
model_type: "inpaint",
support_controlnet: false,
controlnets: [],
support_freeu: false,
support_lcm_lora: false,
is_single_file_diffusers: false,
need_prompt: false,
},
showCroper: false,
enableDownloadMask: false,
enableManualInpainting: false,
enableUploadMask: false,
ldmSteps: 30,
ldmSampler: LDMSampler.ddim,
zitsWireframe: true,
cv2Radius: 5,
cv2Flag: CV2Flag.INPAINT_NS,
prompt: "",
negativePrompt: "",
seed: 42,
seedFixed: false,
sdMaskBlur: 5,
sdStrength: 1.0,
sdSteps: 50,
sdGuidanceScale: 7.5,
sdSampler: SDSampler.uni_pc,
sdMatchHistograms: false,
sdScale: 100,
paintByExampleSteps: 50,
paintByExampleGuidanceScale: 7.5,
paintByExampleMaskBlur: 5,
paintByExampleMatchHistograms: false,
p2pSteps: 50,
p2pImageGuidanceScale: 1.5,
p2pGuidanceScale: 7.5,
controlnetConditioningScale: 0.4,
controlnetMethod: "lllyasviel/control_v11p_sd15_canny",
enableLCMLora: false,
enableFreeu: false,
freeuConfig: { s1: 0.9, s2: 0.2, b1: 1.2, b2: 1.4 },
},
}
export const useStore = createWithEqualityFn<AppState & AppAction>()(
persist(
immer((set, get) => ({
...defaultValues,
// Edirot State //
updateEditorState: (newState: Partial<EditorState>) => {
set((state) => {
return {
...state,
editorState: {
...state.editorState,
...newState,
},
}
})
},
cleanCurLineGroup: () => {
get().updateEditorState({ curLineGroup: [] })
},
handleCanvasMouseDown: (point: Point) => {
let lineGroup: LineGroup = []
const state = get()
if (state.runMannually()) {
lineGroup = [...state.editorState.curLineGroup]
}
lineGroup.push({ size: state.getBrushSize(), pts: [point] })
set((state) => {
state.editorState.curLineGroup = lineGroup
})
},
handleCanvasMouseMove: (point: Point) => {
set((state) => {
const curLineGroup = state.editorState.curLineGroup
if (curLineGroup.length) {
curLineGroup[curLineGroup.length - 1].pts.push(point)
}
})
},
runMannually: (): boolean => {
const state = get()
return (
state.settings.enableManualInpainting ||
state.settings.model.model_type !== MODEL_TYPE_INPAINT
)
},
// undo/redo
undoDisabled: (): boolean => {
const editorState = get().editorState
if (editorState.renders.length > 0) {
return false
}
if (get().runMannually()) {
if (editorState.curLineGroup.length === 0) {
return true
}
} else if (editorState.renders.length === 0) {
return true
}
return false
},
undo: () => {
if (
get().runMannually() &&
get().editorState.curLineGroup.length !== 0
) {
// undoStroke
set((state) => {
const editorState = state.editorState
if (editorState.curLineGroup.length === 0) {
return
}
editorState.lastLineGroup = []
const lastLine = editorState.curLineGroup.pop()!
editorState.redoCurLines.push(lastLine)
})
} else {
set((state) => {
const editorState = state.editorState
if (
editorState.renders.length === 0 ||
editorState.lineGroups.length === 0
) {
return
}
const lastLineGroup = editorState.lineGroups.pop()!
editorState.redoLineGroups.push(lastLineGroup)
editorState.redoCurLines = []
editorState.curLineGroup = []
const lastRender = editorState.renders.pop()!
editorState.redoRenders.push(lastRender)
})
}
},
redoDisabled: (): boolean => {
const editorState = get().editorState
if (editorState.redoRenders.length > 0) {
return false
}
if (get().runMannually()) {
if (editorState.redoCurLines.length === 0) {
return true
}
} else if (editorState.redoRenders.length === 0) {
return true
}
return false
},
redo: () => {
if (
get().runMannually() &&
get().editorState.redoCurLines.length !== 0
) {
set((state) => {
const editorState = state.editorState
if (editorState.redoCurLines.length === 0) {
return
}
const line = editorState.redoCurLines.pop()!
editorState.curLineGroup.push(line)
})
} else {
set((state) => {
const editorState = state.editorState
if (
editorState.redoRenders.length === 0 ||
editorState.redoLineGroups.length === 0
) {
return
}
const lastLineGroup = editorState.redoLineGroups.pop()!
editorState.lineGroups.push(lastLineGroup)
editorState.curLineGroup = []
const lastRender = editorState.redoRenders.pop()!
editorState.renders.push(lastRender)
})
}
},
resetRedoState: () => {
set((state) => {
state.editorState.redoCurLines = []
state.editorState.redoLineGroups = []
state.editorState.redoRenders = []
})
},
//****//
updateAppState: (newState: Partial<AppState>) => {
set(() => newState)
},
getBrushSize: (): number => {
return (
get().editorState.baseBrushSize * get().editorState.brushSizeScale
)
},
showPromptInput: (): boolean => {
const model_type = get().settings.model.model_type
return ["diffusers_sd", "diffusers_sd_inpaint"].includes(model_type)
},
showSidePanel: (): boolean => {
const model_type = get().settings.model.model_type
return ["diffusers_sd", "diffusers_sd_inpaint"].includes(model_type)
},
setServerConfig: (newValue: ServerConfig) => {
set((state) => {
state.serverConfig = newValue
})
},
updateSettings: (newSettings: Partial<Settings>) => {
set((state) => {
state.settings = {
...state.settings,
...newSettings,
}
})
},
updateFileManagerState: (newState: Partial<FileManagerState>) => {
set((state) => {
state.fileManagerState = {
...state.fileManagerState,
...newState,
}
})
},
updateInteractiveSegState: (newState: Partial<InteractiveSegState>) => {
set((state) => {
state.interactiveSegState = {
...state.interactiveSegState,
...newState,
}
})
},
resetInteractiveSegState: () => {
set((state) => {
state.interactiveSegState = defaultValues.interactiveSegState
})
},
setIsInpainting: (newValue: boolean) =>
set((state) => {
state.isInpainting = newValue
}),
setIsPluginRunning: (newValue: boolean) =>
set((state) => {
state.isPluginRunning = newValue
}),
setFile: (file: File) =>
set((state) => {
// TODO: 清空各种状态
state.file = file
}),
setCustomFile: (file: File) =>
set((state) => {
state.customMask = file
}),
setBaseBrushSize: (newValue: number) =>
set((state) => {
state.editorState.baseBrushSize = newValue
}),
setImageSize: (width: number, height: number) => {
// 根据图片尺寸调整 brushSize 的 scale
set((state) => {
state.imageWidth = width
state.imageHeight = height
state.editorState.brushSizeScale =
Math.max(Math.min(width, height), 512) / 512
})
},
setCropperX: (newValue: number) =>
set((state) => {
state.cropperState.x = newValue
}),
setCropperY: (newValue: number) =>
set((state) => {
state.cropperState.y = newValue
}),
setCropperWidth: (newValue: number) =>
set((state) => {
state.cropperState.width = newValue
}),
setCropperHeight: (newValue: number) =>
set((state) => {
state.cropperState.height = newValue
}),
setSeed: (newValue: number) =>
set((state) => {
state.settings.seed = newValue
}),
})),
{
name: "ZUSTAND_STATE", // name of the item in the storage (must be unique)
version: 0,
partialize: (state) =>
Object.fromEntries(
Object.entries(state).filter(([key]) =>
["fileManagerState", "settings"].includes(key)
)
),
}
),
shallow
)
// export const useStore = <U>(selector: (state: AppState & AppAction) => U) => {
// return createWithEqualityFn(selector, shallow)
// }
// export const useStore = createWithEqualityFn(useBaseStore, shallow)