huggingface_hub deprecate resume_download
This commit is contained in:
@@ -15,8 +15,12 @@ from ..utils import (
|
|||||||
)
|
)
|
||||||
from .brushnet import BrushNetModel
|
from .brushnet import BrushNetModel
|
||||||
from .brushnet_unet_forward import brushnet_unet_forward
|
from .brushnet_unet_forward import brushnet_unet_forward
|
||||||
from .unet_2d_blocks import CrossAttnDownBlock2D_forward, DownBlock2D_forward, CrossAttnUpBlock2D_forward, \
|
from .unet_2d_blocks import (
|
||||||
UpBlock2D_forward
|
CrossAttnDownBlock2D_forward,
|
||||||
|
DownBlock2D_forward,
|
||||||
|
CrossAttnUpBlock2D_forward,
|
||||||
|
UpBlock2D_forward,
|
||||||
|
)
|
||||||
from ...schema import InpaintRequest, ModelType
|
from ...schema import InpaintRequest, ModelType
|
||||||
|
|
||||||
|
|
||||||
@@ -26,6 +30,7 @@ class BrushNetWrapper(DiffusionInpaintModel):
|
|||||||
|
|
||||||
def init_model(self, device: torch.device, **kwargs):
|
def init_model(self, device: torch.device, **kwargs):
|
||||||
from .pipeline_brushnet import StableDiffusionBrushNetPipeline
|
from .pipeline_brushnet import StableDiffusionBrushNetPipeline
|
||||||
|
|
||||||
self.model_info = kwargs["model_info"]
|
self.model_info = kwargs["model_info"]
|
||||||
self.brushnet_method = kwargs["brushnet_method"]
|
self.brushnet_method = kwargs["brushnet_method"]
|
||||||
|
|
||||||
@@ -52,7 +57,9 @@ class BrushNetWrapper(DiffusionInpaintModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"Loading BrushNet model from {self.brushnet_method}")
|
logger.info(f"Loading BrushNet model from {self.brushnet_method}")
|
||||||
brushnet = BrushNetModel.from_pretrained(self.brushnet_method, torch_dtype=torch_dtype)
|
brushnet = BrushNetModel.from_pretrained(
|
||||||
|
self.brushnet_method, torch_dtype=torch_dtype
|
||||||
|
)
|
||||||
|
|
||||||
if self.model_info.is_single_file_diffusers:
|
if self.model_info.is_single_file_diffusers:
|
||||||
if self.model_info.model_type == ModelType.DIFFUSERS_SD:
|
if self.model_info.model_type == ModelType.DIFFUSERS_SD:
|
||||||
@@ -64,7 +71,7 @@ class BrushNetWrapper(DiffusionInpaintModel):
|
|||||||
self.model_id_or_path,
|
self.model_id_or_path,
|
||||||
torch_dtype=torch_dtype,
|
torch_dtype=torch_dtype,
|
||||||
load_safety_checker=not disable_nsfw_checker,
|
load_safety_checker=not disable_nsfw_checker,
|
||||||
original_config_file=get_config_files()['v1'],
|
original_config_file=get_config_files()["v1"],
|
||||||
brushnet=brushnet,
|
brushnet=brushnet,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
@@ -94,31 +101,42 @@ class BrushNetWrapper(DiffusionInpaintModel):
|
|||||||
self.callback = kwargs.pop("callback", None)
|
self.callback = kwargs.pop("callback", None)
|
||||||
|
|
||||||
# Monkey patch the forward method of the UNet to use the brushnet_unet_forward method
|
# Monkey patch the forward method of the UNet to use the brushnet_unet_forward method
|
||||||
self.model.unet.forward = brushnet_unet_forward.__get__(self.model.unet, self.model.unet.__class__)
|
self.model.unet.forward = brushnet_unet_forward.__get__(
|
||||||
|
self.model.unet, self.model.unet.__class__
|
||||||
|
)
|
||||||
|
|
||||||
for down_block in self.model.brushnet.down_blocks:
|
for down_block in self.model.brushnet.down_blocks:
|
||||||
down_block.forward = DownBlock2D_forward.__get__(down_block, down_block.__class__)
|
down_block.forward = DownBlock2D_forward.__get__(
|
||||||
|
down_block, down_block.__class__
|
||||||
|
)
|
||||||
for up_block in self.model.brushnet.up_blocks:
|
for up_block in self.model.brushnet.up_blocks:
|
||||||
up_block.forward = UpBlock2D_forward.__get__(up_block, up_block.__class__)
|
up_block.forward = UpBlock2D_forward.__get__(up_block, up_block.__class__)
|
||||||
|
|
||||||
# Monkey patch unet down_blocks to use CrossAttnDownBlock2D_forward
|
# Monkey patch unet down_blocks to use CrossAttnDownBlock2D_forward
|
||||||
for down_block in self.model.unet.down_blocks:
|
for down_block in self.model.unet.down_blocks:
|
||||||
if down_block.__class__.__name__ == "CrossAttnDownBlock2D":
|
if down_block.__class__.__name__ == "CrossAttnDownBlock2D":
|
||||||
down_block.forward = CrossAttnDownBlock2D_forward.__get__(down_block, down_block.__class__)
|
down_block.forward = CrossAttnDownBlock2D_forward.__get__(
|
||||||
|
down_block, down_block.__class__
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
down_block.forward = DownBlock2D_forward.__get__(down_block, down_block.__class__)
|
down_block.forward = DownBlock2D_forward.__get__(
|
||||||
|
down_block, down_block.__class__
|
||||||
|
)
|
||||||
|
|
||||||
for up_block in self.model.unet.up_blocks:
|
for up_block in self.model.unet.up_blocks:
|
||||||
if up_block.__class__.__name__ == "CrossAttnUpBlock2D":
|
if up_block.__class__.__name__ == "CrossAttnUpBlock2D":
|
||||||
up_block.forward = CrossAttnUpBlock2D_forward.__get__(up_block, up_block.__class__)
|
up_block.forward = CrossAttnUpBlock2D_forward.__get__(
|
||||||
|
up_block, up_block.__class__
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
up_block.forward = UpBlock2D_forward.__get__(up_block, up_block.__class__)
|
up_block.forward = UpBlock2D_forward.__get__(
|
||||||
|
up_block, up_block.__class__
|
||||||
|
)
|
||||||
|
|
||||||
def switch_brushnet_method(self, new_method: str):
|
def switch_brushnet_method(self, new_method: str):
|
||||||
self.brushnet_method = new_method
|
self.brushnet_method = new_method
|
||||||
brushnet = BrushNetModel.from_pretrained(
|
brushnet = BrushNetModel.from_pretrained(
|
||||||
new_method,
|
new_method,
|
||||||
resume_download=True,
|
|
||||||
local_files_only=self.local_files_only,
|
local_files_only=self.local_files_only,
|
||||||
torch_dtype=self.torch_dtype,
|
torch_dtype=self.torch_dtype,
|
||||||
).to(self.model.device)
|
).to(self.model.device)
|
||||||
|
|||||||
Reference in New Issue
Block a user