|
|
|
|
@@ -4,6 +4,7 @@ import { immer } from "zustand/middleware/immer"
|
|
|
|
|
import { castDraft } from "immer"
|
|
|
|
|
import { createWithEqualityFn } from "zustand/traditional"
|
|
|
|
|
import {
|
|
|
|
|
AdjustMaskOperate,
|
|
|
|
|
CV2Flag,
|
|
|
|
|
ExtenderDirection,
|
|
|
|
|
FreeuConfig,
|
|
|
|
|
@@ -27,13 +28,14 @@ import {
|
|
|
|
|
PAINT_BY_EXAMPLE,
|
|
|
|
|
} from "./const"
|
|
|
|
|
import {
|
|
|
|
|
blobToImage,
|
|
|
|
|
canvasToImage,
|
|
|
|
|
dataURItoBlob,
|
|
|
|
|
generateMask,
|
|
|
|
|
loadImage,
|
|
|
|
|
srcToFile,
|
|
|
|
|
} from "./utils"
|
|
|
|
|
import inpaint, { getGenInfo, runPlugin } from "./api"
|
|
|
|
|
import inpaint, { getGenInfo, postAdjustMask, runPlugin } from "./api"
|
|
|
|
|
import { toast } from "@/components/ui/use-toast"
|
|
|
|
|
|
|
|
|
|
type FileManagerState = {
|
|
|
|
|
@@ -102,13 +104,14 @@ export type Settings = {
|
|
|
|
|
|
|
|
|
|
// PowerPaint
|
|
|
|
|
powerpaintTask: PowerPaintTask
|
|
|
|
|
|
|
|
|
|
// AdjustMask
|
|
|
|
|
adjustMaskKernelSize: number
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
type InteractiveSegState = {
|
|
|
|
|
isInteractiveSeg: boolean
|
|
|
|
|
interactiveSegMask: HTMLImageElement | null
|
|
|
|
|
tmpInteractiveSegMask: HTMLImageElement | null
|
|
|
|
|
prevInteractiveSegMask: HTMLImageElement | null
|
|
|
|
|
clicks: number[][]
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@@ -119,8 +122,12 @@ type EditorState = {
|
|
|
|
|
lineGroups: LineGroup[]
|
|
|
|
|
lastLineGroup: LineGroup
|
|
|
|
|
curLineGroup: LineGroup
|
|
|
|
|
// 只用来显示
|
|
|
|
|
|
|
|
|
|
// mask from interactive-seg or other segmentation models
|
|
|
|
|
extraMasks: HTMLImageElement[]
|
|
|
|
|
prevExtraMasks: HTMLImageElement[]
|
|
|
|
|
|
|
|
|
|
temporaryMasks: HTMLImageElement[]
|
|
|
|
|
// redo 相关
|
|
|
|
|
redoRenders: HTMLImageElement[]
|
|
|
|
|
redoCurLines: Line[]
|
|
|
|
|
@@ -135,6 +142,7 @@ type AppState = {
|
|
|
|
|
imageWidth: number
|
|
|
|
|
isInpainting: boolean
|
|
|
|
|
isPluginRunning: boolean
|
|
|
|
|
isAdjustingMask: boolean
|
|
|
|
|
windowSize: Size
|
|
|
|
|
editorState: EditorState
|
|
|
|
|
disableShortCuts: boolean
|
|
|
|
|
@@ -209,6 +217,9 @@ type AppAction = {
|
|
|
|
|
redo: () => void
|
|
|
|
|
undoDisabled: () => boolean
|
|
|
|
|
redoDisabled: () => boolean
|
|
|
|
|
|
|
|
|
|
adjustMask: (operate: AdjustMaskOperate) => Promise<void>
|
|
|
|
|
clearMask: () => void
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const defaultValues: AppState = {
|
|
|
|
|
@@ -219,6 +230,7 @@ const defaultValues: AppState = {
|
|
|
|
|
imageWidth: 0,
|
|
|
|
|
isInpainting: false,
|
|
|
|
|
isPluginRunning: false,
|
|
|
|
|
isAdjustingMask: false,
|
|
|
|
|
disableShortCuts: false,
|
|
|
|
|
|
|
|
|
|
windowSize: {
|
|
|
|
|
@@ -230,6 +242,8 @@ const defaultValues: AppState = {
|
|
|
|
|
brushSizeScale: 1,
|
|
|
|
|
renders: [],
|
|
|
|
|
extraMasks: [],
|
|
|
|
|
prevExtraMasks: [],
|
|
|
|
|
temporaryMasks: [],
|
|
|
|
|
lineGroups: [],
|
|
|
|
|
lastLineGroup: [],
|
|
|
|
|
curLineGroup: [],
|
|
|
|
|
@@ -240,9 +254,7 @@ const defaultValues: AppState = {
|
|
|
|
|
|
|
|
|
|
interactiveSegState: {
|
|
|
|
|
isInteractiveSeg: false,
|
|
|
|
|
interactiveSegMask: null,
|
|
|
|
|
tmpInteractiveSegMask: null,
|
|
|
|
|
prevInteractiveSegMask: null,
|
|
|
|
|
clicks: [],
|
|
|
|
|
},
|
|
|
|
|
|
|
|
|
|
@@ -323,6 +335,7 @@ const defaultValues: AppState = {
|
|
|
|
|
enableFreeu: false,
|
|
|
|
|
freeuConfig: { s1: 0.9, s2: 0.2, b1: 1.2, b2: 1.4 },
|
|
|
|
|
powerpaintTask: PowerPaintTask.text_guided,
|
|
|
|
|
adjustMaskKernelSize: 12,
|
|
|
|
|
},
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@@ -335,10 +348,9 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
|
|
|
|
|
if (get().settings.showExtender) {
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
const { lastLineGroup, curLineGroup } = get().editorState
|
|
|
|
|
const { prevInteractiveSegMask, interactiveSegMask } =
|
|
|
|
|
get().interactiveSegState
|
|
|
|
|
if (curLineGroup.length !== 0 || interactiveSegMask !== null) {
|
|
|
|
|
const { lastLineGroup, curLineGroup, prevExtraMasks, extraMasks } =
|
|
|
|
|
get().editorState
|
|
|
|
|
if (curLineGroup.length !== 0 || extraMasks.length !== 0) {
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
const { imageWidth, imageHeight } = get()
|
|
|
|
|
@@ -347,13 +359,13 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
|
|
|
|
|
imageWidth,
|
|
|
|
|
imageHeight,
|
|
|
|
|
[lastLineGroup],
|
|
|
|
|
prevInteractiveSegMask ? [prevInteractiveSegMask] : [],
|
|
|
|
|
prevExtraMasks,
|
|
|
|
|
BRUSH_COLOR
|
|
|
|
|
)
|
|
|
|
|
try {
|
|
|
|
|
const maskImage = await canvasToImage(maskCanvas)
|
|
|
|
|
set((state) => {
|
|
|
|
|
state.editorState.extraMasks.push(castDraft(maskImage))
|
|
|
|
|
state.editorState.temporaryMasks.push(castDraft(maskImage))
|
|
|
|
|
})
|
|
|
|
|
} catch (e) {
|
|
|
|
|
console.error(e)
|
|
|
|
|
@@ -362,7 +374,7 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
|
|
|
|
|
},
|
|
|
|
|
hidePrevMask: () => {
|
|
|
|
|
set((state) => {
|
|
|
|
|
state.editorState.extraMasks = []
|
|
|
|
|
state.editorState.temporaryMasks = []
|
|
|
|
|
})
|
|
|
|
|
},
|
|
|
|
|
|
|
|
|
|
@@ -408,33 +420,36 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const { lastLineGroup, curLineGroup, lineGroups, renders } =
|
|
|
|
|
get().editorState
|
|
|
|
|
|
|
|
|
|
const { interactiveSegMask, prevInteractiveSegMask } =
|
|
|
|
|
get().interactiveSegState
|
|
|
|
|
const {
|
|
|
|
|
lastLineGroup,
|
|
|
|
|
curLineGroup,
|
|
|
|
|
lineGroups,
|
|
|
|
|
renders,
|
|
|
|
|
prevExtraMasks,
|
|
|
|
|
extraMasks,
|
|
|
|
|
} = get().editorState
|
|
|
|
|
|
|
|
|
|
const useLastLineGroup =
|
|
|
|
|
curLineGroup.length === 0 &&
|
|
|
|
|
interactiveSegMask === null &&
|
|
|
|
|
extraMasks.length === 0 &&
|
|
|
|
|
!settings.showExtender
|
|
|
|
|
|
|
|
|
|
// useLastLineGroup 的影响
|
|
|
|
|
// 1. 使用上一次的 mask
|
|
|
|
|
// 2. 结果替换当前 render
|
|
|
|
|
let maskImage = null
|
|
|
|
|
let maskImages: HTMLImageElement[] = []
|
|
|
|
|
let maskLineGroup: LineGroup = []
|
|
|
|
|
if (useLastLineGroup === true) {
|
|
|
|
|
maskLineGroup = lastLineGroup
|
|
|
|
|
maskImage = prevInteractiveSegMask
|
|
|
|
|
maskImages = prevExtraMasks
|
|
|
|
|
} else {
|
|
|
|
|
maskLineGroup = curLineGroup
|
|
|
|
|
maskImage = interactiveSegMask
|
|
|
|
|
maskImages = extraMasks
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (
|
|
|
|
|
maskLineGroup.length === 0 &&
|
|
|
|
|
maskImage === null &&
|
|
|
|
|
maskImages === null &&
|
|
|
|
|
!settings.showExtender
|
|
|
|
|
) {
|
|
|
|
|
toast({
|
|
|
|
|
@@ -474,7 +489,7 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
|
|
|
|
|
imageWidth,
|
|
|
|
|
imageHeight,
|
|
|
|
|
[maskLineGroup],
|
|
|
|
|
maskImage ? [maskImage] : []
|
|
|
|
|
maskImages
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
try {
|
|
|
|
|
@@ -500,6 +515,8 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
|
|
|
|
|
lineGroups: newLineGroups,
|
|
|
|
|
lastLineGroup: maskLineGroup,
|
|
|
|
|
curLineGroup: [],
|
|
|
|
|
extraMasks: [],
|
|
|
|
|
prevExtraMasks: maskImages,
|
|
|
|
|
})
|
|
|
|
|
} catch (e: any) {
|
|
|
|
|
toast({
|
|
|
|
|
@@ -512,15 +529,6 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
|
|
|
|
|
set((state) => {
|
|
|
|
|
state.isInpainting = false
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
const newInteractiveSegState = {
|
|
|
|
|
...defaultValues.interactiveSegState,
|
|
|
|
|
prevInteractiveSegMask: maskImage,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
set((state) => {
|
|
|
|
|
state.interactiveSegState = castDraft(newInteractiveSegState)
|
|
|
|
|
})
|
|
|
|
|
},
|
|
|
|
|
|
|
|
|
|
runRenderablePlugin: async (
|
|
|
|
|
@@ -557,8 +565,8 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
|
|
|
|
|
} else {
|
|
|
|
|
const newMask = new Image()
|
|
|
|
|
await loadImage(newMask, blob)
|
|
|
|
|
get().updateInteractiveSegState({
|
|
|
|
|
interactiveSegMask: newMask,
|
|
|
|
|
set((state) => {
|
|
|
|
|
state.editorState.extraMasks.push(castDraft(newMask))
|
|
|
|
|
})
|
|
|
|
|
}
|
|
|
|
|
const end = new Date()
|
|
|
|
|
@@ -618,7 +626,9 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
|
|
|
|
|
},
|
|
|
|
|
|
|
|
|
|
getIsProcessing: (): boolean => {
|
|
|
|
|
return get().isInpainting || get().isPluginRunning
|
|
|
|
|
return (
|
|
|
|
|
get().isInpainting || get().isPluginRunning || get().isAdjustingMask
|
|
|
|
|
)
|
|
|
|
|
},
|
|
|
|
|
|
|
|
|
|
isSD: (): boolean => {
|
|
|
|
|
@@ -809,14 +819,14 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
|
|
|
|
|
|
|
|
|
|
handleInteractiveSegAccept: () => {
|
|
|
|
|
set((state) => {
|
|
|
|
|
return {
|
|
|
|
|
...state,
|
|
|
|
|
interactiveSegState: {
|
|
|
|
|
...defaultValues.interactiveSegState,
|
|
|
|
|
interactiveSegMask:
|
|
|
|
|
state.interactiveSegState.tmpInteractiveSegMask,
|
|
|
|
|
},
|
|
|
|
|
if (state.interactiveSegState.tmpInteractiveSegMask) {
|
|
|
|
|
state.editorState.extraMasks.push(
|
|
|
|
|
castDraft(state.interactiveSegState.tmpInteractiveSegMask)
|
|
|
|
|
)
|
|
|
|
|
}
|
|
|
|
|
state.interactiveSegState = castDraft({
|
|
|
|
|
...defaultValues.interactiveSegState,
|
|
|
|
|
})
|
|
|
|
|
})
|
|
|
|
|
},
|
|
|
|
|
|
|
|
|
|
@@ -986,10 +996,54 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
|
|
|
|
|
set((state) => {
|
|
|
|
|
state.settings.seed = newValue
|
|
|
|
|
}),
|
|
|
|
|
|
|
|
|
|
adjustMask: async (operate: AdjustMaskOperate) => {
|
|
|
|
|
const { imageWidth, imageHeight } = get()
|
|
|
|
|
const { curLineGroup, extraMasks } = get().editorState
|
|
|
|
|
const { adjustMaskKernelSize } = get().settings
|
|
|
|
|
if (curLineGroup.length === 0 && extraMasks.length === 0) {
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
set((state) => {
|
|
|
|
|
state.isAdjustingMask = true
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
const maskCanvas = generateMask(
|
|
|
|
|
imageWidth,
|
|
|
|
|
imageHeight,
|
|
|
|
|
[curLineGroup],
|
|
|
|
|
extraMasks,
|
|
|
|
|
BRUSH_COLOR
|
|
|
|
|
)
|
|
|
|
|
const maskBlob = dataURItoBlob(maskCanvas.toDataURL())
|
|
|
|
|
const newMaskBlob = await postAdjustMask(
|
|
|
|
|
maskBlob,
|
|
|
|
|
operate,
|
|
|
|
|
adjustMaskKernelSize
|
|
|
|
|
)
|
|
|
|
|
const newMask = await blobToImage(newMaskBlob)
|
|
|
|
|
|
|
|
|
|
// TODO: currently ignore stroke undo/redo
|
|
|
|
|
set((state) => {
|
|
|
|
|
state.editorState.extraMasks = [castDraft(newMask)]
|
|
|
|
|
state.editorState.curLineGroup = []
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
set((state) => {
|
|
|
|
|
state.isAdjustingMask = false
|
|
|
|
|
})
|
|
|
|
|
},
|
|
|
|
|
clearMask: () => {
|
|
|
|
|
set((state) => {
|
|
|
|
|
state.editorState.extraMasks = []
|
|
|
|
|
state.editorState.curLineGroup = []
|
|
|
|
|
})
|
|
|
|
|
},
|
|
|
|
|
})),
|
|
|
|
|
{
|
|
|
|
|
name: "ZUSTAND_STATE", // name of the item in the storage (must be unique)
|
|
|
|
|
version: 0,
|
|
|
|
|
version: 1,
|
|
|
|
|
partialize: (state) =>
|
|
|
|
|
Object.fromEntries(
|
|
|
|
|
Object.entries(state).filter(([key]) =>
|
|
|
|
|
|