wip
This commit is contained in:
@@ -15,18 +15,13 @@ export default async function inpaint(
|
||||
imageFile: File,
|
||||
settings: Settings,
|
||||
croperRect: Rect,
|
||||
maskBase64?: string,
|
||||
customMask?: File,
|
||||
paintByExampleImage?: File
|
||||
mask: File | Blob,
|
||||
paintByExampleImage: File | null = null
|
||||
) {
|
||||
// 1080, 2000, Original
|
||||
const fd = new FormData()
|
||||
fd.append("image", imageFile)
|
||||
if (maskBase64 !== undefined) {
|
||||
fd.append("mask", dataURItoBlob(maskBase64))
|
||||
} else if (customMask !== undefined) {
|
||||
fd.append("mask", customMask)
|
||||
}
|
||||
fd.append("mask", mask)
|
||||
|
||||
fd.append("ldmSteps", settings.ldmSteps.toString())
|
||||
fd.append("ldmSampler", settings.ldmSampler.toString())
|
||||
@@ -42,8 +37,7 @@ export default async function inpaint(
|
||||
fd.append("croperY", croperRect.y.toString())
|
||||
fd.append("croperHeight", croperRect.height.toString())
|
||||
fd.append("croperWidth", croperRect.width.toString())
|
||||
// fd.append("useCroper", settings.showCroper ? "true" : "false")
|
||||
fd.append("useCroper", "false")
|
||||
fd.append("useCroper", settings.showCroper ? "true" : "false")
|
||||
|
||||
fd.append("sdMaskBlur", settings.sdMaskBlur.toString())
|
||||
fd.append("sdStrength", settings.sdStrength.toString())
|
||||
@@ -147,7 +141,6 @@ export async function runPlugin(
|
||||
name: string,
|
||||
imageFile: File,
|
||||
upscale?: number,
|
||||
maskFile?: File | null,
|
||||
clicks?: number[][]
|
||||
) {
|
||||
const fd = new FormData()
|
||||
@@ -159,9 +152,6 @@ export async function runPlugin(
|
||||
if (clicks) {
|
||||
fd.append("clicks", JSON.stringify(clicks))
|
||||
}
|
||||
if (maskFile) {
|
||||
fd.append("mask", maskFile)
|
||||
}
|
||||
|
||||
try {
|
||||
const res = await fetch(`${API_ENDPOINT}/run_plugin`, {
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import { create } from "zustand"
|
||||
import { persist } from "zustand/middleware"
|
||||
import { shallow } from "zustand/shallow"
|
||||
import { immer } from "zustand/middleware/immer"
|
||||
import { castDraft } from "immer"
|
||||
import { nanoid } from "nanoid"
|
||||
import { createWithEqualityFn } from "zustand/traditional"
|
||||
import {
|
||||
CV2Flag,
|
||||
@@ -10,6 +11,7 @@ import {
|
||||
Line,
|
||||
LineGroup,
|
||||
ModelInfo,
|
||||
PluginParams,
|
||||
Point,
|
||||
SDSampler,
|
||||
Size,
|
||||
@@ -17,6 +19,9 @@ import {
|
||||
SortOrder,
|
||||
} from "./types"
|
||||
import { DEFAULT_BRUSH_SIZE, MODEL_TYPE_INPAINT } from "./const"
|
||||
import { dataURItoBlob, generateMask, loadImage, srcToFile } from "./utils"
|
||||
import inpaint, { runPlugin } from "./api"
|
||||
import { toast, useToast } from "@/components/ui/use-toast"
|
||||
|
||||
type FileManagerState = {
|
||||
sortBy: SortBy
|
||||
@@ -95,7 +100,9 @@ type ServerConfig = {
|
||||
|
||||
type InteractiveSegState = {
|
||||
isInteractiveSeg: boolean
|
||||
isInteractiveSegRunning: boolean
|
||||
interactiveSegMask: HTMLImageElement | null
|
||||
tmpInteractiveSegMask: HTMLImageElement | null
|
||||
prevInteractiveSegMask: HTMLImageElement | null
|
||||
clicks: number[][]
|
||||
}
|
||||
|
||||
@@ -103,9 +110,11 @@ type EditorState = {
|
||||
baseBrushSize: number
|
||||
brushSizeScale: number
|
||||
renders: HTMLImageElement[]
|
||||
paintByExampleImage: File | null
|
||||
lineGroups: LineGroup[]
|
||||
lastLineGroup: LineGroup
|
||||
curLineGroup: LineGroup
|
||||
extraMasks: HTMLImageElement[]
|
||||
// redo 相关
|
||||
redoRenders: HTMLImageElement[]
|
||||
redoCurLines: Line[]
|
||||
@@ -113,6 +122,8 @@ type EditorState = {
|
||||
}
|
||||
|
||||
type AppState = {
|
||||
idForUpdateView: string
|
||||
|
||||
file: File | null
|
||||
customMask: File | null
|
||||
imageHeight: number
|
||||
@@ -136,6 +147,7 @@ type AppAction = {
|
||||
setCustomFile: (file: File) => void
|
||||
setIsInpainting: (newValue: boolean) => void
|
||||
setIsPluginRunning: (newValue: boolean) => void
|
||||
getIsProcessing: () => boolean
|
||||
setBaseBrushSize: (newValue: number) => void
|
||||
getBrushSize: () => number
|
||||
setImageSize: (width: number, height: number) => void
|
||||
@@ -151,10 +163,18 @@ type AppAction = {
|
||||
updateFileManagerState: (newState: Partial<FileManagerState>) => void
|
||||
updateInteractiveSegState: (newState: Partial<InteractiveSegState>) => void
|
||||
resetInteractiveSegState: () => void
|
||||
handleInteractiveSegAccept: () => void
|
||||
showPromptInput: () => boolean
|
||||
showSidePanel: () => boolean
|
||||
|
||||
runInpainting: () => Promise<void>
|
||||
runRenderablePlugin: (
|
||||
pluginName: string,
|
||||
params?: PluginParams
|
||||
) => Promise<void>
|
||||
|
||||
// EditorState
|
||||
getCurrentTargetFile: () => Promise<File>
|
||||
updateEditorState: (newState: Partial<EditorState>) => void
|
||||
runMannually: () => boolean
|
||||
handleCanvasMouseDown: (point: Point) => void
|
||||
@@ -168,12 +188,15 @@ type AppAction = {
|
||||
}
|
||||
|
||||
const defaultValues: AppState = {
|
||||
idForUpdateView: nanoid(),
|
||||
|
||||
file: null,
|
||||
customMask: null,
|
||||
imageHeight: 0,
|
||||
imageWidth: 0,
|
||||
isInpainting: false,
|
||||
isPluginRunning: false,
|
||||
|
||||
windowSize: {
|
||||
height: 600,
|
||||
width: 800,
|
||||
@@ -182,6 +205,8 @@ const defaultValues: AppState = {
|
||||
baseBrushSize: DEFAULT_BRUSH_SIZE,
|
||||
brushSizeScale: 1,
|
||||
renders: [],
|
||||
paintByExampleImage: null,
|
||||
extraMasks: [],
|
||||
lineGroups: [],
|
||||
lastLineGroup: [],
|
||||
curLineGroup: [],
|
||||
@@ -192,7 +217,9 @@ const defaultValues: AppState = {
|
||||
|
||||
interactiveSegState: {
|
||||
isInteractiveSeg: false,
|
||||
isInteractiveSegRunning: false,
|
||||
interactiveSegMask: null,
|
||||
tmpInteractiveSegMask: null,
|
||||
prevInteractiveSegMask: null,
|
||||
clicks: [],
|
||||
},
|
||||
|
||||
@@ -267,16 +294,208 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
|
||||
immer((set, get) => ({
|
||||
...defaultValues,
|
||||
|
||||
getCurrentTargetFile: async (): Promise<File> => {
|
||||
const file = get().file! // 一定是在 file 加载了以后才可能调用这个函数
|
||||
const renders = get().editorState.renders
|
||||
|
||||
let targetFile = file
|
||||
if (renders.length > 0) {
|
||||
const lastRender = renders[renders.length - 1]
|
||||
targetFile = await srcToFile(
|
||||
lastRender.currentSrc,
|
||||
file.name,
|
||||
file.type
|
||||
)
|
||||
}
|
||||
return targetFile
|
||||
},
|
||||
|
||||
runInpainting: async () => {
|
||||
const { file, imageWidth, imageHeight, settings, cropperState } = get()
|
||||
|
||||
if (file === null) {
|
||||
return
|
||||
}
|
||||
const {
|
||||
lastLineGroup,
|
||||
curLineGroup,
|
||||
lineGroups,
|
||||
renders,
|
||||
paintByExampleImage,
|
||||
} = get().editorState
|
||||
|
||||
const { interactiveSegMask, prevInteractiveSegMask } =
|
||||
get().interactiveSegState
|
||||
|
||||
const useLastLineGroup =
|
||||
curLineGroup.length === 0 && interactiveSegMask === null
|
||||
|
||||
const maskImage = useLastLineGroup
|
||||
? prevInteractiveSegMask
|
||||
: interactiveSegMask
|
||||
|
||||
// useLastLineGroup 的影响
|
||||
// 1. 使用上一次的 mask
|
||||
// 2. 结果替换当前 render
|
||||
let maskLineGroup: LineGroup = []
|
||||
if (useLastLineGroup === true) {
|
||||
if (lastLineGroup.length === 0 && maskImage === null) {
|
||||
toast({
|
||||
variant: "destructive",
|
||||
description: "Please draw mask on picture",
|
||||
})
|
||||
return
|
||||
}
|
||||
maskLineGroup = lastLineGroup
|
||||
} else {
|
||||
if (curLineGroup.length === 0 && maskImage === null) {
|
||||
toast({
|
||||
variant: "destructive",
|
||||
description: "Please draw mask on picture",
|
||||
})
|
||||
return
|
||||
}
|
||||
maskLineGroup = curLineGroup
|
||||
}
|
||||
|
||||
const newLineGroups = [...lineGroups, maskLineGroup]
|
||||
|
||||
set((state) => {
|
||||
state.isInpainting = true
|
||||
})
|
||||
|
||||
let targetFile = file
|
||||
if (useLastLineGroup === true) {
|
||||
// renders.length == 1 还是用原来的
|
||||
if (renders.length > 1) {
|
||||
const lastRender = renders[renders.length - 2]
|
||||
targetFile = await srcToFile(
|
||||
lastRender.currentSrc,
|
||||
file.name,
|
||||
file.type
|
||||
)
|
||||
}
|
||||
} else if (renders.length > 0) {
|
||||
const lastRender = renders[renders.length - 1]
|
||||
targetFile = await srcToFile(
|
||||
lastRender.currentSrc,
|
||||
file.name,
|
||||
file.type
|
||||
)
|
||||
}
|
||||
|
||||
const maskCanvas = generateMask(
|
||||
imageWidth,
|
||||
imageHeight,
|
||||
[maskLineGroup],
|
||||
maskImage ? [maskImage] : []
|
||||
)
|
||||
|
||||
try {
|
||||
const res = await inpaint(
|
||||
targetFile,
|
||||
settings,
|
||||
cropperState,
|
||||
dataURItoBlob(maskCanvas.toDataURL()),
|
||||
paintByExampleImage
|
||||
)
|
||||
|
||||
if (!res) {
|
||||
throw new Error("Something went wrong on server side.")
|
||||
}
|
||||
|
||||
const { blob, seed } = res
|
||||
if (seed) {
|
||||
set((state) => (state.settings.seed = parseInt(seed, 10)))
|
||||
}
|
||||
const newRender = new Image()
|
||||
await loadImage(newRender, blob)
|
||||
if (useLastLineGroup === true) {
|
||||
const prevRenders = renders.slice(0, -1)
|
||||
const newRenders = [...prevRenders, newRender]
|
||||
get().updateEditorState({
|
||||
renders: newRenders,
|
||||
lineGroups: newLineGroups,
|
||||
lastLineGroup: curLineGroup,
|
||||
curLineGroup: [],
|
||||
})
|
||||
} else {
|
||||
const newRenders = [...renders, newRender]
|
||||
get().updateEditorState({
|
||||
renders: newRenders,
|
||||
lineGroups: newLineGroups,
|
||||
lastLineGroup: curLineGroup,
|
||||
curLineGroup: [],
|
||||
})
|
||||
}
|
||||
} catch (e: any) {
|
||||
toast({
|
||||
variant: "destructive",
|
||||
description: e.message ? e.message : e.toString(),
|
||||
})
|
||||
}
|
||||
|
||||
get().resetRedoState()
|
||||
set((state) => {
|
||||
state.isInpainting = false
|
||||
})
|
||||
|
||||
const newInteractiveSegState = {
|
||||
...defaultValues.interactiveSegState,
|
||||
prevInteractiveSegMask: useLastLineGroup ? null : maskImage,
|
||||
}
|
||||
|
||||
set((state) => {
|
||||
state.interactiveSegState = castDraft(newInteractiveSegState)
|
||||
})
|
||||
},
|
||||
|
||||
runRenderablePlugin: async (
|
||||
pluginName: string,
|
||||
params: PluginParams = { upscale: 1 }
|
||||
) => {
|
||||
const { renders, lineGroups } = get().editorState
|
||||
set((state) => {
|
||||
state.isInpainting = true
|
||||
})
|
||||
|
||||
try {
|
||||
const start = new Date()
|
||||
const targetFile = await get().getCurrentTargetFile()
|
||||
const res = await runPlugin(pluginName, targetFile, params.upscale)
|
||||
if (!res) {
|
||||
throw new Error("Something went wrong on server side.")
|
||||
}
|
||||
const { blob } = res
|
||||
const newRender = new Image()
|
||||
await loadImage(newRender, blob)
|
||||
get().setImageSize(newRender.height, newRender.width)
|
||||
const newRenders = [...renders, newRender]
|
||||
const newLineGroups = [...lineGroups, []]
|
||||
get().updateEditorState({
|
||||
renders: newRenders,
|
||||
lineGroups: newLineGroups,
|
||||
})
|
||||
const end = new Date()
|
||||
const time = end.getTime() - start.getTime()
|
||||
toast({
|
||||
description: `Run ${pluginName} successfully in ${time / 1000}s`,
|
||||
})
|
||||
} catch (e: any) {
|
||||
toast({
|
||||
variant: "destructive",
|
||||
description: e.message ? e.message : e.toString(),
|
||||
})
|
||||
}
|
||||
set((state) => {
|
||||
state.isInpainting = false
|
||||
})
|
||||
},
|
||||
|
||||
// Edirot State //
|
||||
updateEditorState: (newState: Partial<EditorState>) => {
|
||||
set((state) => {
|
||||
return {
|
||||
...state,
|
||||
editorState: {
|
||||
...state.editorState,
|
||||
...newState,
|
||||
},
|
||||
}
|
||||
state.editorState = castDraft({ ...state.editorState, ...newState })
|
||||
})
|
||||
},
|
||||
|
||||
@@ -313,6 +532,10 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
|
||||
)
|
||||
},
|
||||
|
||||
getIsProcessing: (): boolean => {
|
||||
return get().isInpainting || get().isPluginRunning
|
||||
},
|
||||
|
||||
// undo/redo
|
||||
|
||||
undoDisabled: (): boolean => {
|
||||
@@ -468,15 +691,30 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
|
||||
|
||||
updateInteractiveSegState: (newState: Partial<InteractiveSegState>) => {
|
||||
set((state) => {
|
||||
state.interactiveSegState = {
|
||||
...state.interactiveSegState,
|
||||
...newState,
|
||||
return {
|
||||
...state,
|
||||
interactiveSegState: {
|
||||
...state.interactiveSegState,
|
||||
...newState,
|
||||
},
|
||||
}
|
||||
})
|
||||
},
|
||||
|
||||
resetInteractiveSegState: () => {
|
||||
get().updateInteractiveSegState(defaultValues.interactiveSegState)
|
||||
},
|
||||
|
||||
handleInteractiveSegAccept: () => {
|
||||
set((state) => {
|
||||
state.interactiveSegState = defaultValues.interactiveSegState
|
||||
return {
|
||||
...state,
|
||||
interactiveSegState: {
|
||||
...defaultValues.interactiveSegState,
|
||||
interactiveSegMask:
|
||||
state.interactiveSegState.tmpInteractiveSegMask,
|
||||
},
|
||||
}
|
||||
})
|
||||
},
|
||||
|
||||
@@ -492,8 +730,12 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
|
||||
|
||||
setFile: (file: File) =>
|
||||
set((state) => {
|
||||
// TODO: 清空各种状态
|
||||
state.file = file
|
||||
state.interactiveSegState = castDraft(
|
||||
defaultValues.interactiveSegState
|
||||
)
|
||||
state.editorState = castDraft(defaultValues.editorState)
|
||||
state.cropperState = defaultValues.cropperState
|
||||
}),
|
||||
|
||||
setCustomFile: (file: File) =>
|
||||
|
||||
@@ -25,6 +25,10 @@ export enum PluginName {
|
||||
InteractiveSeg = "InteractiveSeg",
|
||||
}
|
||||
|
||||
export interface PluginParams {
|
||||
upscale: number
|
||||
}
|
||||
|
||||
export enum SortBy {
|
||||
NAME = "name",
|
||||
CTIME = "ctime",
|
||||
|
||||
@@ -159,3 +159,28 @@ export function drawLines(
|
||||
ctx.stroke()
|
||||
})
|
||||
}
|
||||
|
||||
export const generateMask = (
|
||||
imageWidth: number,
|
||||
imageHeight: number,
|
||||
lineGroups: LineGroup[],
|
||||
maskImages: HTMLImageElement[] = []
|
||||
): HTMLCanvasElement => {
|
||||
const maskCanvas = document.createElement("canvas")
|
||||
maskCanvas.width = imageWidth
|
||||
maskCanvas.height = imageHeight
|
||||
const ctx = maskCanvas.getContext("2d")
|
||||
if (!ctx) {
|
||||
throw new Error("could not retrieve mask canvas")
|
||||
}
|
||||
|
||||
maskImages.forEach((maskImage) => {
|
||||
ctx.drawImage(maskImage, 0, 0, imageWidth, imageHeight)
|
||||
})
|
||||
|
||||
lineGroups.forEach((lineGroup) => {
|
||||
drawLines(ctx, lineGroup, "white")
|
||||
})
|
||||
|
||||
return maskCanvas
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user