sd make change sampler work
This commit is contained in:
@@ -4,12 +4,13 @@ import PIL.Image
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from diffusers import PNDMScheduler, DDIMScheduler
|
||||
from loguru import logger
|
||||
|
||||
from lama_cleaner.helper import norm_img
|
||||
|
||||
from lama_cleaner.model.base import InpaintModel
|
||||
from lama_cleaner.schema import Config
|
||||
from lama_cleaner.schema import Config, SDSampler
|
||||
|
||||
|
||||
#
|
||||
@@ -43,13 +44,12 @@ class SD(InpaintModel):
|
||||
min_size = 512
|
||||
|
||||
def init_model(self, device: torch.device, **kwargs):
|
||||
# return
|
||||
from .sd_pipeline import StableDiffusionInpaintPipeline
|
||||
|
||||
self.model = StableDiffusionInpaintPipeline.from_pretrained(
|
||||
self.model_id_or_path,
|
||||
revision="fp16",
|
||||
torch_dtype=torch.float16,
|
||||
revision="fp16" if torch.cuda.is_available() else 'main',
|
||||
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
|
||||
use_auth_token=kwargs["hf_access_token"],
|
||||
)
|
||||
# https://huggingface.co/docs/diffusers/v0.3.0/en/api/pipelines/stable_diffusion#diffusers.StableDiffusionInpaintPipeline.enable_attention_slicing
|
||||
@@ -59,7 +59,6 @@ class SD(InpaintModel):
|
||||
|
||||
@torch.cuda.amp.autocast()
|
||||
def forward(self, image, mask, config: Config):
|
||||
# return image
|
||||
"""Input image and output image have same size
|
||||
image: [H, W, C] RGB
|
||||
mask: [H, W, 1] 255 means area to repaint
|
||||
@@ -76,9 +75,30 @@ class SD(InpaintModel):
|
||||
#
|
||||
# image = torch.from_numpy(image).unsqueeze(0).to(self.device)
|
||||
# mask = torch.from_numpy(mask).unsqueeze(0).to(self.device)
|
||||
# import time
|
||||
# time.sleep(2)
|
||||
# return image
|
||||
|
||||
if config.sd_sampler == SDSampler.ddim:
|
||||
scheduler = DDIMScheduler(
|
||||
beta_start=0.00085,
|
||||
beta_end=0.012,
|
||||
beta_schedule="scaled_linear",
|
||||
clip_sample=False,
|
||||
set_alpha_to_one=False,
|
||||
)
|
||||
elif config.sd_sampler == SDSampler.pndm:
|
||||
PNDM_kwargs = {
|
||||
"tensor_format": "pt",
|
||||
"beta_schedule": "scaled_linear",
|
||||
"beta_start": 0.00085,
|
||||
"beta_end": 0.012,
|
||||
"num_train_timesteps": 1000,
|
||||
"skip_prk_steps": True
|
||||
}
|
||||
scheduler = PNDMScheduler(**PNDM_kwargs)
|
||||
else:
|
||||
raise ValueError(config.sd_sampler)
|
||||
|
||||
self.model.scheduler = scheduler
|
||||
|
||||
seed = config.sd_seed
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
|
||||
Reference in New Issue
Block a user