big update

This commit is contained in:
Sanster
2022-04-16 00:11:51 +08:00
parent 2b031603ed
commit 205286a414
40 changed files with 539 additions and 376 deletions

View File

@@ -1,10 +1,12 @@
import { Setting } from '../store/Atoms'
import { dataURItoBlob } from '../utils'
export const API_ENDPOINT = `${process.env.REACT_APP_INPAINTING_URL}/inpaint`
export const API_ENDPOINT = `${process.env.REACT_APP_INPAINTING_URL}`
export default async function inpaint(
imageFile: File,
maskBase64: string,
settings: Setting,
sizeLimit?: string
) {
// 1080, 2000, Original
@@ -13,13 +15,22 @@ export default async function inpaint(
const mask = dataURItoBlob(maskBase64)
fd.append('mask', mask)
fd.append('ldmSteps', settings.ldmSteps.toString())
fd.append('hdStrategy', settings.hdStrategy)
fd.append('hdStrategyCropMargin', settings.hdStrategyCropMargin.toString())
fd.append(
'hdStrategyCropTrigerSize',
settings.hdStrategyCropTrigerSize.toString()
)
fd.append('hdStrategyResizeLimit', settings.hdStrategyResizeLimit.toString())
if (sizeLimit === undefined) {
fd.append('sizeLimit', '1080')
} else {
fd.append('sizeLimit', sizeLimit)
}
const res = await fetch(API_ENDPOINT, {
const res = await fetch(`${API_ENDPOINT}/inpaint`, {
method: 'POST',
body: fd,
}).then(async r => {
@@ -28,3 +39,12 @@ export default async function inpaint(
return URL.createObjectURL(res)
}
export function switchModel(name: string) {
const fd = new FormData()
fd.append('name', name)
return fetch(`${API_ENDPOINT}/switch_model`, {
method: 'POST',
body: fd,
})
}

View File

@@ -15,12 +15,14 @@ import {
TransformComponent,
TransformWrapper,
} from 'react-zoom-pan-pinch'
import { useRecoilValue } from 'recoil'
import { useWindowSize, useKey, useKeyPressEvent } from 'react-use'
import inpaint from '../../adapters/inpainting'
import Button from '../shared/Button'
import Slider from './Slider'
import SizeSelector from './SizeSelector'
import { downloadImage, loadImage, useImage } from '../../utils'
import { settingState } from '../../store/Atoms'
const TOOLBAR_SIZE = 200
const BRUSH_COLOR = '#ffcc00bb'
@@ -57,6 +59,7 @@ function drawLines(
export default function Editor(props: EditorProps) {
const { file } = props
const settings = useRecoilValue(settingState)
const [brushSize, setBrushSize] = useState(40)
const [original, isOriginalLoaded] = useImage(file)
const [renders, setRenders] = useState<HTMLImageElement[]>([])
@@ -125,6 +128,7 @@ export default function Editor(props: EditorProps) {
const res = await inpaint(
file,
maskCanvas.toDataURL(),
settings,
sizeLimit.toString()
)
if (!res) {
@@ -157,6 +161,7 @@ export default function Editor(props: EditorProps) {
renders,
sizeLimit,
historyLineCount,
settings,
])
const hadDrawSomething = () => {

View File

@@ -6,7 +6,7 @@ import Button from '../shared/Button'
import Shortcuts from '../Shortcuts/Shortcuts'
import useResolution from '../../hooks/useResolution'
import { ThemeChanger } from './ThemeChanger'
import SettingIcon from '../Setting/SettingIcon'
import SettingIcon from '../Settings/SettingIcon'
const Header = () => {
const [file, setFile] = useRecoilState(fileState)

View File

@@ -1,50 +1,16 @@
import React, { ReactNode } from 'react'
import { useRecoilState } from 'recoil'
import { settingState } from '../../store/Atoms'
import NumberInput from '../shared/NumberInput'
import Selector from '../shared/Selector'
import NumberInputSetting from './NumberInputSetting'
import SettingBlock from './SettingBlock'
export enum HDStrategy {
ORIGINAL = 'Original',
REISIZE = 'Resize',
RESIZE = 'Resize',
CROP = 'Crop',
}
interface PixelSizeInputProps {
title: string
value: string
onValue: (val: string) => void
}
function PixelSizeInputSetting(props: PixelSizeInputProps) {
const { title, value, onValue } = props
return (
<SettingBlock
className="sub-setting-block"
title={title}
input={
<div
style={{
display: 'flex',
justifyContent: 'center',
alignItems: 'center',
gap: '8px',
}}
>
<NumberInput
style={{ width: '80px' }}
value={`${value}`}
onValue={onValue}
/>
<span>pixel</span>
</div>
}
/>
)
}
function HDSettingBlock() {
const [setting, setSettingState] = useRecoilState(settingState)
@@ -84,7 +50,7 @@ function HDSettingBlock() {
tabIndex={0}
role="button"
className="inline-tip"
onClick={() => onStrategyChange(HDStrategy.REISIZE)}
onClick={() => onStrategyChange(HDStrategy.RESIZE)}
>
Resize Strategy
</div>{' '}
@@ -100,9 +66,10 @@ function HDSettingBlock() {
Resize the longer side of the image to a specific size(keep ratio),
then do inpainting on the resized image.
</div>
<PixelSizeInputSetting
<NumberInputSetting
title="Size limit"
value={`${setting.hdStrategyResizeLimit}`}
suffix="pixel"
onValue={onResizeLimitChange}
/>
</div>
@@ -117,14 +84,16 @@ function HDSettingBlock() {
the result back. Mainly for performance and memory reasons on high
resolution image.
</div>
<PixelSizeInputSetting
<NumberInputSetting
title="Trigger size"
value={`${setting.hdStrategyCropTrigerSize}`}
suffix="pixel"
onValue={onCropTriggerSizeChange}
/>
<PixelSizeInputSetting
<NumberInputSetting
title="Crop margin"
value={`${setting.hdStrategyCropMargin}`}
suffix="pixel"
onValue={onCropMarginChange}
/>
</div>
@@ -137,7 +106,7 @@ function HDSettingBlock() {
return renderOriginalOptionDesc()
case HDStrategy.CROP:
return renderCropOptionDesc()
case HDStrategy.REISIZE:
case HDStrategy.RESIZE:
return renderResizeOptionDesc()
default:
return renderOriginalOptionDesc()

View File

@@ -2,11 +2,12 @@ import React, { ReactNode } from 'react'
import { useRecoilState } from 'recoil'
import { settingState } from '../../store/Atoms'
import Selector from '../shared/Selector'
import NumberInputSetting from './NumberInputSetting'
import SettingBlock from './SettingBlock'
export enum AIModel {
LAMA = 'LaMa',
LDM = 'LDM',
LAMA = 'lama',
LDM = 'ldm',
}
function ModelSettingBlock() {
@@ -24,7 +25,7 @@ function ModelSettingBlock() {
githubUrl: string
) => {
return (
<div style={{ display: 'flex', flexDirection: 'column' }}>
<div style={{ display: 'flex', flexDirection: 'column', gap: '4px' }}>
<a
className="model-desc-link"
href={paperUrl}
@@ -34,14 +35,11 @@ function ModelSettingBlock() {
{name}
</a>
<br />
<a
className="model-desc-link"
href={githubUrl}
target="_blank"
rel="noreferrer noopener"
style={{ marginTop: '8px' }}
>
{githubUrl}
</a>
@@ -49,6 +47,28 @@ function ModelSettingBlock() {
)
}
const renderLDMModelDesc = () => {
return (
<div>
{renderModelDesc(
'High-Resolution Image Synthesis with Latent Diffusion Models',
'https://arxiv.org/abs/2112.10752',
'https://github.com/CompVis/latent-diffusion'
)}
<NumberInputSetting
title="Steps"
value={`${setting.ldmSteps}`}
onValue={value => {
const val = value.length === 0 ? 0 : parseInt(value, 10)
setSettingState(old => {
return { ...old, ldmSteps: val }
})
}}
/>
</div>
)
}
const renderOptionDesc = (): ReactNode => {
switch (setting.model) {
case AIModel.LAMA:
@@ -58,11 +78,7 @@ function ModelSettingBlock() {
'https://github.com/saic-mdal/lama'
)
case AIModel.LDM:
return renderModelDesc(
'High-Resolution Image Synthesis with Latent Diffusion Models',
'https://arxiv.org/abs/2112.10752',
'https://github.com/CompVis/latent-diffusion'
)
return renderLDMModelDesc()
default:
return <></>
}

View File

@@ -0,0 +1,40 @@
import React from 'react'
import NumberInput from '../shared/NumberInput'
import SettingBlock from './SettingBlock'
interface NumberInputSettingProps {
title: string
value: string
suffix?: string
onValue: (val: string) => void
}
function NumberInputSetting(props: NumberInputSettingProps) {
const { title, value, suffix, onValue } = props
return (
<SettingBlock
className="sub-setting-block"
title={title}
input={
<div
style={{
display: 'flex',
justifyContent: 'center',
alignItems: 'center',
gap: '8px',
}}
>
<NumberInput
style={{ width: '80px' }}
value={`${value}`}
onValue={onValue}
/>
{suffix && <span>{suffix}</span>}
</div>
}
/>
)
}
export default NumberInputSetting

View File

@@ -1,14 +1,26 @@
import React, { ReactNode } from 'react'
import { useRecoilState } from 'recoil'
import { settingState } from '../../store/Atoms'
import { Switch, SwitchThumb } from '../shared/Switch'
import SettingBlock from './SettingBlock'
function SavePathSettingBlock() {
const [setting, setSettingState] = useRecoilState(settingState)
const onCheckChange = (checked: boolean) => {
setSettingState(old => {
return { ...old, saveImageBesideOrigin: checked }
})
}
return (
<SettingBlock
title="Download image beside origin image"
input={
<Switch defaultChecked>
<Switch
checked={setting.saveImageBesideOrigin}
onCheckedChange={onCheckChange}
>
<SwitchThumb />
</Switch>
}

View File

@@ -7,7 +7,7 @@
margin-top: 12px;
border: 1px solid var(--border-color);
border-radius: 0.3rem;
padding: 2rem;
padding: 1rem;
.sub-setting-block {
margin-top: 8px;

View File

@@ -8,7 +8,6 @@
background-color: var(--modal-bg);
color: var(--modal-text-color);
box-shadow: 0px 0px 20px rgb(0, 0, 40, 0.2);
min-height: 600px;
width: 700px;
@include mobile {

View File

@@ -1,11 +1,11 @@
import React from 'react'
import { useRecoilState } from 'recoil'
import { switchModel } from '../../adapters/inpainting'
import { settingState } from '../../store/Atoms'
import Modal from '../shared/Modal'
import HDSettingBlock from './HDSettingBlock'
import ModelSettingBlock from './ModelSettingBlock'
import SavePathSettingBlock from './SavePathSettingBlock'
export default function SettingModal() {
const [setting, setSettingState] = useRecoilState(settingState)
@@ -14,6 +14,8 @@ export default function SettingModal() {
setSettingState(old => {
return { ...old, show: false }
})
switchModel(setting.model)
}
return (
@@ -23,7 +25,9 @@ export default function SettingModal() {
className="modal-setting"
show={setting.show}
>
<SavePathSettingBlock />
{/* It's not possible because this poses a security risk */}
{/* https://stackoverflow.com/questions/34870711/download-a-file-at-different-location-using-html5 */}
{/* <SavePathSettingBlock /> */}
<ModelSettingBlock />
<HDSettingBlock />
</Modal>

View File

@@ -1,9 +1,7 @@
import React from 'react'
import { useRecoilValue } from 'recoil'
import Editor from './Editor/Editor'
import { shortcutsState } from '../store/Atoms'
import ShortcutsModal from './Shortcuts/ShortcutsModal'
import SettingModal from './Setting/SettingModal'
import SettingModal from './Settings/SettingsModal'
interface WorkspaceProps {
file: File

View File

@@ -16,11 +16,15 @@ export default function Modal(props: ModalProps) {
const ref = useRef(null)
useClickAway(ref, () => {
onClose?.()
if (show) {
onClose?.()
}
})
useKeyPressEvent('Escape', e => {
onClose?.()
if (show) {
onClose?.()
}
})
return (

View File

@@ -1,6 +1,6 @@
import { atom } from 'recoil'
import { HDStrategy } from '../components/Setting/HDSettingBlock'
import { AIModel } from '../components/Setting/ModelSettingBlock'
import { HDStrategy } from '../components/Settings/HDSettingBlock'
import { AIModel } from '../components/Settings/ModelSettingBlock'
export const fileState = atom<File | undefined>({
key: 'fileState',
@@ -16,10 +16,15 @@ export interface Setting {
show: boolean
saveImageBesideOrigin: boolean
model: AIModel
// For LaMa
hdStrategy: HDStrategy
hdStrategyResizeLimit: number
hdStrategyCropTrigerSize: number
hdStrategyCropMargin: number
// For LDM
ldmSteps: number
}
export const settingState = atom<Setting>({
@@ -28,9 +33,10 @@ export const settingState = atom<Setting>({
show: false,
saveImageBesideOrigin: false,
model: AIModel.LAMA,
hdStrategy: HDStrategy.ORIGINAL,
hdStrategy: HDStrategy.RESIZE,
hdStrategyResizeLimit: 2048,
hdStrategyCropTrigerSize: 2048,
hdStrategyCropMargin: 128,
ldmSteps: 50,
},
})

View File

@@ -11,7 +11,7 @@
@use '../components/Header/Header';
@use '../components/Header/ThemeChanger';
@use '../components/Shortcuts/Shortcuts';
@use '../components/Setting/Setting.scss';
@use '../components/Settings/Settings.scss';
// Shared
@use '../components/FileSelect/FileSelect';

View File

@@ -31,7 +31,11 @@ def ceil_modulo(x, mod):
def numpy_to_bytes(image_numpy: np.ndarray, ext: str) -> bytes:
data = cv2.imencode(f".{ext}", image_numpy)[1]
data = cv2.imencode(f".{ext}", image_numpy,
[
int(cv2.IMWRITE_JPEG_QUALITY), 100,
int(cv2.IMWRITE_PNG_COMPRESSION), 0
])[1]
image_bytes = data.tobytes()
return image_bytes
@@ -74,13 +78,24 @@ def resize_max_size(
return np_img
def pad_img_to_modulo(img, mod):
channels, height, width = img.shape
def pad_img_to_modulo(img: np.ndarray, mod: int):
"""
Args:
img: [H, W, C]
mod:
Returns:
"""
if len(img.shape) == 2:
img = img[:, :, np.newaxis]
height, width = img.shape[:2]
out_height = ceil_modulo(height, mod)
out_width = ceil_modulo(width, mod)
return np.pad(
img,
((0, 0), (0, out_height - height), (0, out_width - width)),
((0, out_height - height), (0, out_width - width), (0, 0)),
mode="symmetric",
)
@@ -88,15 +103,13 @@ def pad_img_to_modulo(img, mod):
def boxes_from_mask(mask: np.ndarray) -> List[np.ndarray]:
"""
Args:
mask: (1, h, w) 0~1
mask: (h, w, 1) 0~255
Returns:
"""
height, width = mask.shape[1:]
_, thresh = cv2.threshold(
(mask.transpose(1, 2, 0) * 255).astype(np.uint8), 127, 255, 0
)
height, width = mask.shape[:2]
_, thresh = cv2.threshold(mask, 127, 255, 0)
contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
boxes = []

View File

@@ -1,121 +0,0 @@
import os
from typing import List
import cv2
import torch
import numpy as np
from lama_cleaner.helper import pad_img_to_modulo, download_model, boxes_from_mask
LAMA_MODEL_URL = os.environ.get(
"LAMA_MODEL_URL",
"https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt",
)
class LaMa:
def __init__(self, crop_trigger_size: List[int], crop_margin: int, device):
"""
Args:
crop_trigger_size: h, w
crop_margin:
device:
"""
self.crop_trigger_size = crop_trigger_size
self.crop_margin = crop_margin
self.device = device
if os.environ.get("LAMA_MODEL"):
model_path = os.environ.get("LAMA_MODEL")
if not os.path.exists(model_path):
raise FileNotFoundError(
f"lama torchscript model not found: {model_path}"
)
else:
model_path = download_model(LAMA_MODEL_URL)
print(f"Load LaMa model from: {model_path}")
model = torch.jit.load(model_path, map_location="cpu")
model = model.to(device)
model.eval()
self.model = model
@torch.no_grad()
def __call__(self, image, mask):
"""
image: [C, H, W] RGB
mask: [1, H, W]
return: BGR IMAGE
"""
area = image.shape[1] * image.shape[2]
if area < self.crop_trigger_size[0] * self.crop_trigger_size[1]:
return self._run(image, mask)
print("Trigger crop image")
boxes = boxes_from_mask(mask)
crop_result = []
for box in boxes:
crop_image, crop_box = self._run_box(image, mask, box)
crop_result.append((crop_image, crop_box))
image = (image.transpose(1, 2, 0) * 255).astype(np.uint8)[:, :, ::-1]
for crop_image, crop_box in crop_result:
x1, y1, x2, y2 = crop_box
image[y1:y2, x1:x2, :] = crop_image
return image
def _run_box(self, image, mask, box):
"""
Args:
image: [C, H, W] RGB
mask: [1, H, W]
box: [left,top,right,bottom]
Returns:
BGR IMAGE
"""
box_h = box[3] - box[1]
box_w = box[2] - box[0]
cx = (box[0] + box[2]) // 2
cy = (box[1] + box[3]) // 2
img_h, img_w = image.shape[1:]
w = box_w + self.crop_margin * 2
h = box_h + self.crop_margin * 2
l = max(cx - w // 2, 0)
t = max(cy - h // 2, 0)
r = min(cx + w // 2, img_w)
b = min(cy + h // 2, img_h)
crop_img = image[:, t:b, l:r]
crop_mask = mask[:, t:b, l:r]
print(f"box size: ({box_h},{box_w}) crop size: {crop_img.shape}")
return self._run(crop_img, crop_mask), [l, t, r, b]
def _run(self, image, mask):
"""
image: [C, H, W] RGB
mask: [1, H, W]
return: BGR IMAGE
"""
device = self.device
origin_height, origin_width = image.shape[1:]
image = pad_img_to_modulo(image, mod=8)
mask = pad_img_to_modulo(mask, mod=8)
mask = (mask > 0) * 1
image = torch.from_numpy(image).unsqueeze(0).to(device)
mask = torch.from_numpy(mask).unsqueeze(0).to(device)
inpainted_image = self.model(image, mask)
cur_res = inpainted_image[0].permute(1, 2, 0).detach().cpu().numpy()
cur_res = cur_res[0:origin_height, 0:origin_width, :]
cur_res = np.clip(cur_res * 255, 0, 255).astype("uint8")
cur_res = cv2.cvtColor(cur_res, cv2.COLOR_RGB2BGR)
return cur_res

122
lama_cleaner/model/base.py Normal file
View File

@@ -0,0 +1,122 @@
import abc
import cv2
import torch
from loguru import logger
from lama_cleaner.helper import boxes_from_mask, resize_max_size, pad_img_to_modulo
from lama_cleaner.schema import Config, HDStrategy
class InpaintModel:
pad_mod = 8
def __init__(self, device):
"""
Args:
device:
"""
self.device = device
self.init_model(device)
@abc.abstractmethod
def init_model(self, device):
...
@abc.abstractmethod
def forward(self, image, mask, config: Config):
"""Input image and output image have same size
image: [H, W, C] RGB
mask: [H, W]
return: BGR IMAGE
"""
...
def _pad_forward(self, image, mask, config: Config):
origin_height, origin_width = image.shape[:2]
padd_image = pad_img_to_modulo(image, mod=self.pad_mod)
padd_mask = pad_img_to_modulo(mask, mod=self.pad_mod)
result = self.forward(padd_image, padd_mask, config)
result = result[0:origin_height, 0:origin_width, :]
original_pixel_indices = mask != 255
result[original_pixel_indices] = image[:, :, ::-1][original_pixel_indices]
return result
@torch.no_grad()
def __call__(self, image, mask, config: Config):
"""
image: [H, W, C] RGB, not normalized
mask: [H, W]
return: BGR IMAGE
"""
inpaint_result = None
logger.info(f"hd_strategy: {config.hd_strategy}")
if config.hd_strategy == HDStrategy.CROP:
if max(image.shape) > config.hd_strategy_crop_trigger_size:
logger.info(f"Run crop strategy")
boxes = boxes_from_mask(mask)
crop_result = []
for box in boxes:
crop_image, crop_box = self._run_box(image, mask, box, config)
crop_result.append((crop_image, crop_box))
inpaint_result = image[:, :, ::-1]
for crop_image, crop_box in crop_result:
x1, y1, x2, y2 = crop_box
inpaint_result[y1:y2, x1:x2, :] = crop_image
elif config.hd_strategy == HDStrategy.RESIZE:
if max(image.shape) > config.hd_strategy_resize_limit:
origin_size = image.shape[:2]
downsize_image = resize_max_size(image, size_limit=config.hd_strategy_resize_limit)
downsize_mask = resize_max_size(mask, size_limit=config.hd_strategy_resize_limit)
logger.info(f"Run resize strategy, origin size: {image.shape} forward size: {downsize_image.shape}")
inpaint_result = self._pad_forward(downsize_image, downsize_mask, config)
# only paste masked area result
inpaint_result = cv2.resize(inpaint_result,
(origin_size[1], origin_size[0]),
interpolation=cv2.INTER_CUBIC)
original_pixel_indices = mask != 255
inpaint_result[original_pixel_indices] = image[:, :, ::-1][original_pixel_indices]
if inpaint_result is None:
inpaint_result = self._pad_forward(image, mask, config)
return inpaint_result
def _run_box(self, image, mask, box, config: Config):
"""
Args:
image: [H, W, C] RGB
mask: [H, W, 1]
box: [left,top,right,bottom]
Returns:
BGR IMAGE
"""
box_h = box[3] - box[1]
box_w = box[2] - box[0]
cx = (box[0] + box[2]) // 2
cy = (box[1] + box[3]) // 2
img_h, img_w = image.shape[:2]
w = box_w + config.hd_strategy_crop_margin * 2
h = box_h + config.hd_strategy_crop_margin * 2
l = max(cx - w // 2, 0)
t = max(cy - h // 2, 0)
r = min(cx + w // 2, img_w)
b = min(cy + h // 2, img_h)
crop_img = image[t:b, l:r, :]
crop_mask = mask[t:b, l:r]
logger.info(f"box size: ({box_h},{box_w}) crop size: {crop_img.shape}")
return self._pad_forward(crop_img, crop_mask, config), [l, t, r, b]

View File

@@ -0,0 +1,64 @@
import os
import cv2
import numpy as np
import torch
from loguru import logger
from lama_cleaner.helper import pad_img_to_modulo, download_model, norm_img
from lama_cleaner.model.base import InpaintModel
from lama_cleaner.schema import Config
LAMA_MODEL_URL = os.environ.get(
"LAMA_MODEL_URL",
"https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt",
)
class LaMa(InpaintModel):
pad_mod = 8
def __init__(self, device):
"""
Args:
device:
"""
super().__init__(device)
self.device = device
def init_model(self, device):
if os.environ.get("LAMA_MODEL"):
model_path = os.environ.get("LAMA_MODEL")
if not os.path.exists(model_path):
raise FileNotFoundError(
f"lama torchscript model not found: {model_path}"
)
else:
model_path = download_model(LAMA_MODEL_URL)
logger.info(f"Load LaMa model from: {model_path}")
model = torch.jit.load(model_path, map_location="cpu")
model = model.to(device)
model.eval()
self.model = model
def forward(self, image, mask, config: Config):
"""Input image and output image have same size
image: [H, W, C] RGB
mask: [H, W]
return: BGR IMAGE
"""
image = norm_img(image)
mask = norm_img(mask)
mask = (mask > 0) * 1
image = torch.from_numpy(image).unsqueeze(0).to(self.device)
mask = torch.from_numpy(mask).unsqueeze(0).to(self.device)
inpainted_image = self.model(image, mask)
cur_res = inpainted_image[0].permute(1, 2, 0).detach().cpu().numpy()
cur_res = np.clip(cur_res * 255, 0, 255).astype("uint8")
cur_res = cv2.cvtColor(cur_res, cv2.COLOR_RGB2BGR)
return cur_res

View File

@@ -2,13 +2,16 @@ import os
import numpy as np
import torch
from loguru import logger
from lama_cleaner.model.base import InpaintModel
from lama_cleaner.schema import Config
torch.manual_seed(42)
import torch.nn as nn
from tqdm import tqdm
import cv2
from lama_cleaner.helper import pad_img_to_modulo, download_model
from lama_cleaner.ldm.utils import make_beta_schedule, make_ddim_timesteps, make_ddim_sampling_parameters, noise_like, \
from lama_cleaner.helper import download_model, norm_img
from lama_cleaner.model.utils import make_beta_schedule, make_ddim_timesteps, make_ddim_sampling_parameters, noise_like, \
timestep_embedding
LDM_ENCODE_MODEL_URL = os.environ.get(
@@ -217,7 +220,7 @@ class DDIMSampler(object):
time_range = reversed(range(0, timesteps)) if ddim_use_original_steps else np.flip(timesteps)
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
print(f"Running DDIM Sampling with {total_steps} timesteps")
logger.info(f"Running DDIM Sampling with {total_steps} timesteps")
iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
@@ -268,97 +271,39 @@ def load_jit_model(url, device):
return model
class LDM:
def __init__(self, device, steps=50):
class LDM(InpaintModel):
pad_mod = 32
def __init__(self, device):
super().__init__(device)
self.device = device
def init_model(self, device):
self.diffusion_model = load_jit_model(LDM_DIFFUSION_MODEL_URL, device)
self.cond_stage_model_decode = load_jit_model(LDM_DECODE_MODEL_URL, device)
self.cond_stage_model_encode = load_jit_model(LDM_ENCODE_MODEL_URL, device)
model = LatentDiffusion(self.diffusion_model, device)
self.sampler = DDIMSampler(model)
self.steps = steps
def _norm(self, tensor):
return tensor * 2.0 - 1.0
@torch.no_grad()
def __call__(self, image, mask):
def forward(self, image, mask, config: Config):
"""
image: [C, H, W] RGB
mask: [1, H, W]
image: [H, W, C] RGB
mask: [H, W, 1]
return: BGR IMAGE
"""
# image [1,3,512,512] float32
# mask: [1,1,512,512] float32
# masked_image: [1,3,512,512] float32
origin_height, origin_width = image.shape[1:]
image = pad_img_to_modulo(image, mod=32)
mask = pad_img_to_modulo(mask, mod=32)
padded_height, padded_width = image.shape[1:]
steps = config.ldm_steps
image = norm_img(image)
mask = norm_img(mask)
mask[mask < 0.5] = 0
mask[mask >= 0.5] = 1
# crop 512 x 512
if padded_width <= 512 or padded_height <= 512:
np_img = self._forward(image, mask, self.device)
else:
print("Try to zoom in")
# zoom in
# x,y,w,h
# box = self.box_from_bitmap(mask)
box = self.find_main_content(mask)
if box is None:
print("No bbox found")
np_img = self._forward(image, mask, self.device)
else:
print(f"box: {box}")
box_x, box_y, box_w, box_h = box
cx = box_x + box_w // 2
cy = box_y + box_h // 2
# w = max(512, box_w)
# h = max(512, box_h)
w = box_w + 512
h = box_h + 512
left = max(cx - w // 2, 0)
top = max(cy - h // 2, 0)
right = min(cx + w // 2, origin_width)
bottom = min(cy + h // 2, origin_height)
x = left
y = top
w = right - left
h = bottom - top
crop_img = image[:, int(y):int(y + h), int(x):int(x + w)]
crop_mask = mask[:, int(y):int(y + h), int(x):int(x + w)]
print(f"Apply zoom in size width x height: {crop_img.shape}")
crop_img_height, crop_img_width = crop_img.shape[1:]
crop_img = pad_img_to_modulo(crop_img, mod=32)
crop_mask = pad_img_to_modulo(crop_mask, mod=32)
# RGB
np_img = self._forward(crop_img, crop_mask, self.device)
image = (image.transpose(1, 2, 0) * 255).astype(np.uint8)
image[int(y): int(y + h), int(x): int(x + w), :] = np_img[0:crop_img_height, 0:crop_img_width, :]
np_img = image
# BGR to RGB
# np_img = image[:, :, ::-1]
np_img = np_img[0:origin_height, 0:origin_width, :]
np_img = np_img[:, :, ::-1]
return np_img
def _forward(self, image, mask, device):
image = torch.from_numpy(image).unsqueeze(0).to(device)
mask = torch.from_numpy(mask).unsqueeze(0).to(device)
image = torch.from_numpy(image).unsqueeze(0).to(self.device)
mask = torch.from_numpy(mask).unsqueeze(0).to(self.device)
masked_image = (1 - mask) * image
image = self._norm(image)
@@ -371,47 +316,20 @@ class LDM:
c = torch.cat((c, cc), dim=1) # 1,4,128,128
shape = (c.shape[1] - 1,) + c.shape[2:]
samples_ddim = self.sampler.sample(steps=self.steps,
samples_ddim = self.sampler.sample(steps=steps,
conditioning=c,
batch_size=c.shape[0],
shape=shape)
x_samples_ddim = self.cond_stage_model_decode(samples_ddim) # samples_ddim: 1, 3, 128, 128 float32
image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
mask = torch.clamp((mask + 1.0) / 2.0, min=0.0, max=1.0)
predicted_image = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
# image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
# mask = torch.clamp((mask + 1.0) / 2.0, min=0.0, max=1.0)
inpainted_image = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
inpainted = (1 - mask) * image + mask * predicted_image
inpainted = inpainted.cpu().numpy().transpose(0, 2, 3, 1)[0] * 255
np_img = inpainted.astype(np.uint8)
return np_img
# inpainted = (1 - mask) * image + mask * predicted_image
inpainted_image = inpainted_image.cpu().numpy().transpose(0, 2, 3, 1)[0] * 255
inpainted_image = inpainted_image.astype(np.uint8)[:, :, ::-1]
return inpainted_image
def find_main_content(self, bitmap: np.ndarray):
th2 = bitmap[0].astype(np.uint8)
row_sum = th2.sum(1)
col_sum = th2.sum(0)
xmin = max(0, np.argwhere(col_sum != 0).min() - 20)
xmax = min(np.argwhere(col_sum != 0).max() + 20, th2.shape[1])
ymin = max(0, np.argwhere(row_sum != 0).min() - 20)
ymax = min(np.argwhere(row_sum != 0).max() + 20, th2.shape[0])
left, top, right, bottom = int(xmin), int(ymin), int(xmax), int(ymax)
return left, top, right - left, bottom - top
def box_from_bitmap(self, bitmap):
"""
bitmap: single map with shape (NUM_CLASSES, H, W),
whose values are binarized as {0, 1}
"""
contours, _ = cv2.findContours(
(bitmap[0] * 255).astype(np.uint8), cv2.RETR_FLOODFILL, cv2.CHAIN_APPROX_NONE
)
contours = sorted(contours, key=lambda x: cv2.contourArea(x), reverse=True)
num_contours = len(contours)
print(f"contours size: {num_contours}")
if num_contours != 1:
return None
# x,y,w,h
return cv2.boundingRect(contours[0])
def _norm(self, tensor):
return tensor * 2.0 - 1.0

View File

@@ -0,0 +1,34 @@
from lama_cleaner.model.lama import LaMa
from lama_cleaner.model.ldm import LDM
from lama_cleaner.schema import Config
class ModelManager:
LAMA = 'lama'
LDM = 'ldm'
def __init__(self, name: str, device):
self.name = name
self.device = device
self.model = self.init_model(name, device)
def init_model(self, name: str, device):
if name == self.LAMA:
model = LaMa(device)
elif name == self.LDM:
model = LDM(device)
else:
raise NotImplementedError(f"Not supported model: {name}")
return model
def __call__(self, image, mask, config: Config):
return self.model(image, mask, config)
def switch(self, new_name: str):
if new_name == self.name:
return
try:
self.model = self.init_model(new_name, self.device)
self.name = new_name
except NotImplementedError as e:
raise e

17
lama_cleaner/schema.py Normal file
View File

@@ -0,0 +1,17 @@
from enum import Enum
from pydantic import BaseModel
class HDStrategy(str, Enum):
ORIGINAL = 'Original'
RESIZE = 'Resize'
CROP = 'Crop'
class Config(BaseModel):
ldm_steps: int
hd_strategy: str
hd_strategy_crop_margin: int
hd_strategy_crop_trigger_size: int
hd_strategy_resize_limit: int

View File

Binary file not shown.

After

Width:  |  Height:  |  Size: 129 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 193 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 193 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 193 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 193 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 193 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 193 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 11 KiB

BIN
lama_cleaner/tests/mask.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 7.7 KiB

View File

@@ -1,15 +0,0 @@
import cv2
import numpy as np
from lama_cleaner.helper import boxes_from_mask
def test_boxes_from_mask():
mask = cv2.imread("mask.jpg", cv2.IMREAD_GRAYSCALE)
mask = mask[:, :, np.newaxis]
mask = (mask / 255).transpose(2, 0, 1)
boxes = boxes_from_mask(mask)
print(boxes)
test_boxes_from_mask()

View File

@@ -0,0 +1,55 @@
import os
from pathlib import Path
import cv2
import numpy as np
import pytest
from lama_cleaner.model_manager import ModelManager
from lama_cleaner.schema import Config, HDStrategy
current_dir = Path(__file__).parent.absolute().resolve()
def get_data():
img = cv2.imread(str(current_dir / 'image.png'))
img = cv2.cvtColor(img, cv2.COLOR_BGRA2RGB)
mask = cv2.imread(str(current_dir / 'mask.png'), cv2.IMREAD_GRAYSCALE)
return img, mask
def get_config(strategy):
return Config(
ldm_steps=1,
hd_strategy=strategy,
hd_strategy_crop_margin=32,
hd_strategy_crop_trigger_size=200,
hd_strategy_resize_limit=200,
)
def assert_equal(model, config, gt_name):
img, mask = get_data()
res = model(img, mask, config)
# cv2.imwrite(gt_name, res,
# [int(cv2.IMWRITE_JPEG_QUALITY), 100, int(cv2.IMWRITE_PNG_COMPRESSION), 0])
"""
Note that JPEG is lossy compression, so even if it is the highest quality 100,
when the saved image is reloaded, a difference occurs with the original pixel value.
If you want to save the original image as it is, save it as PNG or BMP.
"""
gt = cv2.imread(str(current_dir / gt_name), cv2.IMREAD_UNCHANGED)
assert np.array_equal(res, gt)
@pytest.mark.parametrize('strategy', [HDStrategy.ORIGINAL, HDStrategy.RESIZE, HDStrategy.CROP])
def test_lama(strategy):
model = ModelManager(name='lama', device='cpu')
assert_equal(model, get_config(strategy), f'lama_{strategy[0].upper() + strategy[1:]}_result.png')
@pytest.mark.parametrize('strategy', [HDStrategy.ORIGINAL, HDStrategy.RESIZE, HDStrategy.CROP])
def test_ldm(strategy):
model = ModelManager(name='ldm', device='cpu')
assert_equal(model, get_config(strategy), f'ldm_{strategy[0].upper() + strategy[1:]}_result.png')