diff --git a/lama_cleaner/app/src/components/Editor/Editor.tsx b/lama_cleaner/app/src/components/Editor/Editor.tsx index 5d0b0ee..2c2a697 100644 --- a/lama_cleaner/app/src/components/Editor/Editor.tsx +++ b/lama_cleaner/app/src/components/Editor/Editor.tsx @@ -598,6 +598,15 @@ export default function Editor() { } }, [runRenderablePlugin]) + useEffect(() => { + emitter.on(PluginName.GFPGAN, () => { + runRenderablePlugin(PluginName.GFPGAN) + }) + return () => { + emitter.off(PluginName.GFPGAN) + } + }, [runRenderablePlugin]) + useEffect(() => { emitter.on(PluginName.RealESRGAN, (data: any) => { runRenderablePlugin(PluginName.RealESRGAN, data) diff --git a/lama_cleaner/app/src/components/Plugins/Plugins.tsx b/lama_cleaner/app/src/components/Plugins/Plugins.tsx index 289ae19..087c4ac 100644 --- a/lama_cleaner/app/src/components/Plugins/Plugins.tsx +++ b/lama_cleaner/app/src/components/Plugins/Plugins.tsx @@ -4,6 +4,7 @@ import { CursorArrowRaysIcon, GifIcon } from '@heroicons/react/24/outline' import { BoxModelIcon, ChevronRightIcon, + FaceIcon, HobbyKnifeIcon, MixIcon, } from '@radix-ui/react-icons' @@ -20,6 +21,7 @@ import Button from '../shared/Button' export enum PluginName { RemoveBG = 'RemoveBG', RealESRGAN = 'RealESRGAN', + GFPGAN = 'GFPGAN', InteractiveSeg = 'InteractiveSeg', MakeGIF = 'MakeGIF', } @@ -33,6 +35,10 @@ const pluginMap = { IconClass: BoxModelIcon, showName: 'RealESRGAN 4x', }, + [PluginName.GFPGAN]: { + IconClass: FaceIcon, + showName: 'GFPGAN', + }, [PluginName.InteractiveSeg]: { IconClass: CursorArrowRaysIcon, showName: 'Interactive Seg', diff --git a/lama_cleaner/server.py b/lama_cleaner/server.py index 99b5958..cb79b8e 100644 --- a/lama_cleaner/server.py +++ b/lama_cleaner/server.py @@ -327,28 +327,46 @@ def run_plugin(): return "Plugin not found", 500 origin_image_bytes = files["image"].read() # RGB - rgb_np_img, _ = load_img(origin_image_bytes) + rgb_np_img, alpha_channel, exif = load_img(origin_image_bytes, return_exif=True) start = time.time() - res = plugins[name](rgb_np_img, files, form) + bgr_res = plugins[name](rgb_np_img, files, form) logger.info(f"{name} process time: {(time.time() - start) * 1000}ms") torch_gc() if name == MakeGIF.name: - filename = form["filename"] return send_file( - io.BytesIO(res), + io.BytesIO(bgr_res), mimetype="image/gif", as_attachment=True, - attachment_filename=filename, + attachment_filename=form["filename"], ) + + if name == RemoveBG.name: + rgb_res = cv2.cvtColor(bgr_res, cv2.COLOR_BGRA2RGBA) + ext = "png" else: - response = make_response( - send_file( - io.BytesIO(numpy_to_bytes(res, "png")), - mimetype=f"image/png", + rgb_res = cv2.cvtColor(bgr_res, cv2.COLOR_BGR2RGB) + ext = get_image_ext(origin_image_bytes) + if alpha_channel is not None: + if alpha_channel.shape[:2] != rgb_res.shape[:2]: + alpha_channel = cv2.resize( + alpha_channel, dsize=(rgb_res.shape[1], rgb_res.shape[0]) + ) + rgb_res = np.concatenate( + (rgb_res, alpha_channel[:, :, np.newaxis]), axis=-1 ) + + response = make_response( + send_file( + io.BytesIO( + pil_to_bytes( + Image.fromarray(rgb_res), ext, quality=image_quality, exif=exif + ) + ), + mimetype=f"image/{ext}", ) + ) return response diff --git a/lama_cleaner/tests/test_model.py b/lama_cleaner/tests/test_model.py index d1af506..a76f3c6 100644 --- a/lama_cleaner/tests/test_model.py +++ b/lama_cleaner/tests/test_model.py @@ -1,4 +1,3 @@ -import os from pathlib import Path import cv2 diff --git a/lama_cleaner/tests/test_plugins.py b/lama_cleaner/tests/test_plugins.py index 0634a5d..6d68ec9 100644 --- a/lama_cleaner/tests/test_plugins.py +++ b/lama_cleaner/tests/test_plugins.py @@ -1,27 +1,46 @@ from pathlib import Path import cv2 +import pytest +import torch.cuda -from lama_cleaner.plugins import RemoveBG, RealESRGANUpscaler +from lama_cleaner.plugins import RemoveBG, RealESRGANUpscaler, GFPGANPlugin current_dir = Path(__file__).parent.absolute().resolve() save_dir = current_dir / "result" save_dir.mkdir(exist_ok=True, parents=True) img_p = current_dir / "bunny.jpeg" +bgr_img = cv2.imread(str(img_p)) +rgb_img = cv2.cvtColor(bgr_img, cv2.COLOR_BGR2RGB) + + +def _save(img, name): + cv2.imwrite(str(save_dir / name), img) def test_remove_bg(): model = RemoveBG() - img = cv2.imread(str(img_p)) - res = model.forward(img) - cv2.imwrite(str(save_dir / "test_remove_bg.png"), res) + res = model.forward(bgr_img) + _save(res, "test_remove_bg.png") -def test_upscale(): - model = RealESRGANUpscaler("cpu") - img = cv2.imread(str(img_p)) - res = model.forward(img, 2) - cv2.imwrite(str(save_dir / "test_upscale_x2.png"), res) +@pytest.mark.parametrize("device", ["cuda", "cpu"]) +def test_upscale(device): + if device == "cuda" and not torch.cuda.is_available(): + return - res = model.forward(img, 4) - cv2.imwrite(str(save_dir / "test_upscale_x4.png"), res) + model = RealESRGANUpscaler("realesr-general-x4v3", device) + res = model.forward(bgr_img, 2) + _save(res, "test_upscale_x2.png") + + res = model.forward(bgr_img, 4) + _save(res, "test_upscale_x4.png") + + +@pytest.mark.parametrize("device", ["cuda", "cpu"]) +def test_gfpgan(device): + if device == "cuda" and not torch.cuda.is_available(): + return + model = GFPGANPlugin(device) + res = model(rgb_img, None, None) + _save(res, "test_gfpgan.png")