huggingface_hub deprecate resume_download
This commit is contained in:
@@ -15,8 +15,12 @@ from ..utils import (
|
||||
)
|
||||
from .brushnet import BrushNetModel
|
||||
from .brushnet_unet_forward import brushnet_unet_forward
|
||||
from .unet_2d_blocks import CrossAttnDownBlock2D_forward, DownBlock2D_forward, CrossAttnUpBlock2D_forward, \
|
||||
UpBlock2D_forward
|
||||
from .unet_2d_blocks import (
|
||||
CrossAttnDownBlock2D_forward,
|
||||
DownBlock2D_forward,
|
||||
CrossAttnUpBlock2D_forward,
|
||||
UpBlock2D_forward,
|
||||
)
|
||||
from ...schema import InpaintRequest, ModelType
|
||||
|
||||
|
||||
@@ -26,6 +30,7 @@ class BrushNetWrapper(DiffusionInpaintModel):
|
||||
|
||||
def init_model(self, device: torch.device, **kwargs):
|
||||
from .pipeline_brushnet import StableDiffusionBrushNetPipeline
|
||||
|
||||
self.model_info = kwargs["model_info"]
|
||||
self.brushnet_method = kwargs["brushnet_method"]
|
||||
|
||||
@@ -52,7 +57,9 @@ class BrushNetWrapper(DiffusionInpaintModel):
|
||||
)
|
||||
|
||||
logger.info(f"Loading BrushNet model from {self.brushnet_method}")
|
||||
brushnet = BrushNetModel.from_pretrained(self.brushnet_method, torch_dtype=torch_dtype)
|
||||
brushnet = BrushNetModel.from_pretrained(
|
||||
self.brushnet_method, torch_dtype=torch_dtype
|
||||
)
|
||||
|
||||
if self.model_info.is_single_file_diffusers:
|
||||
if self.model_info.model_type == ModelType.DIFFUSERS_SD:
|
||||
@@ -64,7 +71,7 @@ class BrushNetWrapper(DiffusionInpaintModel):
|
||||
self.model_id_or_path,
|
||||
torch_dtype=torch_dtype,
|
||||
load_safety_checker=not disable_nsfw_checker,
|
||||
original_config_file=get_config_files()['v1'],
|
||||
original_config_file=get_config_files()["v1"],
|
||||
brushnet=brushnet,
|
||||
**model_kwargs,
|
||||
)
|
||||
@@ -94,31 +101,42 @@ class BrushNetWrapper(DiffusionInpaintModel):
|
||||
self.callback = kwargs.pop("callback", None)
|
||||
|
||||
# Monkey patch the forward method of the UNet to use the brushnet_unet_forward method
|
||||
self.model.unet.forward = brushnet_unet_forward.__get__(self.model.unet, self.model.unet.__class__)
|
||||
self.model.unet.forward = brushnet_unet_forward.__get__(
|
||||
self.model.unet, self.model.unet.__class__
|
||||
)
|
||||
|
||||
for down_block in self.model.brushnet.down_blocks:
|
||||
down_block.forward = DownBlock2D_forward.__get__(down_block, down_block.__class__)
|
||||
down_block.forward = DownBlock2D_forward.__get__(
|
||||
down_block, down_block.__class__
|
||||
)
|
||||
for up_block in self.model.brushnet.up_blocks:
|
||||
up_block.forward = UpBlock2D_forward.__get__(up_block, up_block.__class__)
|
||||
|
||||
# Monkey patch unet down_blocks to use CrossAttnDownBlock2D_forward
|
||||
for down_block in self.model.unet.down_blocks:
|
||||
if down_block.__class__.__name__ == "CrossAttnDownBlock2D":
|
||||
down_block.forward = CrossAttnDownBlock2D_forward.__get__(down_block, down_block.__class__)
|
||||
down_block.forward = CrossAttnDownBlock2D_forward.__get__(
|
||||
down_block, down_block.__class__
|
||||
)
|
||||
else:
|
||||
down_block.forward = DownBlock2D_forward.__get__(down_block, down_block.__class__)
|
||||
down_block.forward = DownBlock2D_forward.__get__(
|
||||
down_block, down_block.__class__
|
||||
)
|
||||
|
||||
for up_block in self.model.unet.up_blocks:
|
||||
if up_block.__class__.__name__ == "CrossAttnUpBlock2D":
|
||||
up_block.forward = CrossAttnUpBlock2D_forward.__get__(up_block, up_block.__class__)
|
||||
up_block.forward = CrossAttnUpBlock2D_forward.__get__(
|
||||
up_block, up_block.__class__
|
||||
)
|
||||
else:
|
||||
up_block.forward = UpBlock2D_forward.__get__(up_block, up_block.__class__)
|
||||
up_block.forward = UpBlock2D_forward.__get__(
|
||||
up_block, up_block.__class__
|
||||
)
|
||||
|
||||
def switch_brushnet_method(self, new_method: str):
|
||||
self.brushnet_method = new_method
|
||||
brushnet = BrushNetModel.from_pretrained(
|
||||
new_method,
|
||||
resume_download=True,
|
||||
local_files_only=self.local_files_only,
|
||||
torch_dtype=self.torch_dtype,
|
||||
).to(self.model.device)
|
||||
|
||||
Reference in New Issue
Block a user