make plugin work

This commit is contained in:
Qing
2023-03-25 09:53:22 +08:00
parent 996a264797
commit 6e54f77ed6
16 changed files with 528 additions and 284 deletions

View File

@@ -4,24 +4,11 @@ import { nanoid } from 'nanoid'
import useInputImage from './hooks/useInputImage' import useInputImage from './hooks/useInputImage'
import { themeState } from './components/Header/ThemeChanger' import { themeState } from './components/Header/ThemeChanger'
import Workspace from './components/Workspace' import Workspace from './components/Workspace'
import { import { fileState, serverConfigState, toastState } from './store/Atoms'
enableFileManagerState,
fileState,
isControlNetState,
isDisableModelSwitchState,
isEnableAutoSavingState,
toastState,
} from './store/Atoms'
import { keepGUIAlive } from './utils' import { keepGUIAlive } from './utils'
import Header from './components/Header/Header' import Header from './components/Header/Header'
import useHotKey from './hooks/useHotkey' import useHotKey from './hooks/useHotkey'
import { import { getServerConfig, isDesktop } from './adapters/inpainting'
getEnableAutoSaving,
getEnableFileManager,
getIsControlNet,
getIsDisableModelSwitch,
isDesktop,
} from './adapters/inpainting'
const SUPPORTED_FILE_TYPE = [ const SUPPORTED_FILE_TYPE = [
'image/jpeg', 'image/jpeg',
@@ -36,10 +23,7 @@ function App() {
const [theme, setTheme] = useRecoilState(themeState) const [theme, setTheme] = useRecoilState(themeState)
const setToastState = useSetRecoilState(toastState) const setToastState = useSetRecoilState(toastState)
const userInputImage = useInputImage() const userInputImage = useInputImage()
const setIsDisableModelSwitch = useSetRecoilState(isDisableModelSwitchState) const setServerConfigState = useSetRecoilState(serverConfigState)
const setEnableFileManager = useSetRecoilState(enableFileManagerState)
const setIsEnableAutoSavingState = useSetRecoilState(isEnableAutoSavingState)
const setIsControlNet = useSetRecoilState(isControlNetState)
// Set Input Image // Set Input Image
useEffect(() => { useEffect(() => {
@@ -58,38 +42,13 @@ function App() {
}, []) }, [])
useEffect(() => { useEffect(() => {
const fetchData = async () => { const fetchServerConfig = async () => {
const isDisable: string = await getIsDisableModelSwitch().then(res => const serverConfig = await getServerConfig().then(res => res.json())
res.text() console.log(serverConfig)
) setServerConfigState(serverConfig)
setIsDisableModelSwitch(isDisable === 'true')
} }
fetchServerConfig()
fetchData() }, [])
const fetchData2 = async () => {
const isEnabled = await getEnableFileManager().then(res => res.text())
setEnableFileManager(isEnabled === 'true')
}
fetchData2()
const fetchData3 = async () => {
const isEnabled = await getEnableAutoSaving().then(res => res.text())
setIsEnableAutoSavingState(isEnabled === 'true')
}
fetchData3()
const fetchData4 = async () => {
const isEnabled = await getIsControlNet().then(res => res.text())
setIsControlNet(isEnabled === 'true')
}
fetchData4()
}, [
setEnableFileManager,
setIsDisableModelSwitch,
setIsEnableAutoSavingState,
setIsControlNet,
])
// Dark Mode Hotkey // Dark Mode Hotkey
useHotKey( useHotKey(

View File

@@ -1,3 +1,4 @@
import { PluginName } from '../components/Plugins/Plugins'
import { Rect, Settings } from '../store/Atoms' import { Rect, Settings } from '../store/Atoms'
import { dataURItoBlob, loadImage, srcToFile } from '../utils' import { dataURItoBlob, loadImage, srcToFile } from '../utils'
@@ -116,26 +117,8 @@ export default async function inpaint(
} }
} }
export function getIsDisableModelSwitch() { export function getServerConfig() {
return fetch(`${API_ENDPOINT}/is_disable_model_switch`, { return fetch(`${API_ENDPOINT}/server_config`, {
method: 'GET',
})
}
export function getIsControlNet() {
return fetch(`${API_ENDPOINT}/is_controlnet`, {
method: 'GET',
})
}
export function getEnableFileManager() {
return fetch(`${API_ENDPOINT}/is_enable_file_manager`, {
method: 'GET',
})
}
export function getEnableAutoSaving() {
return fetch(`${API_ENDPOINT}/is_enable_auto_saving`, {
method: 'GET', method: 'GET',
}) })
} }
@@ -167,20 +150,24 @@ export function modelDownloaded(name: string) {
}) })
} }
export async function postInteractiveSeg( export async function runPlugin(
name: string,
imageFile: File, imageFile: File,
maskFile: File | null, maskFile?: File | null,
clicks: number[][] clicks?: number[][]
) { ) {
const fd = new FormData() const fd = new FormData()
fd.append('name', name)
fd.append('image', imageFile) fd.append('image', imageFile)
fd.append('clicks', JSON.stringify(clicks)) if (clicks) {
if (maskFile !== null) { fd.append('clicks', JSON.stringify(clicks))
}
if (maskFile) {
fd.append('mask', maskFile) fd.append('mask', maskFile)
} }
try { try {
const res = await fetch(`${API_ENDPOINT}/interactive_seg`, { const res = await fetch(`${API_ENDPOINT}/run_plugin`, {
method: 'POST', method: 'POST',
body: fd, body: fd,
}) })
@@ -255,12 +242,13 @@ export async function makeGif(
) { ) {
const cleanFile = await srcToFile(cleanImage.src, filename, mimeType) const cleanFile = await srcToFile(cleanImage.src, filename, mimeType)
const fd = new FormData() const fd = new FormData()
fd.append('origin_img', originFile) fd.append('name', PluginName.MakeGIF)
fd.append('image', originFile)
fd.append('clean_img', cleanFile) fd.append('clean_img', cleanFile)
fd.append('filename', filename) fd.append('filename', filename)
try { try {
const res = await fetch(`${API_ENDPOINT}/make_gif`, { const res = await fetch(`${API_ENDPOINT}/run_plugin`, {
method: 'POST', method: 'POST',
body: fd, body: fd,
}) })

View File

@@ -18,10 +18,7 @@ import {
} from 'react-zoom-pan-pinch' } from 'react-zoom-pan-pinch'
import { useRecoilState, useRecoilValue, useSetRecoilState } from 'recoil' import { useRecoilState, useRecoilValue, useSetRecoilState } from 'recoil'
import { useWindowSize, useKey, useKeyPressEvent } from 'react-use' import { useWindowSize, useKey, useKeyPressEvent } from 'react-use'
import inpaint, { import inpaint, { downloadToOutput, runPlugin } from '../../adapters/inpainting'
downloadToOutput,
postInteractiveSeg,
} from '../../adapters/inpainting'
import Button from '../shared/Button' import Button from '../shared/Button'
import Slider from './Slider' import Slider from './Slider'
import SizeSelector from './SizeSelector' import SizeSelector from './SizeSelector'
@@ -50,6 +47,8 @@ import {
isInteractiveSegRunningState, isInteractiveSegRunningState,
isInteractiveSegState, isInteractiveSegState,
isPix2PixState, isPix2PixState,
isPluginRunningState,
isProcessingState,
negativePropmtState, negativePropmtState,
propmtState, propmtState,
runManuallyState, runManuallyState,
@@ -69,6 +68,7 @@ import FileSelect from '../FileSelect/FileSelect'
import InteractiveSeg from '../InteractiveSeg/InteractiveSeg' import InteractiveSeg from '../InteractiveSeg/InteractiveSeg'
import InteractiveSegConfirmActions from '../InteractiveSeg/ConfirmActions' import InteractiveSegConfirmActions from '../InteractiveSeg/ConfirmActions'
import InteractiveSegReplaceModal from '../InteractiveSeg/ReplaceModal' import InteractiveSegReplaceModal from '../InteractiveSeg/ReplaceModal'
import { PluginName } from '../Plugins/Plugins'
import MakeGIF from './MakeGIF' import MakeGIF from './MakeGIF'
const TOOLBAR_SIZE = 200 const TOOLBAR_SIZE = 200
@@ -118,13 +118,15 @@ export default function Editor() {
const croperRect = useRecoilValue(croperState) const croperRect = useRecoilValue(croperState)
const setToastState = useSetRecoilState(toastState) const setToastState = useSetRecoilState(toastState)
const [isInpainting, setIsInpainting] = useRecoilState(isInpaintingState) const [isInpainting, setIsInpainting] = useRecoilState(isInpaintingState)
const setIsPluginRunning = useSetRecoilState(isPluginRunningState)
const isProcessing = useRecoilValue(isProcessingState)
const runMannually = useRecoilValue(runManuallyState) const runMannually = useRecoilValue(runManuallyState)
const isDiffusionModels = useRecoilValue(isDiffusionModelsState) const isDiffusionModels = useRecoilValue(isDiffusionModelsState)
const isPix2Pix = useRecoilValue(isPix2PixState) const isPix2Pix = useRecoilValue(isPix2PixState)
const [isInteractiveSeg, setIsInteractiveSeg] = useRecoilState( const [isInteractiveSeg, setIsInteractiveSeg] = useRecoilState(
isInteractiveSegState isInteractiveSegState
) )
const [isInteractiveSegRunning, setIsInteractiveSegRunning] = useRecoilState( const setIsInteractiveSegRunning = useSetRecoilState(
isInteractiveSegRunningState isInteractiveSegRunningState
) )
@@ -538,6 +540,77 @@ export default function Editor() {
} }
}, [runInpainting]) }, [runInpainting])
const getCurrentRender = useCallback(async () => {
let targetFile = file
if (renders.length > 0) {
const lastRender = renders[renders.length - 1]
targetFile = await srcToFile(lastRender.currentSrc, file.name, file.type)
}
return targetFile
}, [file, renders])
useEffect(() => {
emitter.on(PluginName.InteractiveSeg, () => {
setIsInteractiveSeg(true)
if (interactiveSegMask !== null) {
setShowInteractiveSegModal(true)
}
})
return () => {
emitter.off(PluginName.InteractiveSeg)
}
})
const runRenderablePlugin = useCallback(
async (name: string) => {
if (isProcessing) {
return
}
try {
// TODO 要不要加 undoCurrentLine
setIsPluginRunning(true)
const targetFile = await getCurrentRender()
const res = await runPlugin(name, targetFile)
if (!res) {
throw new Error('Something went wrong on server side.')
}
const { blob } = res
const newRender = new Image()
await loadImage(newRender, blob)
const newRenders = [...renders, newRender]
setRenders(newRenders)
} catch (e: any) {
setToastState({
open: true,
desc: e.message ? e.message : e.toString(),
state: 'error',
duration: 3000,
})
} finally {
setIsPluginRunning(false)
}
},
[renders, setRenders, getCurrentRender, setIsPluginRunning, isProcessing]
)
useEffect(() => {
emitter.on(PluginName.RemoveBG, () => {
runRenderablePlugin(PluginName.RemoveBG)
})
return () => {
emitter.off(PluginName.RemoveBG)
}
}, [runRenderablePlugin])
useEffect(() => {
emitter.on(PluginName.RealESRGAN, () => {
runRenderablePlugin(PluginName.RealESRGAN)
})
return () => {
emitter.off(PluginName.RealESRGAN)
}
}, [runRenderablePlugin])
const hadRunInpainting = () => { const hadRunInpainting = () => {
return renders.length !== 0 return renders.length !== 0
} }
@@ -759,13 +832,7 @@ export default function Editor() {
} }
setIsInteractiveSegRunning(true) setIsInteractiveSegRunning(true)
const targetFile = await getCurrentRender()
let targetFile = file
if (renders.length > 0) {
const lastRender = renders[renders.length - 1]
targetFile = await srcToFile(lastRender.currentSrc, file.name, file.type)
}
const prevMask = null const prevMask = null
// prev_mask seems to be not working better // prev_mask seems to be not working better
// if (tmpInteractiveSegMask !== null) { // if (tmpInteractiveSegMask !== null) {
@@ -777,7 +844,12 @@ export default function Editor() {
// } // }
try { try {
const res = await postInteractiveSeg(targetFile, prevMask, newClicks) const res = await runPlugin(
PluginName.InteractiveSeg.toString(),
targetFile,
prevMask,
newClicks
)
if (!res) { if (!res) {
throw new Error('Something went wrong on server side.') throw new Error('Something went wrong on server side.')
} }
@@ -990,10 +1062,7 @@ export default function Editor() {
]) ])
const disableUndo = () => { const disableUndo = () => {
if (isInteractiveSeg) { if (isProcessing) {
return true
}
if (isInpainting) {
return true return true
} }
if (renders.length > 0) { if (renders.length > 0) {
@@ -1074,10 +1143,7 @@ export default function Editor() {
]) ])
const disableRedo = () => { const disableRedo = () => {
if (isInteractiveSeg) { if (isProcessing) {
return true
}
if (isInpainting) {
return true return true
} }
if (redoRenders.length > 0) { if (redoRenders.length > 0) {
@@ -1185,20 +1251,6 @@ export default function Editor() {
return undefined return undefined
}, [showBrush, isPanning]) }, [showBrush, isPanning])
useHotKey(
'i',
() => {
if (!isInteractiveSeg && isOriginalLoaded) {
setIsInteractiveSeg(true)
if (interactiveSegMask !== null) {
setShowInteractiveSegModal(true)
}
}
},
{},
[isInteractiveSeg, interactiveSegMask, isOriginalLoaded]
)
// Standard Hotkeys for Brush Size // Standard Hotkeys for Brush Size
useHotKey('[', () => { useHotKey('[', () => {
setBrushSize((currentBrushSize: number) => { setBrushSize((currentBrushSize: number) => {
@@ -1370,11 +1422,7 @@ export default function Editor() {
}} }}
> >
<TransformComponent <TransformComponent
contentClass={ contentClass={isProcessing ? 'editor-canvas-loading' : ''}
isInpainting || isInteractiveSegRunning
? 'editor-canvas-loading'
: ''
}
contentStyle={{ contentStyle={{
visibility: initialCentered ? 'visible' : 'hidden', visibility: initialCentered ? 'visible' : 'hidden',
}} }}
@@ -1416,23 +1464,24 @@ export default function Editor() {
}} }}
> >
{showOriginal && ( {showOriginal && (
<div <>
className="editor-slider" <div
style={{ className="editor-slider"
marginRight: `${sliderPos}%`, style={{
}} marginRight: `${sliderPos}%`,
/> }}
/>
<img
className="original-image"
src={original.src}
alt="original"
style={{
width: `${original.naturalWidth}px`,
height: `${original.naturalHeight}px`,
}}
/>
</>
)} )}
<img
className="original-image"
src={original.src}
alt="original"
style={{
width: `${original.naturalWidth}px`,
height: `${original.naturalHeight}px`,
}}
/>
</div> </div>
</div> </div>
@@ -1467,6 +1516,7 @@ export default function Editor() {
onMouseMove={onMouseMove} onMouseMove={onMouseMove}
onMouseUp={onPointerUp} onMouseUp={onPointerUp}
> >
<MakeGIF renders={renders} />
<InteractiveSegConfirmActions <InteractiveSegConfirmActions
onAcceptClick={onInteractiveAccept} onAcceptClick={onInteractiveAccept}
onCancelClick={onInteractiveCancel} onCancelClick={onInteractiveCancel}
@@ -1514,17 +1564,6 @@ export default function Editor() {
onClick={() => setShowRefBrush(false)} onClick={() => setShowRefBrush(false)}
/> />
<div className="editor-toolkit-btns"> <div className="editor-toolkit-btns">
<Button
toolTip="Interactive Segmentation"
icon={<CursorArrowRaysIcon />}
disabled={isInteractiveSeg || isInpainting || !isOriginalLoaded}
onClick={() => {
setIsInteractiveSeg(true)
if (interactiveSegMask !== null) {
setShowInteractiveSegModal(true)
}
}}
/>
<Button <Button
toolTip="Reset Zoom & Pan" toolTip="Reset Zoom & Pan"
icon={<ArrowsPointingOutIcon />} icon={<ArrowsPointingOutIcon />}
@@ -1591,7 +1630,6 @@ export default function Editor() {
}} }}
disabled={renders.length === 0} disabled={renders.length === 0}
/> />
<MakeGIF renders={renders} />
<Button <Button
toolTip="Save Image" toolTip="Save Image"
icon={<ArrowDownTrayIcon />} icon={<ArrowDownTrayIcon />}
@@ -1617,8 +1655,7 @@ export default function Editor() {
</svg> </svg>
} }
disabled={ disabled={
isInpainting || isProcessing ||
isInteractiveSeg ||
(!hadDrawSomething() && interactiveSegMask === null) (!hadDrawSomething() && interactiveSegMask === null)
} }
onClick={() => { onClick={() => {

View File

@@ -1,5 +1,4 @@
import React, { useState } from 'react' import React, { useEffect, useState } from 'react'
import { GifIcon } from '@heroicons/react/24/outline'
import { useRecoilState, useRecoilValue, useSetRecoilState } from 'recoil' import { useRecoilState, useRecoilValue, useSetRecoilState } from 'recoil'
import Button from '../shared/Button' import Button from '../shared/Button'
import { fileState, gifImageState, toastState } from '../../store/Atoms' import { fileState, gifImageState, toastState } from '../../store/Atoms'
@@ -7,6 +6,8 @@ import { makeGif } from '../../adapters/inpainting'
import Modal from '../shared/Modal' import Modal from '../shared/Modal'
import { LoadingIcon } from '../shared/Toast' import { LoadingIcon } from '../shared/Toast'
import { downloadImage } from '../../utils' import { downloadImage } from '../../utils'
import emitter from '../../event'
import { PluginName } from '../Plugins/Plugins'
interface Props { interface Props {
renders: HTMLImageElement[] renders: HTMLImageElement[]
@@ -30,84 +31,94 @@ const MakeGIF = (props: Props) => {
} }
} }
return ( useEffect(() => {
<div> emitter.on(PluginName.MakeGIF, async () => {
<Button if (renders.length === 0) {
toolTip="Make Gif" setToastState({
icon={<GifIcon />} open: true,
disabled={!renders.length} desc: 'No render found',
onClick={async () => { state: 'error',
setShow(true) duration: 2000,
setGifImg(null) })
try { return
const gif = await makeGif( }
file,
renders[renders.length - 1],
file.name,
file.type
)
if (gif) {
setGifImg(gif)
}
} catch (e: any) {
setToastState({
open: true,
desc: e.message ? e.message : e.toString(),
state: 'error',
duration: 2000,
})
}
}}
/>
<Modal
onClose={handleOnClose}
title="GIF"
className="modal-setting"
show={show}
>
<div
style={{
display: 'flex',
alignItems: 'center',
justifyContent: 'center',
flexDirection: 'column',
gap: 16,
}}
>
{gifImg ? (
<img src={gifImg.src} style={{ borderRadius: 8 }} alt="gif" />
) : (
<div
style={{
display: 'flex',
alignItems: 'center',
justifyContent: 'center',
gap: 8,
}}
>
<LoadingIcon />
Generating GIF...
</div>
)}
{gifImg && ( setShow(true)
<div setGifImg(null)
style={{ try {
display: 'flex', const gif = await makeGif(
width: '100%', file,
justifyContent: 'flex-end', renders[renders.length - 1],
alignItems: 'center', file.name,
gap: '12px', file.type
}} )
> if (gif) {
<Button onClick={handleDownload} border> setGifImg(gif)
Download }
</Button> } catch (e: any) {
</div> setToastState({
)} open: true,
</div> desc: e.message ? e.message : e.toString(),
</Modal> state: 'error',
</div> duration: 2000,
})
setShow(false)
}
})
return () => {
emitter.off(PluginName.MakeGIF)
}
}, [setGifImg, renders, file, setShow])
return (
<Modal
onClose={handleOnClose}
title="GIF"
className="modal-setting"
show={show}
>
<div
style={{
display: 'flex',
alignItems: 'center',
justifyContent: 'center',
flexDirection: 'column',
gap: 16,
}}
>
{gifImg ? (
<img src={gifImg.src} style={{ borderRadius: 8 }} alt="gif" />
) : (
<div
style={{
display: 'flex',
alignItems: 'center',
justifyContent: 'center',
gap: 8,
}}
>
<LoadingIcon />
Generating GIF...
</div>
)}
{gifImg && (
<div
style={{
display: 'flex',
width: '100%',
justifyContent: 'flex-end',
alignItems: 'center',
gap: '12px',
}}
>
<Button onClick={handleDownload} border>
Download
</Button>
</div>
)}
</div>
</Modal>
) )
} }

View File

@@ -0,0 +1,92 @@
@use '../../styles/Mixins/' as *;
.plugins {
position: absolute;
top: 68px;
left: 1rem;
padding: 0.1rem 0.3rem;
z-index: 4;
border-radius: 0.8rem;
border-style: solid;
border-color: var(--border-color);
border-width: 1px;
}
.plugins-trigger {
font-family: 'WorkSans', sans-serif;
font-size: 16px;
border: 0px;
}
.plugins-content {
outline: none;
position: relative;
font-family: 'WorkSans', sans-serif;
font-size: 14px;
top: 8px;
left: 1rem;
padding: 0.8rem 0.5rem;
z-index: 9;
// backdrop-filter: blur(12px);
color: var(--text-color);
background-color: var(--page-bg);
border-radius: 0.8rem;
border-style: solid;
border-color: var(--border-color);
border-width: 1px;
display: flex;
flex-direction: column;
gap: 12px;
.setting-block-content {
gap: 1rem;
}
// input {
// height: 24px;
// // border-radius: 4px;
// }
// button {
// height: 28px;
// // border-radius: 4px;
// }
}
.negative-prompt {
all: unset;
border-width: 0;
border-radius: 0.5rem;
min-height: 150px;
max-width: 200px;
width: 100%;
padding: 12px 0.8rem;
outline: 1px solid var(--border-color);
&:focus-visible {
border-width: 0;
outline: 1px solid var(--yellow-accent);
}
&:-webkit-input-placeholder {
padding-top: 10px;
}
&:-moz-input-placeholder {
padding-top: 10px;
}
&:-ms-input-placeholder {
padding-top: 10px;
}
}
.resize-title-tile {
width: 86px;
font-size: 0.5rem;
color: var(--text-color-gray);
}

View File

@@ -0,0 +1,91 @@
import React, { FormEvent } from 'react'
import { useRecoilValue } from 'recoil'
import { CursorArrowRaysIcon, GifIcon } from '@heroicons/react/24/outline'
import { BoxModelIcon, MarginIcon, HobbyKnifeIcon } from '@radix-ui/react-icons'
import { useToggle } from 'react-use'
import * as PopoverPrimitive from '@radix-ui/react-popover'
import {
fileState,
isInpaintingState,
isPluginRunningState,
isProcessingState,
serverConfigState,
} from '../../store/Atoms'
import emitter from '../../event'
import Button from '../shared/Button'
export enum PluginName {
RemoveBG = 'RemoveBG',
RealESRGAN = 'RealESRGAN',
InteractiveSeg = 'InteractiveSeg',
MakeGIF = 'MakeGIF',
}
const pluginMap = {
[PluginName.RemoveBG]: {
IconClass: HobbyKnifeIcon,
showName: 'RemoveBG',
},
[PluginName.RealESRGAN]: {
IconClass: BoxModelIcon,
showName: 'RealESRGAN 4x',
},
[PluginName.InteractiveSeg]: {
IconClass: CursorArrowRaysIcon,
showName: 'Interactive Seg',
},
[PluginName.MakeGIF]: {
IconClass: GifIcon,
showName: 'Make GIF',
},
}
const Plugins = () => {
const [open, toggleOpen] = useToggle(true)
const serverConfig = useRecoilValue(serverConfigState)
const file = useRecoilValue(fileState)
const isProcessing = useRecoilValue(isProcessingState)
const onPluginClick = (pluginName: string) => {
if (isProcessing) {
return
}
emitter.emit(pluginName)
}
const renderPlugins = () => {
return serverConfig.plugins.map((plugin: string) => {
const { IconClass } = pluginMap[plugin as PluginName]
return (
<Button
style={{ gap: 6 }}
icon={<IconClass style={{ width: 15 }} />}
onClick={() => onPluginClick(plugin)}
disabled={!file || isProcessing}
>
{pluginMap[plugin as PluginName].showName}
</Button>
)
})
}
return (
<div className="plugins">
<PopoverPrimitive.Root open={open}>
<PopoverPrimitive.Trigger
className="btn-primary plugins-trigger"
onClick={() => toggleOpen()}
>
Plugins
</PopoverPrimitive.Trigger>
<PopoverPrimitive.Portal>
<PopoverPrimitive.Content className="plugins-content">
{renderPlugins()}
</PopoverPrimitive.Content>
</PopoverPrimitive.Portal>
</PopoverPrimitive.Root>
</div>
)
}
export default Plugins

View File

@@ -1,11 +1,9 @@
import React, { ReactNode, useEffect, useState } from 'react' import React, { ReactNode, useEffect, useState } from 'react'
import { useRecoilState, useRecoilValue } from 'recoil' import { useRecoilState, useRecoilValue } from 'recoil'
import { getIsDisableModelSwitch } from '../../adapters/inpainting'
import { import {
AIModel, AIModel,
CV2Flag, CV2Flag,
isDisableModelSwitchState, isDisableModelSwitchState,
SDSampler,
settingState, settingState,
} from '../../store/Atoms' } from '../../store/Atoms'
import Selector from '../shared/Selector' import Selector from '../shared/Selector'

View File

@@ -4,7 +4,7 @@
position: absolute; position: absolute;
top: 68px; top: 68px;
right: 1.5rem; right: 1.5rem;
padding: 0.3rem 0.3rem; padding: 0.1rem 0.3rem;
z-index: 4; z-index: 4;
border-radius: 0.8rem; border-radius: 0.8rem;
@@ -20,10 +20,11 @@
} }
.side-panel-content { .side-panel-content {
outline: none;
position: relative; position: relative;
font-family: 'WorkSans', sans-serif; font-family: 'WorkSans', sans-serif;
font-size: 14px; font-size: 14px;
top: 1rem; top: 8px;
right: 1.5rem; right: 1.5rem;
padding: 1rem 1rem; padding: 1rem 1rem;
z-index: 9; z-index: 9;

View File

@@ -24,6 +24,7 @@ import SidePanel from './SidePanel/SidePanel'
import PESidePanel from './SidePanel/PESidePanel' import PESidePanel from './SidePanel/PESidePanel'
import FileManager from './FileManager/FileManager' import FileManager from './FileManager/FileManager'
import P2PSidePanel from './SidePanel/P2PSidePanel' import P2PSidePanel from './SidePanel/P2PSidePanel'
import Plugins from './Plugins/Plugins'
const Workspace = () => { const Workspace = () => {
const setFile = useSetRecoilState(fileState) const setFile = useSetRecoilState(fileState)
@@ -102,6 +103,7 @@ const Workspace = () => {
{isSD ? <SidePanel /> : <></>} {isSD ? <SidePanel /> : <></>}
{isPaintByExample ? <PESidePanel /> : <></>} {isPaintByExample ? <PESidePanel /> : <></>}
{isPix2Pix ? <P2PSidePanel /> : <></>} {isPix2Pix ? <P2PSidePanel /> : <></>}
<Plugins />
<FileManager <FileManager
photoWidth={256} photoWidth={256}
show={showFileManager} show={showFileManager}

View File

@@ -52,6 +52,8 @@ interface AppState {
gifImage: HTMLImageElement | undefined gifImage: HTMLImageElement | undefined
brushSize: number brushSize: number
isControlNet: boolean isControlNet: boolean
plugins: string[]
isPluginRunning: boolean
} }
export const appState = atom<AppState>({ export const appState = atom<AppState>({
@@ -72,6 +74,8 @@ export const appState = atom<AppState>({
gifImage: undefined, gifImage: undefined,
brushSize: 40, brushSize: 40,
isControlNet: false, isControlNet: false,
plugins: [],
isPluginRunning: false,
}, },
}) })
@@ -97,6 +101,36 @@ export const isInpaintingState = selector({
}, },
}) })
export const isPluginRunningState = selector({
key: 'isPluginRunningState',
get: ({ get }) => {
const app = get(appState)
return app.isPluginRunning
},
set: ({ get, set }, newValue: any) => {
const app = get(appState)
set(appState, { ...app, isPluginRunning: newValue })
},
})
export const serverConfigState = selector({
key: 'serverConfigState',
get: ({ get }) => {
const app = get(appState)
return {
isControlNet: app.isControlNet,
isDisableModelSwitchState: app.isDisableModelSwitch,
isEnableAutoSaving: app.isEnableAutoSaving,
enableFileManager: app.enableFileManager,
plugins: app.plugins,
}
},
set: ({ get, set }, newValue: any) => {
const app = get(appState)
set(appState, { ...app, ...newValue })
},
})
export const brushSizeState = selector({ export const brushSizeState = selector({
key: 'brushSizeState', key: 'brushSizeState',
get: ({ get }) => { get: ({ get }) => {
@@ -217,6 +251,16 @@ export const isInteractiveSegRunningState = selector({
}, },
}) })
export const isProcessingState = selector({
key: 'isProcessingState',
get: ({ get }) => {
const app = get(appState)
return (
app.isInteractiveSegRunning || app.isPluginRunning || app.isInpainting
)
},
})
export const interactiveSegClicksState = selector({ export const interactiveSegClicksState = selector({
key: 'interactiveSegClicksState', key: 'interactiveSegClicksState',
get: ({ get }) => { get: ({ get }) => {

View File

@@ -15,6 +15,7 @@
@use '../components/Shortcuts/Shortcuts'; @use '../components/Shortcuts/Shortcuts';
@use '../components/Settings/Settings.scss'; @use '../components/Settings/Settings.scss';
@use '../components/SidePanel/SidePanel.scss'; @use '../components/SidePanel/SidePanel.scss';
@use '../components/Plugins/Plugins.scss';
@use '../components/Croper/Croper.scss'; @use '../components/Croper/Croper.scss';
@use '../components/InteractiveSeg/InteractiveSeg.scss'; @use '../components/InteractiveSeg/InteractiveSeg.scss';

View File

@@ -94,6 +94,11 @@ def parse_args():
parser.add_argument( parser.add_argument(
"--realesrgan-device", default="cpu", type=str, choices=["cpu", "cuda"] "--realesrgan-device", default="cpu", type=str, choices=["cpu", "cuda"]
) )
parser.add_argument(
"--enable-gif",
action="store_true",
help="Enable GIF plugin",
)
######### #########
# useless args # useless args

View File

@@ -1,3 +1,3 @@
from .interactive_seg import InteractiveSeg, Click from .interactive_seg import InteractiveSeg, Click
from .remove_bg import RemoveBG from .remove_bg import RemoveBG
from .upscale import RealESRGANUpscaler from .realesrgan import RealESRGANUpscaler

View File

@@ -1,9 +1,10 @@
import io import io
import math import math
from pathlib import Path
from PIL import Image, ImageDraw from PIL import Image, ImageDraw
from lama_cleaner.helper import load_img
def keep_ratio_resize(img, size, resample=Image.BILINEAR): def keep_ratio_resize(img, size, resample=Image.BILINEAR):
if img.width > img.height: if img.width > img.height:
@@ -33,16 +34,20 @@ def cubic_bezier(p1, p2, duration: int, frames: int):
x3, y3 = (1, 1) x3, y3 = (1, 1)
def cal_y(t): def cal_y(t):
return math.pow(1 - t, 3) * y0 + \ return (
3 * math.pow(1 - t, 2) * t * y1 + \ math.pow(1 - t, 3) * y0
3 * (1 - t) * math.pow(t, 2) * y2 + \ + 3 * math.pow(1 - t, 2) * t * y1
math.pow(t, 3) * y3 + 3 * (1 - t) * math.pow(t, 2) * y2
+ math.pow(t, 3) * y3
)
def cal_x(t): def cal_x(t):
return math.pow(1 - t, 3) * x0 + \ return (
3 * math.pow(1 - t, 2) * t * x1 + \ math.pow(1 - t, 3) * x0
3 * (1 - t) * math.pow(t, 2) * x2 + \ + 3 * math.pow(1 - t, 2) * t * x1
math.pow(t, 3) * x3 + 3 * (1 - t) * math.pow(t, 2) * x2
+ math.pow(t, 3) * x3
)
res = [] res = []
for t in range(0, 1 * frames, duration): for t in range(0, 1 * frames, duration):
@@ -58,7 +63,7 @@ def make_compare_gif(
src_img: Image.Image, src_img: Image.Image,
max_side_length: int = 600, max_side_length: int = 600,
splitter_width: int = 5, splitter_width: int = 5,
splitter_color=(255, 203, 0, int(255 * 0.73)) splitter_color=(255, 203, 0, int(255 * 0.73)),
): ):
if clean_img.size != src_img.size: if clean_img.size != src_img.size:
clean_img = clean_img.resize(src_img.size, Image.BILINEAR) clean_img = clean_img.resize(src_img.size, Image.BILINEAR)
@@ -79,7 +84,7 @@ def make_compare_gif(
images = [] images = []
for i in range(num_frames): for i in range(num_frames):
new_frame = Image.new('RGB', (width, height)) new_frame = Image.new("RGB", (width, height))
new_frame.paste(clean_img, (0, 0)) new_frame.paste(clean_img, (0, 0))
left = int(cubic_bezier_points[i][0] * width) left = int(cubic_bezier_points[i][0] * width)
@@ -88,7 +93,9 @@ def make_compare_gif(
if i != num_frames - 1: if i != num_frames - 1:
# draw a yellow splitter on the edge of the cropped image # draw a yellow splitter on the edge of the cropped image
draw = ImageDraw.Draw(new_frame) draw = ImageDraw.Draw(new_frame)
draw.line([(left, 0), (left, height)], width=splitter_width, fill=splitter_color) draw.line(
[(left, 0), (left, height)], width=splitter_width, fill=splitter_color
)
images.append(new_frame) images.append(new_frame)
for i in range(10): for i in range(10):
@@ -97,7 +104,7 @@ def make_compare_gif(
cubic_bezier_points.reverse() cubic_bezier_points.reverse()
# Generate images to make Gif from left to right # Generate images to make Gif from left to right
for i in range(num_frames): for i in range(num_frames):
new_frame = Image.new('RGB', (width, height)) new_frame = Image.new("RGB", (width, height))
new_frame.paste(src_img, (0, 0)) new_frame.paste(src_img, (0, 0))
right = int(cubic_bezier_points[i][0] * width) right = int(cubic_bezier_points[i][0] * width)
@@ -106,7 +113,9 @@ def make_compare_gif(
if i != num_frames - 1: if i != num_frames - 1:
# draw a yellow splitter on the edge of the cropped image # draw a yellow splitter on the edge of the cropped image
draw = ImageDraw.Draw(new_frame) draw = ImageDraw.Draw(new_frame)
draw.line([(right, 0), (right, height)], width=splitter_width, fill=splitter_color) draw.line(
[(right, 0), (right, height)], width=splitter_width, fill=splitter_color
)
images.append(new_frame) images.append(new_frame)
images.append(clean_img) images.append(clean_img)
@@ -114,12 +123,25 @@ def make_compare_gif(
img_byte_arr = io.BytesIO() img_byte_arr = io.BytesIO()
clean_img.save( clean_img.save(
img_byte_arr, img_byte_arr,
format='GIF', format="GIF",
save_all=True, save_all=True,
include_color_table=True, include_color_table=True,
append_images=images, append_images=images,
optimize=False, optimize=False,
duration=duration_per_frame, duration=duration_per_frame,
loop=0 loop=0,
) )
return img_byte_arr.getvalue() return img_byte_arr.getvalue()
class MakeGIF:
name = "MakeGIF"
def __call__(self, rgb_np_img, files, form):
origin_image = rgb_np_img
clean_image_bytes = files["clean_img"].read()
clean_image, _ = load_img(clean_image_bytes)
gif_bytes = make_compare_gif(
Image.fromarray(origin_image), Image.fromarray(clean_image)
)
return gif_bytes

View File

@@ -37,7 +37,7 @@ class RealESRGANUpscaler:
def __call__(self, rgb_np_img, files, form): def __call__(self, rgb_np_img, files, form):
bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR) bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR)
scale = float(form["scale"]) scale = 4
return self.forward(bgr_np_img, scale) return self.forward(bgr_np_img, scale)
def forward(self, bgr_np_img, scale: float): def forward(self, bgr_np_img, scale: float):

View File

@@ -17,10 +17,10 @@ import numpy as np
from loguru import logger from loguru import logger
from lama_cleaner.const import SD15_MODELS from lama_cleaner.const import SD15_MODELS
from lama_cleaner.make_gif import make_compare_gif
from lama_cleaner.model.utils import torch_gc from lama_cleaner.model.utils import torch_gc
from lama_cleaner.model_manager import ModelManager from lama_cleaner.model_manager import ModelManager
from lama_cleaner.plugins import InteractiveSeg, RemoveBG, RealESRGANUpscaler from lama_cleaner.plugins import InteractiveSeg, RemoveBG, RealESRGANUpscaler
from lama_cleaner.plugins.gif import MakeGIF
from lama_cleaner.schema import Config from lama_cleaner.schema import Config
from lama_cleaner.file_manager import FileManager from lama_cleaner.file_manager import FileManager
@@ -318,11 +318,10 @@ def process():
return response return response
@app.route("/run_plugin/", methods=["POST"]) @app.route("/run_plugin", methods=["POST"])
def run_plugin(): def run_plugin():
form = request.form form = request.form
files = request.files files = request.files
name = form["name"] name = form["name"]
if name not in plugins: if name not in plugins:
return "Plugin not found", 500 return "Plugin not found", 500
@@ -335,18 +334,33 @@ def run_plugin():
logger.info(f"{name} process time: {(time.time() - start) * 1000}ms") logger.info(f"{name} process time: {(time.time() - start) * 1000}ms")
torch_gc() torch_gc()
response = make_response( if name == MakeGIF.name:
send_file( filename = form["filename"]
io.BytesIO(numpy_to_bytes(res, "png")), return send_file(
mimetype=f"image/png", io.BytesIO(res),
mimetype="image/gif",
as_attachment=True,
attachment_filename=filename,
)
else:
response = make_response(
send_file(
io.BytesIO(numpy_to_bytes(res, "png")),
mimetype=f"image/png",
)
) )
)
return response return response
@app.route("/plugins/", methods=["GET"]) @app.route("/server_config", methods=["GET"])
def get_plugins(): def get_server_config():
return list(plugins.keys()), 200 return {
"isControlNet": is_controlnet,
"isDisableModelSwitchState": is_disable_model_switch,
"isEnableAutoSaving": is_enable_file_manager,
"enableFileManager": is_enable_auto_saving,
"plugins": list(plugins.keys()),
}, 200
@app.route("/model") @app.route("/model")
@@ -354,30 +368,6 @@ def current_model():
return model.name, 200 return model.name, 200
@app.route("/is_controlnet")
def get_is_controlnet():
res = "true" if is_controlnet else "false"
return res, 200
@app.route("/is_disable_model_switch")
def get_is_disable_model_switch():
res = "true" if is_disable_model_switch else "false"
return res, 200
@app.route("/is_enable_file_manager")
def get_is_enable_file_manager():
res = "true" if is_enable_file_manager else "false"
return res, 200
@app.route("/is_enable_auto_saving")
def get_is_enable_auto_saving():
res = "true" if is_enable_auto_saving else "false"
return res, 200
@app.route("/model_downloaded/<name>") @app.route("/model_downloaded/<name>")
def model_downloaded(name): def model_downloaded(name):
return str(model.is_downloaded(name)), 200 return str(model.is_downloaded(name)), 200
@@ -435,6 +425,9 @@ def build_plugins(args):
if args.enable_realesrgan: if args.enable_realesrgan:
logger.info(f"Initialize {RealESRGANUpscaler.name} plugin") logger.info(f"Initialize {RealESRGANUpscaler.name} plugin")
plugins[RealESRGANUpscaler.name] = RealESRGANUpscaler(args.realesrgan_device) plugins[RealESRGANUpscaler.name] = RealESRGANUpscaler(args.realesrgan_device)
if args.enable_gif:
logger.info(f"Initialize GIF plugin")
plugins[MakeGIF.name] = MakeGIF()
def main(args): def main(args):