wip: add interactive seg model

This commit is contained in:
Qing
2022-11-27 21:25:27 +08:00
parent af87cca643
commit 023306ae40
20 changed files with 820 additions and 46 deletions

View File

@@ -1,8 +1,3 @@
import {
ArrowsExpandIcon,
DownloadIcon,
EyeIcon,
} from '@heroicons/react/outline'
import React, {
SyntheticEvent,
useCallback,
@@ -10,6 +5,12 @@ import React, {
useRef,
useState,
} from 'react'
import {
CursorArrowRaysIcon,
EyeIcon,
ArrowsPointingOutIcon,
ArrowDownTrayIcon,
} from '@heroicons/react/24/outline'
import {
ReactZoomPanPinchRef,
TransformComponent,
@@ -17,7 +18,7 @@ import {
} from 'react-zoom-pan-pinch'
import { useRecoilState, useRecoilValue } from 'recoil'
import { useWindowSize, useKey, useKeyPressEvent } from 'react-use'
import inpaint from '../../adapters/inpainting'
import inpaint, { postInteractiveSeg } from '../../adapters/inpainting'
import Button from '../shared/Button'
import Slider from './Slider'
import SizeSelector from './SizeSelector'
@@ -34,7 +35,10 @@ import {
import {
croperState,
fileState,
interactiveSegClicksState,
isInpaintingState,
isInteractiveSegRunningState,
isInteractiveSegState,
isSDState,
negativePropmtState,
propmtState,
@@ -51,6 +55,8 @@ import emitter, {
CustomMaskEventData,
} from '../../event'
import FileSelect from '../FileSelect/FileSelect'
import InteractiveSeg from '../InteractiveSeg/InteractiveSeg'
import InteractiveSegConfirmActions from '../InteractiveSeg/ConfirmActions'
const TOOLBAR_SIZE = 200
const MIN_BRUSH_SIZE = 10
@@ -101,6 +107,20 @@ export default function Editor() {
const [isInpainting, setIsInpainting] = useRecoilState(isInpaintingState)
const runMannually = useRecoilValue(runManuallyState)
const isSD = useRecoilValue(isSDState)
const [isInteractiveSeg, setIsInteractiveSeg] = useRecoilState(
isInteractiveSegState
)
const [isInteractiveSegRunning, setIsInteractiveSegRunning] = useRecoilState(
isInteractiveSegRunningState
)
const [interactiveSegMask, setInteractiveSegMask] =
useState<HTMLImageElement | null>(null)
// only used while interactive segmentation is on
const [tmpInteractiveSegMask, setTmpInteractiveSegMask] =
useState<HTMLImageElement | null>(null)
const [clicks, setClicks] = useRecoilState(interactiveSegClicksState)
const [brushSize, setBrushSize] = useState(40)
const [original, isOriginalLoaded] = useImage(file)
@@ -159,13 +179,37 @@ export default function Editor() {
original.naturalWidth,
original.naturalHeight
)
if (isInteractiveSeg && tmpInteractiveSegMask !== null) {
context.drawImage(
tmpInteractiveSegMask,
0,
0,
original.naturalWidth,
original.naturalHeight
)
}
if (!isInteractiveSeg && interactiveSegMask !== null) {
context.drawImage(
interactiveSegMask,
0,
0,
original.naturalWidth,
original.naturalHeight
)
}
drawLines(context, lineGroup)
},
[context, original]
[
context,
original,
isInteractiveSeg,
tmpInteractiveSegMask,
interactiveSegMask,
]
)
const drawLinesOnMask = useCallback(
(_lineGroups: LineGroup[]) => {
(_lineGroups: LineGroup[], maskImage?: HTMLImageElement | null) => {
if (!context?.canvas.width || !context?.canvas.height) {
throw new Error('canvas has invalid size')
}
@@ -176,6 +220,17 @@ export default function Editor() {
throw new Error('could not retrieve mask canvas')
}
if (maskImage !== undefined && maskImage !== null) {
// TODO: check whether draw yellow mask works on backend
ctx.drawImage(
maskImage,
0,
0,
original.naturalWidth,
original.naturalHeight
)
}
_lineGroups.forEach(lineGroup => {
drawLines(ctx, lineGroup, 'white')
})
@@ -199,15 +254,24 @@ export default function Editor() {
)
const runInpainting = useCallback(
async (useLastLineGroup?: boolean, customMask?: File) => {
async (
useLastLineGroup?: boolean,
customMask?: File,
maskImage?: HTMLImageElement | null
) => {
if (file === undefined) {
return
}
const useCustomMask = customMask !== undefined
const useMaskImage = maskImage !== undefined && maskImage !== null
// useLastLineGroup 的影响
// 1. 使用上一次的 mask
// 2. 结果替换当前 render
console.log('runInpainting')
console.log({
useCustomMask,
useMaskImage,
})
let maskLineGroup: LineGroup = []
if (useLastLineGroup === true) {
@@ -216,7 +280,7 @@ export default function Editor() {
}
maskLineGroup = lastLineGroup
} else if (!useCustomMask) {
if (!hadDrawSomething()) {
if (!hadDrawSomething() && !useMaskImage) {
return
}
@@ -230,7 +294,7 @@ export default function Editor() {
setIsDraging(false)
setIsInpainting(true)
if (settings.graduallyInpainting) {
drawLinesOnMask([maskLineGroup])
drawLinesOnMask([maskLineGroup], maskImage)
} else {
drawLinesOnMask(newLineGroups)
}
@@ -309,6 +373,8 @@ export default function Editor() {
drawOnCurrentRender([])
}
setIsInpainting(false)
setTmpInteractiveSegMask(null)
setInteractiveSegMask(null)
},
[
lineGroups,
@@ -493,10 +559,22 @@ export default function Editor() {
}
}, [])
const onInteractiveCancel = useCallback(() => {
setIsInteractiveSeg(false)
setIsInteractiveSegRunning(false)
setClicks([])
setTmpInteractiveSegMask(null)
}, [])
const handleEscPressed = () => {
if (isInpainting) {
return
}
if (isInteractiveSeg) {
onInteractiveCancel()
}
if (isDraging || isMultiStrokeKeyPressed) {
setIsDraging(false)
setCurLineGroup([])
@@ -516,6 +594,8 @@ export default function Editor() {
isDraging,
isInpainting,
isMultiStrokeKeyPressed,
isInteractiveSeg,
onInteractiveCancel,
resetZoom,
drawOnCurrentRender,
]
@@ -536,6 +616,9 @@ export default function Editor() {
}
return
}
if (isInteractiveSeg) {
return
}
if (isPanning) {
return
}
@@ -551,10 +634,58 @@ export default function Editor() {
drawOnCurrentRender(lineGroup)
}
const runInteractiveSeg = async (newClicks: number[][]) => {
if (!file) {
return
}
setIsInteractiveSegRunning(true)
let targetFile = file
if (renders.length > 0) {
const lastRender = renders[renders.length - 1]
targetFile = await srcToFile(lastRender.currentSrc, file.name, file.type)
}
const prevMask = null
// prev_mask seems to be not working better
// if (tmpInteractiveSegMask !== null) {
// prevMask = await srcToFile(
// tmpInteractiveSegMask.currentSrc,
// 'prev_mask.jpg',
// 'image/jpeg'
// )
// }
try {
const res = await postInteractiveSeg(targetFile, prevMask, newClicks)
if (!res) {
throw new Error('Something went wrong on server side.')
}
const { blob } = res
const img = new Image()
img.onload = () => {
setTmpInteractiveSegMask(img)
}
img.src = blob
} catch (e: any) {
setToastState({
open: true,
desc: e.message ? e.message : e.toString(),
state: 'error',
duration: 4000,
})
}
setIsInteractiveSegRunning(false)
}
const onPointerUp = (ev: SyntheticEvent) => {
if (isMidClick(ev)) {
setIsPanning(false)
}
if (isInteractiveSeg) {
return
}
if (isPanning) {
return
@@ -601,7 +732,24 @@ export default function Editor() {
return false
}
const onCanvasMouseUp = (ev: SyntheticEvent) => {
if (isInteractiveSeg) {
const xy = mouseXY(ev)
const newClicks: number[][] = [...clicks]
if (isRightClick(ev)) {
newClicks.push([xy.x, xy.y, 0, newClicks.length])
} else {
newClicks.push([xy.x, xy.y, 1, newClicks.length])
}
runInteractiveSeg(newClicks)
setClicks(newClicks)
}
}
const onMouseDown = (ev: SyntheticEvent) => {
if (isInteractiveSeg) {
return
}
if (isChangingBrushSizeByMouse) {
return
}
@@ -714,6 +862,9 @@ export default function Editor() {
useKey(undoPredicate, undo, undefined, [undoStroke, undoRender, isSD])
const disableUndo = () => {
if (isInteractiveSeg) {
return true
}
if (isInpainting) {
return true
}
@@ -790,6 +941,9 @@ export default function Editor() {
useKey(redoPredicate, redo, undefined, [redoStroke, redoRender, isSD])
const disableRedo = () => {
if (isInteractiveSeg) {
return true
}
if (isInpainting) {
return true
}
@@ -877,6 +1031,16 @@ export default function Editor() {
return undefined
}, [showBrush, isPanning])
useHotKey(
'i',
() => {
if (!isInteractiveSeg) {
setIsInteractiveSeg(true)
}
},
isInteractiveSeg
)
// Standard Hotkeys for Brush Size
useHotKey('[', () => {
setBrushSize(currentBrushSize => {
@@ -1002,6 +1166,20 @@ export default function Editor() {
)
}
const renderInteractiveSegCursor = () => {
return (
<div
className="interactive-seg-cursor"
style={{
left: `${x}px`,
top: `${y}px`,
}}
>
<CursorArrowRaysIcon />
</div>
)
}
const renderCanvas = () => {
return (
<TransformWrapper
@@ -1029,7 +1207,11 @@ export default function Editor() {
}}
>
<TransformComponent
contentClass={isInpainting ? 'editor-canvas-loading' : ''}
contentClass={
isInpainting || isInteractiveSegRunning
? 'editor-canvas-loading'
: ''
}
contentStyle={{
visibility: initialCentered ? 'visible' : 'hidden',
}}
@@ -1052,6 +1234,7 @@ export default function Editor() {
onFocus={() => toggleShowBrush(true)}
onMouseLeave={() => toggleShowBrush(false)}
onMouseDown={onMouseDown}
onMouseUp={onCanvasMouseUp}
onMouseMove={onMouseDrag}
ref={r => {
if (r && !context) {
@@ -1101,11 +1284,22 @@ export default function Editor() {
) : (
<></>
)}
{isInteractiveSeg ? <InteractiveSeg /> : <></>}
</TransformComponent>
</TransformWrapper>
)
}
const onInteractiveAccept = () => {
setInteractiveSegMask(tmpInteractiveSegMask)
setTmpInteractiveSegMask(null)
if (!runMannually && tmpInteractiveSegMask) {
runInpainting(false, undefined, tmpInteractiveSegMask)
}
}
return (
<div
className="editor-container"
@@ -1113,17 +1307,26 @@ export default function Editor() {
onMouseMove={onMouseMove}
onMouseUp={onPointerUp}
>
<InteractiveSegConfirmActions
onAcceptClick={onInteractiveAccept}
onCancelClick={onInteractiveCancel}
/>
{file === undefined ? renderFileSelect() : renderCanvas()}
{showBrush && !isInpainting && !isPanning && (
<div
className="brush-shape"
style={getBrushStyle(
isChangingBrushSizeByMouse ? changeBrushSizeByMouseInit.x : x,
isChangingBrushSizeByMouse ? changeBrushSizeByMouseInit.y : y
)}
/>
)}
{showBrush &&
!isInpainting &&
!isPanning &&
(isInteractiveSeg ? (
renderInteractiveSegCursor()
) : (
<div
className="brush-shape"
style={getBrushStyle(
isChangingBrushSizeByMouse ? changeBrushSizeByMouseInit.x : x,
isChangingBrushSizeByMouse ? changeBrushSizeByMouseInit.y : y
)}
/>
))}
{showRefBrush && (
<div
@@ -1151,10 +1354,17 @@ export default function Editor() {
onClick={() => setShowRefBrush(false)}
/>
<div className="editor-toolkit-btns">
<Button
toolTip="Interactive Segmentation"
tooltipPosition="top"
icon={<CursorArrowRaysIcon />}
disabled={isInteractiveSeg || isInpainting}
onClick={() => setIsInteractiveSeg(true)}
/>
<Button
toolTip="Reset Zoom & Pan"
tooltipPosition="top"
icon={<ArrowsExpandIcon />}
icon={<ArrowsPointingOutIcon />}
disabled={scale === minScale && panned === false}
onClick={resetZoom}
/>
@@ -1224,7 +1434,7 @@ export default function Editor() {
<Button
toolTip="Save Image"
tooltipPosition="top"
icon={<DownloadIcon />}
icon={<ArrowDownTrayIcon />}
disabled={!renders.length}
onClick={download}
/>
@@ -1247,11 +1457,13 @@ export default function Editor() {
/>
</svg>
}
disabled={!hadDrawSomething() || isInpainting}
disabled={
!interactiveSegMask &&
(!hadDrawSomething() || isInpainting || isInteractiveSeg)
}
onClick={() => {
if (!isInpainting && hadDrawSomething()) {
runInpainting()
}
// ensured by disabled
runInpainting(false, undefined, interactiveSegMask)
}}
/>
)}