diff --git a/lama_cleaner/plugins/gfpgan_plugin.py b/lama_cleaner/plugins/gfpgan_plugin.py index 12e66b5..fe8a2a3 100644 --- a/lama_cleaner/plugins/gfpgan_plugin.py +++ b/lama_cleaner/plugins/gfpgan_plugin.py @@ -21,7 +21,7 @@ class GFPGANPlugin(BasePlugin): import facexlib if hasattr(facexlib.detection.retinaface, "device"): - facexlib.detection.retinaface.device = device + facexlib.detection.retinaface.device = face_det_device # Use GFPGAN for face enhancement self.face_enhancer = MyGFPGANer( @@ -32,9 +32,9 @@ class GFPGANPlugin(BasePlugin): device=device, bg_upsampler=upscaler.model if upscaler is not None else None, ) - self.face_enhancer.face_helper.face_det.mean_tensor.to(device) + self.face_enhancer.face_helper.face_det.mean_tensor.to(face_det_device) self.face_enhancer.face_helper.face_det = ( - self.face_enhancer.face_helper.face_det.to(device) + self.face_enhancer.face_helper.face_det.to(face_det_device) ) def __call__(self, rgb_np_img, files, form):