wip
This commit is contained in:
@@ -11,6 +11,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import gc
|
||||
from typing import Union, List, Optional, Callable, Dict, Any
|
||||
|
||||
# Copy from https://github.com/mikonvergence/ControlNetInpaint/blob/main/src/pipeline_stable_diffusion_controlnet_inpaint.py
|
||||
@@ -217,6 +218,38 @@ class StableDiffusionControlNetInpaintPipeline(StableDiffusionControlNetPipeline
|
||||
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
|
||||
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
|
||||
download_from_original_stable_diffusion_ckpt,
|
||||
)
|
||||
|
||||
controlnet = kwargs.pop("controlnet", None)
|
||||
|
||||
pipe = download_from_original_stable_diffusion_ckpt(
|
||||
pretrained_model_link_or_path,
|
||||
num_in_channels=9,
|
||||
from_safetensors=pretrained_model_link_or_path.endswith("safetensors"),
|
||||
device="cpu",
|
||||
load_safety_checker=False,
|
||||
)
|
||||
|
||||
inpaint_pipe = cls(
|
||||
vae=pipe.vae,
|
||||
text_encoder=pipe.text_encoder,
|
||||
tokenizer=pipe.tokenizer,
|
||||
unet=pipe.unet,
|
||||
controlnet=controlnet,
|
||||
scheduler=pipe.scheduler,
|
||||
safety_checker=None,
|
||||
feature_extractor=None,
|
||||
requires_safety_checker=False,
|
||||
)
|
||||
|
||||
del pipe
|
||||
gc.collect()
|
||||
return inpaint_pipe
|
||||
|
||||
def prepare_mask_latents(
|
||||
self,
|
||||
mask,
|
||||
|
||||
Reference in New Issue
Block a user