From ec7b2d8e2d7a6c764bed927269639830fc3292c6 Mon Sep 17 00:00:00 2001 From: Qing Date: Thu, 29 Sep 2022 12:20:55 +0800 Subject: [PATCH] add sd-cpu-textencoder args --- lama_cleaner/model/sd.py | 8 ++++++-- lama_cleaner/parse_args.py | 9 +++++++-- lama_cleaner/server.py | 1 + 3 files changed, 14 insertions(+), 4 deletions(-) diff --git a/lama_cleaner/model/sd.py b/lama_cleaner/model/sd.py index 1bc47e0..8aaeab8 100644 --- a/lama_cleaner/model/sd.py +++ b/lama_cleaner/model/sd.py @@ -71,8 +71,7 @@ class SD(InpaintModel): from .sd_pipeline import StableDiffusionInpaintPipeline model_kwargs = {} - sd_disable_nsfw = kwargs.pop('sd_disable_nsfw', False) - if sd_disable_nsfw: + if kwargs['sd_disable_nsfw']: logger.info("Disable Stable Diffusion Model NSFW checker") model_kwargs.update(dict( feature_extractor=DummyFeatureExtractor(), @@ -89,6 +88,11 @@ class SD(InpaintModel): # https://huggingface.co/docs/diffusers/v0.3.0/en/api/pipelines/stable_diffusion#diffusers.StableDiffusionInpaintPipeline.enable_attention_slicing self.model.enable_attention_slicing() self.model = self.model.to(device) + + if kwargs['sd_cpu_textencoder']: + logger.info("Run Stable Diffusion TextEncoder on CPU") + self.model.text_encoder = self.model.text_encoder.to(torch.device('cpu')) + self.callbacks = kwargs.pop("callbacks", None) @torch.cuda.amp.autocast() diff --git a/lama_cleaner/parse_args.py b/lama_cleaner/parse_args.py index 359621b..b876dde 100644 --- a/lama_cleaner/parse_args.py +++ b/lama_cleaner/parse_args.py @@ -15,12 +15,17 @@ def parse_args(): parser.add_argument( "--hf_access_token", default="", - help="huggingface access token. Check how to get token from: https://huggingface.co/docs/hub/security-tokens", + help="Huggingface access token. Check how to get token from: https://huggingface.co/docs/hub/security-tokens", ) parser.add_argument( "--sd-disable-nsfw", action="store_true", - help="disable stable diffusion nsfw checker", + help="Disable Stable Diffusion nsfw checker", + ) + parser.add_argument( + "--sd-cpu-textencoder", + action="store_true", + help="Always run Stable Diffusion TextEncoder model on CPU", ) parser.add_argument("--device", default="cuda", type=str, choices=["cuda", "cpu"]) parser.add_argument("--gui", action="store_true", help="Launch as desktop app") diff --git a/lama_cleaner/server.py b/lama_cleaner/server.py index 0fd8119..8fcd332 100644 --- a/lama_cleaner/server.py +++ b/lama_cleaner/server.py @@ -219,6 +219,7 @@ def main(args): device=device, hf_access_token=args.hf_access_token, sd_disable_nsfw=args.sd_disable_nsfw, + sd_cpu_textencoder=args.sd_cpu_textencoder, callbacks=[diffuser_callback], )