add adjust mask feature

This commit is contained in:
Qing
2024-01-05 14:57:30 +08:00
parent 2996544e75
commit e889e527ab
18 changed files with 507 additions and 76 deletions

View File

@@ -201,3 +201,28 @@ export async function getSamplers(): Promise<string[]> {
const res = await api.post("/samplers")
return res.data
}
export async function postAdjustMask(
mask: File | Blob,
operate: "expand" | "shrink",
kernel_size: number
) {
const maskBase64 = await convertToBase64(mask)
const res = await fetch(`${API_ENDPOINT}/adjust_mask`, {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({
mask: maskBase64,
operate: operate,
kernel_size: kernel_size,
}),
})
if (res.ok) {
const blob = await res.blob()
return blob
}
const errMsg = await res.json()
throw new Error(errMsg)
}

View File

@@ -4,6 +4,7 @@ import { immer } from "zustand/middleware/immer"
import { castDraft } from "immer"
import { createWithEqualityFn } from "zustand/traditional"
import {
AdjustMaskOperate,
CV2Flag,
ExtenderDirection,
FreeuConfig,
@@ -27,13 +28,14 @@ import {
PAINT_BY_EXAMPLE,
} from "./const"
import {
blobToImage,
canvasToImage,
dataURItoBlob,
generateMask,
loadImage,
srcToFile,
} from "./utils"
import inpaint, { getGenInfo, runPlugin } from "./api"
import inpaint, { getGenInfo, postAdjustMask, runPlugin } from "./api"
import { toast } from "@/components/ui/use-toast"
type FileManagerState = {
@@ -102,13 +104,14 @@ export type Settings = {
// PowerPaint
powerpaintTask: PowerPaintTask
// AdjustMask
adjustMaskKernelSize: number
}
type InteractiveSegState = {
isInteractiveSeg: boolean
interactiveSegMask: HTMLImageElement | null
tmpInteractiveSegMask: HTMLImageElement | null
prevInteractiveSegMask: HTMLImageElement | null
clicks: number[][]
}
@@ -119,8 +122,12 @@ type EditorState = {
lineGroups: LineGroup[]
lastLineGroup: LineGroup
curLineGroup: LineGroup
// 只用来显示
// mask from interactive-seg or other segmentation models
extraMasks: HTMLImageElement[]
prevExtraMasks: HTMLImageElement[]
temporaryMasks: HTMLImageElement[]
// redo 相关
redoRenders: HTMLImageElement[]
redoCurLines: Line[]
@@ -135,6 +142,7 @@ type AppState = {
imageWidth: number
isInpainting: boolean
isPluginRunning: boolean
isAdjustingMask: boolean
windowSize: Size
editorState: EditorState
disableShortCuts: boolean
@@ -209,6 +217,9 @@ type AppAction = {
redo: () => void
undoDisabled: () => boolean
redoDisabled: () => boolean
adjustMask: (operate: AdjustMaskOperate) => Promise<void>
clearMask: () => void
}
const defaultValues: AppState = {
@@ -219,6 +230,7 @@ const defaultValues: AppState = {
imageWidth: 0,
isInpainting: false,
isPluginRunning: false,
isAdjustingMask: false,
disableShortCuts: false,
windowSize: {
@@ -230,6 +242,8 @@ const defaultValues: AppState = {
brushSizeScale: 1,
renders: [],
extraMasks: [],
prevExtraMasks: [],
temporaryMasks: [],
lineGroups: [],
lastLineGroup: [],
curLineGroup: [],
@@ -240,9 +254,7 @@ const defaultValues: AppState = {
interactiveSegState: {
isInteractiveSeg: false,
interactiveSegMask: null,
tmpInteractiveSegMask: null,
prevInteractiveSegMask: null,
clicks: [],
},
@@ -323,6 +335,7 @@ const defaultValues: AppState = {
enableFreeu: false,
freeuConfig: { s1: 0.9, s2: 0.2, b1: 1.2, b2: 1.4 },
powerpaintTask: PowerPaintTask.text_guided,
adjustMaskKernelSize: 12,
},
}
@@ -335,10 +348,9 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
if (get().settings.showExtender) {
return
}
const { lastLineGroup, curLineGroup } = get().editorState
const { prevInteractiveSegMask, interactiveSegMask } =
get().interactiveSegState
if (curLineGroup.length !== 0 || interactiveSegMask !== null) {
const { lastLineGroup, curLineGroup, prevExtraMasks, extraMasks } =
get().editorState
if (curLineGroup.length !== 0 || extraMasks.length !== 0) {
return
}
const { imageWidth, imageHeight } = get()
@@ -347,13 +359,13 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
imageWidth,
imageHeight,
[lastLineGroup],
prevInteractiveSegMask ? [prevInteractiveSegMask] : [],
prevExtraMasks,
BRUSH_COLOR
)
try {
const maskImage = await canvasToImage(maskCanvas)
set((state) => {
state.editorState.extraMasks.push(castDraft(maskImage))
state.editorState.temporaryMasks.push(castDraft(maskImage))
})
} catch (e) {
console.error(e)
@@ -362,7 +374,7 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
},
hidePrevMask: () => {
set((state) => {
state.editorState.extraMasks = []
state.editorState.temporaryMasks = []
})
},
@@ -408,33 +420,36 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
return
}
const { lastLineGroup, curLineGroup, lineGroups, renders } =
get().editorState
const { interactiveSegMask, prevInteractiveSegMask } =
get().interactiveSegState
const {
lastLineGroup,
curLineGroup,
lineGroups,
renders,
prevExtraMasks,
extraMasks,
} = get().editorState
const useLastLineGroup =
curLineGroup.length === 0 &&
interactiveSegMask === null &&
extraMasks.length === 0 &&
!settings.showExtender
// useLastLineGroup 的影响
// 1. 使用上一次的 mask
// 2. 结果替换当前 render
let maskImage = null
let maskImages: HTMLImageElement[] = []
let maskLineGroup: LineGroup = []
if (useLastLineGroup === true) {
maskLineGroup = lastLineGroup
maskImage = prevInteractiveSegMask
maskImages = prevExtraMasks
} else {
maskLineGroup = curLineGroup
maskImage = interactiveSegMask
maskImages = extraMasks
}
if (
maskLineGroup.length === 0 &&
maskImage === null &&
maskImages === null &&
!settings.showExtender
) {
toast({
@@ -474,7 +489,7 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
imageWidth,
imageHeight,
[maskLineGroup],
maskImage ? [maskImage] : []
maskImages
)
try {
@@ -500,6 +515,8 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
lineGroups: newLineGroups,
lastLineGroup: maskLineGroup,
curLineGroup: [],
extraMasks: [],
prevExtraMasks: maskImages,
})
} catch (e: any) {
toast({
@@ -512,15 +529,6 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
set((state) => {
state.isInpainting = false
})
const newInteractiveSegState = {
...defaultValues.interactiveSegState,
prevInteractiveSegMask: maskImage,
}
set((state) => {
state.interactiveSegState = castDraft(newInteractiveSegState)
})
},
runRenderablePlugin: async (
@@ -557,8 +565,8 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
} else {
const newMask = new Image()
await loadImage(newMask, blob)
get().updateInteractiveSegState({
interactiveSegMask: newMask,
set((state) => {
state.editorState.extraMasks.push(castDraft(newMask))
})
}
const end = new Date()
@@ -618,7 +626,9 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
},
getIsProcessing: (): boolean => {
return get().isInpainting || get().isPluginRunning
return (
get().isInpainting || get().isPluginRunning || get().isAdjustingMask
)
},
isSD: (): boolean => {
@@ -809,14 +819,14 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
handleInteractiveSegAccept: () => {
set((state) => {
return {
...state,
interactiveSegState: {
...defaultValues.interactiveSegState,
interactiveSegMask:
state.interactiveSegState.tmpInteractiveSegMask,
},
if (state.interactiveSegState.tmpInteractiveSegMask) {
state.editorState.extraMasks.push(
castDraft(state.interactiveSegState.tmpInteractiveSegMask)
)
}
state.interactiveSegState = castDraft({
...defaultValues.interactiveSegState,
})
})
},
@@ -986,10 +996,54 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
set((state) => {
state.settings.seed = newValue
}),
adjustMask: async (operate: AdjustMaskOperate) => {
const { imageWidth, imageHeight } = get()
const { curLineGroup, extraMasks } = get().editorState
const { adjustMaskKernelSize } = get().settings
if (curLineGroup.length === 0 && extraMasks.length === 0) {
return
}
set((state) => {
state.isAdjustingMask = true
})
const maskCanvas = generateMask(
imageWidth,
imageHeight,
[curLineGroup],
extraMasks,
BRUSH_COLOR
)
const maskBlob = dataURItoBlob(maskCanvas.toDataURL())
const newMaskBlob = await postAdjustMask(
maskBlob,
operate,
adjustMaskKernelSize
)
const newMask = await blobToImage(newMaskBlob)
// TODO: currently ignore stroke undo/redo
set((state) => {
state.editorState.extraMasks = [castDraft(newMask)]
state.editorState.curLineGroup = []
})
set((state) => {
state.isAdjustingMask = false
})
},
clearMask: () => {
set((state) => {
state.editorState.extraMasks = []
state.editorState.curLineGroup = []
})
},
})),
{
name: "ZUSTAND_STATE", // name of the item in the storage (must be unique)
version: 0,
version: 1,
partialize: (state) =>
Object.fromEntries(
Object.entries(state).filter(([key]) =>

View File

@@ -125,3 +125,5 @@ export enum PowerPaintTask {
object_remove = "object-remove",
outpainting = "outpainting",
}
export type AdjustMaskOperate = "expand" | "shrink"

View File

@@ -53,6 +53,13 @@ export function loadImage(image: HTMLImageElement, src: string) {
})
}
export async function blobToImage(blob: Blob) {
const dataURL = URL.createObjectURL(blob)
const newImage = new Image()
await loadImage(newImage, dataURL)
return newImage
}
export function canvasToImage(
canvas: HTMLCanvasElement
): Promise<HTMLImageElement> {