add realesrGAN selection
This commit is contained in:
@@ -14,6 +14,12 @@ class RealESRGANUpscaler(BasePlugin):
|
||||
|
||||
def __init__(self, name, device, no_half=False):
|
||||
super().__init__()
|
||||
self.model_name = name
|
||||
self.device = device
|
||||
self.no_half = no_half
|
||||
self._init_model(name)
|
||||
|
||||
def _init_model(self, name):
|
||||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||
from realesrgan import RealESRGANer
|
||||
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
|
||||
@@ -70,13 +76,19 @@ class RealESRGANUpscaler(BasePlugin):
|
||||
scale=model_info["scale"],
|
||||
model_path=model_path,
|
||||
model=model_info["model"](),
|
||||
half=True if "cuda" in str(device) and not no_half else False,
|
||||
half=True if "cuda" in str(self.device) and not self.no_half else False,
|
||||
tile=512,
|
||||
tile_pad=10,
|
||||
pre_pad=10,
|
||||
device=device,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
def switch_model(self, new_model_name: str):
|
||||
if self.model_name == new_model_name:
|
||||
return
|
||||
self._init_model(new_model_name)
|
||||
self.model_name = new_model_name
|
||||
|
||||
def gen_image(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray:
|
||||
bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR)
|
||||
logger.info(f"RealESRGAN input shape: {bgr_np_img.shape}, scale: {req.scale}")
|
||||
|
||||
Reference in New Issue
Block a user