From 0d57e552cfbbcef7114e28a979408549f4e459b9 Mon Sep 17 00:00:00 2001 From: Qing Date: Thu, 29 Sep 2022 09:42:19 +0800 Subject: [PATCH] add sd-disable-nsfw arg --- lama_cleaner/model/sd.py | 36 +++++++++++++++++++++++++++++++++++- lama_cleaner/parse_args.py | 5 +++++ lama_cleaner/server.py | 1 + 3 files changed, 41 insertions(+), 1 deletion(-) diff --git a/lama_cleaner/model/sd.py b/lama_cleaner/model/sd.py index c7a3389..1bc47e0 100644 --- a/lama_cleaner/model/sd.py +++ b/lama_cleaner/model/sd.py @@ -6,6 +6,7 @@ import numpy as np import torch from diffusers import PNDMScheduler, DDIMScheduler from loguru import logger +from transformers import FeatureExtractionMixin, ImageFeatureExtractionMixin from lama_cleaner.helper import norm_img @@ -38,19 +39,52 @@ from lama_cleaner.schema import Config, SDSampler # mask = torch.from_numpy(mask) # return mask +class DummyFeatureExtractorOutput: + def __init__(self, pixel_values): + self.pixel_values = pixel_values + + def to(self, device): + return self + + +class DummyFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def __call__(self, *args, **kwargs): + return DummyFeatureExtractorOutput(torch.empty(0, 3)) + + +class DummySafetyChecker: + def __init__(self, *args, **kwargs): + pass + + def __call__(self, clip_input, images): + return images, False + class SD(InpaintModel): - pad_mod = 64 # current diffusers only support 64 https://github.com/huggingface/diffusers/pull/505 + pad_mod = 64 # current diffusers only support 64 https://github.com/huggingface/diffusers/pull/505 min_size = 512 def init_model(self, device: torch.device, **kwargs): from .sd_pipeline import StableDiffusionInpaintPipeline + model_kwargs = {} + sd_disable_nsfw = kwargs.pop('sd_disable_nsfw', False) + if sd_disable_nsfw: + logger.info("Disable Stable Diffusion Model NSFW checker") + model_kwargs.update(dict( + feature_extractor=DummyFeatureExtractor(), + safety_checker=DummySafetyChecker(), + )) + self.model = StableDiffusionInpaintPipeline.from_pretrained( self.model_id_or_path, revision="fp16" if torch.cuda.is_available() else "main", torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, use_auth_token=kwargs["hf_access_token"], + **model_kwargs ) # https://huggingface.co/docs/diffusers/v0.3.0/en/api/pipelines/stable_diffusion#diffusers.StableDiffusionInpaintPipeline.enable_attention_slicing self.model.enable_attention_slicing() diff --git a/lama_cleaner/parse_args.py b/lama_cleaner/parse_args.py index f4b720e..359621b 100644 --- a/lama_cleaner/parse_args.py +++ b/lama_cleaner/parse_args.py @@ -17,6 +17,11 @@ def parse_args(): default="", help="huggingface access token. Check how to get token from: https://huggingface.co/docs/hub/security-tokens", ) + parser.add_argument( + "--sd-disable-nsfw", + action="store_true", + help="disable stable diffusion nsfw checker", + ) parser.add_argument("--device", default="cuda", type=str, choices=["cuda", "cpu"]) parser.add_argument("--gui", action="store_true", help="Launch as desktop app") parser.add_argument( diff --git a/lama_cleaner/server.py b/lama_cleaner/server.py index 6397c0b..0fd8119 100644 --- a/lama_cleaner/server.py +++ b/lama_cleaner/server.py @@ -218,6 +218,7 @@ def main(args): name=args.model, device=device, hf_access_token=args.hf_access_token, + sd_disable_nsfw=args.sd_disable_nsfw, callbacks=[diffuser_callback], )