This commit is contained in:
Qing
2023-12-11 22:28:07 +08:00
parent fecf4beef0
commit 354a1280a4
13 changed files with 531 additions and 747 deletions

View File

@@ -15,18 +15,13 @@ export default async function inpaint(
imageFile: File,
settings: Settings,
croperRect: Rect,
maskBase64?: string,
customMask?: File,
paintByExampleImage?: File
mask: File | Blob,
paintByExampleImage: File | null = null
) {
// 1080, 2000, Original
const fd = new FormData()
fd.append("image", imageFile)
if (maskBase64 !== undefined) {
fd.append("mask", dataURItoBlob(maskBase64))
} else if (customMask !== undefined) {
fd.append("mask", customMask)
}
fd.append("mask", mask)
fd.append("ldmSteps", settings.ldmSteps.toString())
fd.append("ldmSampler", settings.ldmSampler.toString())
@@ -42,8 +37,7 @@ export default async function inpaint(
fd.append("croperY", croperRect.y.toString())
fd.append("croperHeight", croperRect.height.toString())
fd.append("croperWidth", croperRect.width.toString())
// fd.append("useCroper", settings.showCroper ? "true" : "false")
fd.append("useCroper", "false")
fd.append("useCroper", settings.showCroper ? "true" : "false")
fd.append("sdMaskBlur", settings.sdMaskBlur.toString())
fd.append("sdStrength", settings.sdStrength.toString())
@@ -147,7 +141,6 @@ export async function runPlugin(
name: string,
imageFile: File,
upscale?: number,
maskFile?: File | null,
clicks?: number[][]
) {
const fd = new FormData()
@@ -159,9 +152,6 @@ export async function runPlugin(
if (clicks) {
fd.append("clicks", JSON.stringify(clicks))
}
if (maskFile) {
fd.append("mask", maskFile)
}
try {
const res = await fetch(`${API_ENDPOINT}/run_plugin`, {

View File

@@ -1,7 +1,8 @@
import { create } from "zustand"
import { persist } from "zustand/middleware"
import { shallow } from "zustand/shallow"
import { immer } from "zustand/middleware/immer"
import { castDraft } from "immer"
import { nanoid } from "nanoid"
import { createWithEqualityFn } from "zustand/traditional"
import {
CV2Flag,
@@ -10,6 +11,7 @@ import {
Line,
LineGroup,
ModelInfo,
PluginParams,
Point,
SDSampler,
Size,
@@ -17,6 +19,9 @@ import {
SortOrder,
} from "./types"
import { DEFAULT_BRUSH_SIZE, MODEL_TYPE_INPAINT } from "./const"
import { dataURItoBlob, generateMask, loadImage, srcToFile } from "./utils"
import inpaint, { runPlugin } from "./api"
import { toast, useToast } from "@/components/ui/use-toast"
type FileManagerState = {
sortBy: SortBy
@@ -95,7 +100,9 @@ type ServerConfig = {
type InteractiveSegState = {
isInteractiveSeg: boolean
isInteractiveSegRunning: boolean
interactiveSegMask: HTMLImageElement | null
tmpInteractiveSegMask: HTMLImageElement | null
prevInteractiveSegMask: HTMLImageElement | null
clicks: number[][]
}
@@ -103,9 +110,11 @@ type EditorState = {
baseBrushSize: number
brushSizeScale: number
renders: HTMLImageElement[]
paintByExampleImage: File | null
lineGroups: LineGroup[]
lastLineGroup: LineGroup
curLineGroup: LineGroup
extraMasks: HTMLImageElement[]
// redo 相关
redoRenders: HTMLImageElement[]
redoCurLines: Line[]
@@ -113,6 +122,8 @@ type EditorState = {
}
type AppState = {
idForUpdateView: string
file: File | null
customMask: File | null
imageHeight: number
@@ -136,6 +147,7 @@ type AppAction = {
setCustomFile: (file: File) => void
setIsInpainting: (newValue: boolean) => void
setIsPluginRunning: (newValue: boolean) => void
getIsProcessing: () => boolean
setBaseBrushSize: (newValue: number) => void
getBrushSize: () => number
setImageSize: (width: number, height: number) => void
@@ -151,10 +163,18 @@ type AppAction = {
updateFileManagerState: (newState: Partial<FileManagerState>) => void
updateInteractiveSegState: (newState: Partial<InteractiveSegState>) => void
resetInteractiveSegState: () => void
handleInteractiveSegAccept: () => void
showPromptInput: () => boolean
showSidePanel: () => boolean
runInpainting: () => Promise<void>
runRenderablePlugin: (
pluginName: string,
params?: PluginParams
) => Promise<void>
// EditorState
getCurrentTargetFile: () => Promise<File>
updateEditorState: (newState: Partial<EditorState>) => void
runMannually: () => boolean
handleCanvasMouseDown: (point: Point) => void
@@ -168,12 +188,15 @@ type AppAction = {
}
const defaultValues: AppState = {
idForUpdateView: nanoid(),
file: null,
customMask: null,
imageHeight: 0,
imageWidth: 0,
isInpainting: false,
isPluginRunning: false,
windowSize: {
height: 600,
width: 800,
@@ -182,6 +205,8 @@ const defaultValues: AppState = {
baseBrushSize: DEFAULT_BRUSH_SIZE,
brushSizeScale: 1,
renders: [],
paintByExampleImage: null,
extraMasks: [],
lineGroups: [],
lastLineGroup: [],
curLineGroup: [],
@@ -192,7 +217,9 @@ const defaultValues: AppState = {
interactiveSegState: {
isInteractiveSeg: false,
isInteractiveSegRunning: false,
interactiveSegMask: null,
tmpInteractiveSegMask: null,
prevInteractiveSegMask: null,
clicks: [],
},
@@ -267,16 +294,208 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
immer((set, get) => ({
...defaultValues,
getCurrentTargetFile: async (): Promise<File> => {
const file = get().file! // 一定是在 file 加载了以后才可能调用这个函数
const renders = get().editorState.renders
let targetFile = file
if (renders.length > 0) {
const lastRender = renders[renders.length - 1]
targetFile = await srcToFile(
lastRender.currentSrc,
file.name,
file.type
)
}
return targetFile
},
runInpainting: async () => {
const { file, imageWidth, imageHeight, settings, cropperState } = get()
if (file === null) {
return
}
const {
lastLineGroup,
curLineGroup,
lineGroups,
renders,
paintByExampleImage,
} = get().editorState
const { interactiveSegMask, prevInteractiveSegMask } =
get().interactiveSegState
const useLastLineGroup =
curLineGroup.length === 0 && interactiveSegMask === null
const maskImage = useLastLineGroup
? prevInteractiveSegMask
: interactiveSegMask
// useLastLineGroup 的影响
// 1. 使用上一次的 mask
// 2. 结果替换当前 render
let maskLineGroup: LineGroup = []
if (useLastLineGroup === true) {
if (lastLineGroup.length === 0 && maskImage === null) {
toast({
variant: "destructive",
description: "Please draw mask on picture",
})
return
}
maskLineGroup = lastLineGroup
} else {
if (curLineGroup.length === 0 && maskImage === null) {
toast({
variant: "destructive",
description: "Please draw mask on picture",
})
return
}
maskLineGroup = curLineGroup
}
const newLineGroups = [...lineGroups, maskLineGroup]
set((state) => {
state.isInpainting = true
})
let targetFile = file
if (useLastLineGroup === true) {
// renders.length == 1 还是用原来的
if (renders.length > 1) {
const lastRender = renders[renders.length - 2]
targetFile = await srcToFile(
lastRender.currentSrc,
file.name,
file.type
)
}
} else if (renders.length > 0) {
const lastRender = renders[renders.length - 1]
targetFile = await srcToFile(
lastRender.currentSrc,
file.name,
file.type
)
}
const maskCanvas = generateMask(
imageWidth,
imageHeight,
[maskLineGroup],
maskImage ? [maskImage] : []
)
try {
const res = await inpaint(
targetFile,
settings,
cropperState,
dataURItoBlob(maskCanvas.toDataURL()),
paintByExampleImage
)
if (!res) {
throw new Error("Something went wrong on server side.")
}
const { blob, seed } = res
if (seed) {
set((state) => (state.settings.seed = parseInt(seed, 10)))
}
const newRender = new Image()
await loadImage(newRender, blob)
if (useLastLineGroup === true) {
const prevRenders = renders.slice(0, -1)
const newRenders = [...prevRenders, newRender]
get().updateEditorState({
renders: newRenders,
lineGroups: newLineGroups,
lastLineGroup: curLineGroup,
curLineGroup: [],
})
} else {
const newRenders = [...renders, newRender]
get().updateEditorState({
renders: newRenders,
lineGroups: newLineGroups,
lastLineGroup: curLineGroup,
curLineGroup: [],
})
}
} catch (e: any) {
toast({
variant: "destructive",
description: e.message ? e.message : e.toString(),
})
}
get().resetRedoState()
set((state) => {
state.isInpainting = false
})
const newInteractiveSegState = {
...defaultValues.interactiveSegState,
prevInteractiveSegMask: useLastLineGroup ? null : maskImage,
}
set((state) => {
state.interactiveSegState = castDraft(newInteractiveSegState)
})
},
runRenderablePlugin: async (
pluginName: string,
params: PluginParams = { upscale: 1 }
) => {
const { renders, lineGroups } = get().editorState
set((state) => {
state.isInpainting = true
})
try {
const start = new Date()
const targetFile = await get().getCurrentTargetFile()
const res = await runPlugin(pluginName, targetFile, params.upscale)
if (!res) {
throw new Error("Something went wrong on server side.")
}
const { blob } = res
const newRender = new Image()
await loadImage(newRender, blob)
get().setImageSize(newRender.height, newRender.width)
const newRenders = [...renders, newRender]
const newLineGroups = [...lineGroups, []]
get().updateEditorState({
renders: newRenders,
lineGroups: newLineGroups,
})
const end = new Date()
const time = end.getTime() - start.getTime()
toast({
description: `Run ${pluginName} successfully in ${time / 1000}s`,
})
} catch (e: any) {
toast({
variant: "destructive",
description: e.message ? e.message : e.toString(),
})
}
set((state) => {
state.isInpainting = false
})
},
// Edirot State //
updateEditorState: (newState: Partial<EditorState>) => {
set((state) => {
return {
...state,
editorState: {
...state.editorState,
...newState,
},
}
state.editorState = castDraft({ ...state.editorState, ...newState })
})
},
@@ -313,6 +532,10 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
)
},
getIsProcessing: (): boolean => {
return get().isInpainting || get().isPluginRunning
},
// undo/redo
undoDisabled: (): boolean => {
@@ -468,15 +691,30 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
updateInteractiveSegState: (newState: Partial<InteractiveSegState>) => {
set((state) => {
state.interactiveSegState = {
...state.interactiveSegState,
...newState,
return {
...state,
interactiveSegState: {
...state.interactiveSegState,
...newState,
},
}
})
},
resetInteractiveSegState: () => {
get().updateInteractiveSegState(defaultValues.interactiveSegState)
},
handleInteractiveSegAccept: () => {
set((state) => {
state.interactiveSegState = defaultValues.interactiveSegState
return {
...state,
interactiveSegState: {
...defaultValues.interactiveSegState,
interactiveSegMask:
state.interactiveSegState.tmpInteractiveSegMask,
},
}
})
},
@@ -492,8 +730,12 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
setFile: (file: File) =>
set((state) => {
// TODO: 清空各种状态
state.file = file
state.interactiveSegState = castDraft(
defaultValues.interactiveSegState
)
state.editorState = castDraft(defaultValues.editorState)
state.cropperState = defaultValues.cropperState
}),
setCustomFile: (file: File) =>

View File

@@ -25,6 +25,10 @@ export enum PluginName {
InteractiveSeg = "InteractiveSeg",
}
export interface PluginParams {
upscale: number
}
export enum SortBy {
NAME = "name",
CTIME = "ctime",

View File

@@ -159,3 +159,28 @@ export function drawLines(
ctx.stroke()
})
}
export const generateMask = (
imageWidth: number,
imageHeight: number,
lineGroups: LineGroup[],
maskImages: HTMLImageElement[] = []
): HTMLCanvasElement => {
const maskCanvas = document.createElement("canvas")
maskCanvas.width = imageWidth
maskCanvas.height = imageHeight
const ctx = maskCanvas.getContext("2d")
if (!ctx) {
throw new Error("could not retrieve mask canvas")
}
maskImages.forEach((maskImage) => {
ctx.drawImage(maskImage, 0, 0, imageWidth, imageHeight)
})
lineGroups.forEach((lineGroup) => {
drawLines(ctx, lineGroup, "white")
})
return maskCanvas
}