add realesrGAN selection

This commit is contained in:
Qing
2024-02-08 17:16:57 +08:00
parent 8060e16c70
commit f52dbc1091
6 changed files with 98 additions and 9 deletions

View File

@@ -43,7 +43,7 @@ from iopaint.helper import (
)
from iopaint.model.utils import torch_gc
from iopaint.model_manager import ModelManager
from iopaint.plugins import build_plugins
from iopaint.plugins import build_plugins, RealESRGANUpscaler
from iopaint.plugins.base_plugin import BasePlugin
from iopaint.plugins.remove_bg import RemoveBG
from iopaint.schema import (
@@ -59,6 +59,7 @@ from iopaint.schema import (
RemoveBGModel,
SwitchPluginModelRequest,
ModelInfo,
RealESRGANModel,
)
CURRENT_DIR = Path(__file__).parent.absolute().resolve()
@@ -192,6 +193,8 @@ class Api:
self.plugins[req.plugin_name].switch_model(req.model_name)
if req.plugin_name == RemoveBG.name:
self.config.remove_bg_model = req.model_name
if req.plugin_name == RealESRGANUpscaler.name:
self.config.realesrgan_model = req.model_name
def api_server_config(self) -> ServerConfigResponse:
plugins = []
@@ -209,6 +212,8 @@ class Api:
modelInfos=self.model_manager.scan_models(),
removeBGModel=self.config.remove_bg_model,
removeBGModels=RemoveBGModel.values(),
realesrganModel=self.config.realesrgan_model,
realesrganModels=RealESRGANModel.values(),
enableFileManager=self.file_manager is not None,
enableAutoSaving=self.config.output_dir is not None,
enableControlnet=self.model_manager.enable_controlnet,

View File

@@ -14,6 +14,12 @@ class RealESRGANUpscaler(BasePlugin):
def __init__(self, name, device, no_half=False):
super().__init__()
self.model_name = name
self.device = device
self.no_half = no_half
self._init_model(name)
def _init_model(self, name):
from basicsr.archs.rrdbnet_arch import RRDBNet
from realesrgan import RealESRGANer
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
@@ -70,13 +76,19 @@ class RealESRGANUpscaler(BasePlugin):
scale=model_info["scale"],
model_path=model_path,
model=model_info["model"](),
half=True if "cuda" in str(device) and not no_half else False,
half=True if "cuda" in str(self.device) and not self.no_half else False,
tile=512,
tile_pad=10,
pre_pad=10,
device=device,
device=self.device,
)
def switch_model(self, new_model_name: str):
if self.model_name == new_model_name:
return
self._init_model(new_model_name)
self.model_name = new_model_name
def gen_image(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray:
bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR)
logger.info(f"RealESRGAN input shape: {bgr_np_img.shape}, scale: {req.scale}")

View File

@@ -423,7 +423,9 @@ class ServerConfigResponse(BaseModel):
plugins: List[PluginInfo]
modelInfos: List[ModelInfo]
removeBGModel: RemoveBGModel
removeBGModels: List[str]
removeBGModels: List[RemoveBGModel]
realesrganModel: RealESRGANModel
realesrganModels: List[RealESRGANModel]
enableFileManager: bool
enableAutoSaving: bool
enableControlnet: bool

View File

@@ -44,7 +44,7 @@ default_configs = dict(
interactive_seg_model=InteractiveSegModel.vit_b,
interactive_seg_device=Device.cpu,
enable_remove_bg=False,
remove_bg_model=RemoveBGModel.u2net,
remove_bg_model=RemoveBGModel.briaai_rmbg_1_4,
enable_anime_seg=False,
enable_realesrgan=False,
realesrgan_device=Device.cpu,