This commit is contained in:
Qing
2023-12-17 22:15:48 +08:00
parent 7bd29ab290
commit f27fc51e34
9 changed files with 591 additions and 37 deletions

View File

@@ -18,7 +18,6 @@ export default async function inpaint(
mask: File | Blob,
paintByExampleImage: File | null = null
) {
// 1080, 2000, Original
const fd = new FormData()
fd.append("image", imageFile)
fd.append("mask", mask)
@@ -37,7 +36,7 @@ export default async function inpaint(
fd.append("croperY", croperRect.y.toString())
fd.append("croperHeight", croperRect.height.toString())
fd.append("croperWidth", croperRect.width.toString())
fd.append("useCroper", settings.showCroper ? "true" : "false")
fd.append("useCroper", settings.showCropper ? "true" : "false")
fd.append("sdMaskBlur", settings.sdMaskBlur.toString())
fd.append("sdStrength", settings.sdStrength.toString())
@@ -52,6 +51,9 @@ export default async function inpaint(
fd.append("sdMatchHistograms", settings.sdMatchHistograms ? "true" : "false")
fd.append("sdScale", (settings.sdScale / 100).toString())
fd.append("enableFreeu", settings.enableFreeu.toString())
fd.append("freeuConfig", JSON.stringify(settings.freeuConfig))
fd.append("enableLCMLora", settings.enableLCMLora.toString())
fd.append("cv2Radius", settings.cv2Radius.toString())
fd.append("cv2Flag", settings.cv2Flag.toString())

View File

@@ -18,12 +18,19 @@ import {
SortOrder,
} from "./types"
import {
BRUSH_COLOR,
DEFAULT_BRUSH_SIZE,
DEFAULT_NEGATIVE_PROMPT,
MODEL_TYPE_INPAINT,
PAINT_BY_EXAMPLE,
} from "./const"
import { dataURItoBlob, generateMask, loadImage, srcToFile } from "./utils"
import {
canvasToImage,
dataURItoBlob,
generateMask,
loadImage,
srcToFile,
} from "./utils"
import inpaint, { runPlugin } from "./api"
import { toast } from "@/components/ui/use-toast"
@@ -48,7 +55,8 @@ export type Settings = {
enableDownloadMask: boolean
enableManualInpainting: boolean
enableUploadMask: boolean
showCroper: boolean
showCropper: boolean
showExpender: boolean
// For LDM
ldmSteps: number
@@ -134,6 +142,7 @@ type AppState = {
interactiveSegState: InteractiveSegState
fileManagerState: FileManagerState
cropperState: CropperState
extenderState: CropperState
serverConfig: ServerConfig
settings: Settings
@@ -155,9 +164,15 @@ type AppAction = {
setCropperWidth: (newValue: number) => void
setCropperHeight: (newValue: number) => void
setExtenderX: (newValue: number) => void
setExtenderY: (newValue: number) => void
setExtenderWidth: (newValue: number) => void
setExtenderHeight: (newValue: number) => void
setServerConfig: (newValue: ServerConfig) => void
setSeed: (newValue: number) => void
updateSettings: (newSettings: Partial<Settings>) => void
setModel: (newModel: ModelInfo) => void
updateFileManagerState: (newState: Partial<FileManagerState>) => void
updateInteractiveSegState: (newState: Partial<InteractiveSegState>) => void
resetInteractiveSegState: () => void
@@ -166,6 +181,8 @@ type AppAction = {
showSidePanel: () => boolean
runInpainting: () => Promise<void>
showPrevMask: () => Promise<void>
hidePrevMask: () => void
runRenderablePlugin: (
pluginName: string,
params?: PluginParams
@@ -226,6 +243,13 @@ const defaultValues: AppState = {
width: 512,
height: 512,
},
extenderState: {
x: 0,
y: 0,
width: 512,
height: 512,
},
fileManagerState: {
sortBy: SortBy.CTIME,
sortOrder: SortOrder.DESCENDING,
@@ -248,6 +272,7 @@ const defaultValues: AppState = {
model_type: "inpaint",
support_controlnet: false,
support_strength: false,
support_outpainting: false,
controlnets: [],
support_freeu: false,
support_lcm_lora: false,
@@ -255,7 +280,8 @@ const defaultValues: AppState = {
need_prompt: false,
},
enableControlnet: false,
showCroper: false,
showCropper: false,
showExpender: false,
enableDownloadMask: false,
enableManualInpainting: false,
enableUploadMask: false,
@@ -289,6 +315,38 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
immer((set, get) => ({
...defaultValues,
showPrevMask: async () => {
const { lastLineGroup, curLineGroup } = get().editorState
const { prevInteractiveSegMask, interactiveSegMask } =
get().interactiveSegState
if (curLineGroup.length !== 0 || interactiveSegMask !== null) {
return
}
const { imageWidth, imageHeight } = get()
const maskCanvas = generateMask(
imageWidth,
imageHeight,
[lastLineGroup],
prevInteractiveSegMask ? [prevInteractiveSegMask] : [],
BRUSH_COLOR
)
try {
const maskImage = await canvasToImage(maskCanvas)
set((state) => {
state.editorState.extraMasks.push(castDraft(maskImage))
})
} catch (e) {
console.error(e)
return
}
},
hidePrevMask: () => {
set((state) => {
state.editorState.extraMasks = []
})
},
getCurrentTargetFile: async (): Promise<File> => {
const file = get().file! // 一定是在 file 加载了以后才可能调用这个函数
const renders = get().editorState.renders
@@ -415,7 +473,7 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
get().updateEditorState({
renders: newRenders,
lineGroups: newLineGroups,
lastLineGroup: curLineGroup,
lastLineGroup: maskLineGroup,
curLineGroup: [],
})
} catch (e: any) {
@@ -432,7 +490,7 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
const newInteractiveSegState = {
...defaultValues.interactiveSegState,
prevInteractiveSegMask: useLastLineGroup ? null : maskImage,
prevInteractiveSegMask: maskImage,
}
set((state) => {
@@ -675,6 +733,19 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
})
},
setModel: (newModel: ModelInfo) => {
set((state) => {
state.settings.model = newModel
if (
newModel.support_controlnet &&
!newModel.controlnets.includes(state.settings.controlnetMethod)
) {
state.settings.controlnetMethod = newModel.controlnets[0]
}
})
},
updateFileManagerState: (newState: Partial<FileManagerState>) => {
set((state) => {
state.fileManagerState = {
@@ -773,6 +844,26 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
state.cropperState.height = newValue
}),
setExtenderX: (newValue: number) =>
set((state) => {
state.extenderState.x = newValue
}),
setExtenderY: (newValue: number) =>
set((state) => {
state.extenderState.y = newValue
}),
setExtenderWidth: (newValue: number) =>
set((state) => {
state.extenderState.width = newValue
}),
setExtenderHeight: (newValue: number) =>
set((state) => {
state.extenderState.height = newValue
}),
setSeed: (newValue: number) =>
set((state) => {
state.settings.seed = newValue

View File

@@ -9,6 +9,7 @@ export interface ModelInfo {
| "diffusers_sdxl_inpaint"
| "diffusers_other"
support_strength: boolean
support_outpainting: boolean
support_controlnet: boolean
controlnets: string[]
support_freeu: boolean
@@ -66,6 +67,7 @@ export enum SDSampler {
kEulerA = "k_euler_a",
dpmPlusPlus = "dpm++",
uni_pc = "uni_pc",
lcm = "lcm",
}
export interface FreeuConfig {

View File

@@ -53,6 +53,24 @@ export function loadImage(image: HTMLImageElement, src: string) {
})
}
export function canvasToImage(
canvas: HTMLCanvasElement
): Promise<HTMLImageElement> {
return new Promise((resolve, reject) => {
const image = new Image()
image.addEventListener("load", () => {
resolve(image)
})
image.addEventListener("error", (error) => {
reject(error)
})
image.src = canvas.toDataURL()
})
}
export function srcToFile(src: string, fileName: string, mimeType: string) {
return fetch(src)
.then(function (res) {
@@ -164,7 +182,8 @@ export const generateMask = (
imageWidth: number,
imageHeight: number,
lineGroups: LineGroup[],
maskImages: HTMLImageElement[] = []
maskImages: HTMLImageElement[] = [],
lineGroupsColor: string = "white"
): HTMLCanvasElement => {
const maskCanvas = document.createElement("canvas")
maskCanvas.width = imageWidth
@@ -179,7 +198,7 @@ export const generateMask = (
})
lineGroups.forEach((lineGroup) => {
drawLines(ctx, lineGroup, "white")
drawLines(ctx, lineGroup, lineGroupsColor)
})
return maskCanvas