huggingface_hub deprecate resume_download

This commit is contained in:
Qing
2024-11-23 19:01:48 +08:00
parent 36e65f0097
commit b58c333a73

View File

@@ -15,8 +15,12 @@ from ..utils import (
) )
from .brushnet import BrushNetModel from .brushnet import BrushNetModel
from .brushnet_unet_forward import brushnet_unet_forward from .brushnet_unet_forward import brushnet_unet_forward
from .unet_2d_blocks import CrossAttnDownBlock2D_forward, DownBlock2D_forward, CrossAttnUpBlock2D_forward, \ from .unet_2d_blocks import (
UpBlock2D_forward CrossAttnDownBlock2D_forward,
DownBlock2D_forward,
CrossAttnUpBlock2D_forward,
UpBlock2D_forward,
)
from ...schema import InpaintRequest, ModelType from ...schema import InpaintRequest, ModelType
@@ -26,6 +30,7 @@ class BrushNetWrapper(DiffusionInpaintModel):
def init_model(self, device: torch.device, **kwargs): def init_model(self, device: torch.device, **kwargs):
from .pipeline_brushnet import StableDiffusionBrushNetPipeline from .pipeline_brushnet import StableDiffusionBrushNetPipeline
self.model_info = kwargs["model_info"] self.model_info = kwargs["model_info"]
self.brushnet_method = kwargs["brushnet_method"] self.brushnet_method = kwargs["brushnet_method"]
@@ -52,7 +57,9 @@ class BrushNetWrapper(DiffusionInpaintModel):
) )
logger.info(f"Loading BrushNet model from {self.brushnet_method}") logger.info(f"Loading BrushNet model from {self.brushnet_method}")
brushnet = BrushNetModel.from_pretrained(self.brushnet_method, torch_dtype=torch_dtype) brushnet = BrushNetModel.from_pretrained(
self.brushnet_method, torch_dtype=torch_dtype
)
if self.model_info.is_single_file_diffusers: if self.model_info.is_single_file_diffusers:
if self.model_info.model_type == ModelType.DIFFUSERS_SD: if self.model_info.model_type == ModelType.DIFFUSERS_SD:
@@ -64,7 +71,7 @@ class BrushNetWrapper(DiffusionInpaintModel):
self.model_id_or_path, self.model_id_or_path,
torch_dtype=torch_dtype, torch_dtype=torch_dtype,
load_safety_checker=not disable_nsfw_checker, load_safety_checker=not disable_nsfw_checker,
original_config_file=get_config_files()['v1'], original_config_file=get_config_files()["v1"],
brushnet=brushnet, brushnet=brushnet,
**model_kwargs, **model_kwargs,
) )
@@ -94,31 +101,42 @@ class BrushNetWrapper(DiffusionInpaintModel):
self.callback = kwargs.pop("callback", None) self.callback = kwargs.pop("callback", None)
# Monkey patch the forward method of the UNet to use the brushnet_unet_forward method # Monkey patch the forward method of the UNet to use the brushnet_unet_forward method
self.model.unet.forward = brushnet_unet_forward.__get__(self.model.unet, self.model.unet.__class__) self.model.unet.forward = brushnet_unet_forward.__get__(
self.model.unet, self.model.unet.__class__
)
for down_block in self.model.brushnet.down_blocks: for down_block in self.model.brushnet.down_blocks:
down_block.forward = DownBlock2D_forward.__get__(down_block, down_block.__class__) down_block.forward = DownBlock2D_forward.__get__(
down_block, down_block.__class__
)
for up_block in self.model.brushnet.up_blocks: for up_block in self.model.brushnet.up_blocks:
up_block.forward = UpBlock2D_forward.__get__(up_block, up_block.__class__) up_block.forward = UpBlock2D_forward.__get__(up_block, up_block.__class__)
# Monkey patch unet down_blocks to use CrossAttnDownBlock2D_forward # Monkey patch unet down_blocks to use CrossAttnDownBlock2D_forward
for down_block in self.model.unet.down_blocks: for down_block in self.model.unet.down_blocks:
if down_block.__class__.__name__ == "CrossAttnDownBlock2D": if down_block.__class__.__name__ == "CrossAttnDownBlock2D":
down_block.forward = CrossAttnDownBlock2D_forward.__get__(down_block, down_block.__class__) down_block.forward = CrossAttnDownBlock2D_forward.__get__(
down_block, down_block.__class__
)
else: else:
down_block.forward = DownBlock2D_forward.__get__(down_block, down_block.__class__) down_block.forward = DownBlock2D_forward.__get__(
down_block, down_block.__class__
)
for up_block in self.model.unet.up_blocks: for up_block in self.model.unet.up_blocks:
if up_block.__class__.__name__ == "CrossAttnUpBlock2D": if up_block.__class__.__name__ == "CrossAttnUpBlock2D":
up_block.forward = CrossAttnUpBlock2D_forward.__get__(up_block, up_block.__class__) up_block.forward = CrossAttnUpBlock2D_forward.__get__(
up_block, up_block.__class__
)
else: else:
up_block.forward = UpBlock2D_forward.__get__(up_block, up_block.__class__) up_block.forward = UpBlock2D_forward.__get__(
up_block, up_block.__class__
)
def switch_brushnet_method(self, new_method: str): def switch_brushnet_method(self, new_method: str):
self.brushnet_method = new_method self.brushnet_method = new_method
brushnet = BrushNetModel.from_pretrained( brushnet = BrushNetModel.from_pretrained(
new_method, new_method,
resume_download=True,
local_files_only=self.local_files_only, local_files_only=self.local_files_only,
torch_dtype=self.torch_dtype, torch_dtype=self.torch_dtype,
).to(self.model.device) ).to(self.model.device)