lots update
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
import gc
|
||||
import math
|
||||
import random
|
||||
import traceback
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
@@ -16,8 +17,11 @@ from diffusers import (
|
||||
EulerAncestralDiscreteScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
UniPCMultistepScheduler,
|
||||
LCMScheduler
|
||||
LCMScheduler,
|
||||
)
|
||||
from huggingface_hub.utils import RevisionNotFoundError
|
||||
from loguru import logger
|
||||
from requests import HTTPError
|
||||
|
||||
from lama_cleaner.schema import SDSampler
|
||||
from torch import conv2d, conv_transpose2d
|
||||
@@ -944,3 +948,20 @@ def get_scheduler(sd_sampler, scheduler_config):
|
||||
return LCMScheduler.from_config(scheduler_config)
|
||||
else:
|
||||
raise ValueError(sd_sampler)
|
||||
|
||||
|
||||
def handle_from_pretrained_exceptions(func, **kwargs):
|
||||
try:
|
||||
return func(**kwargs)
|
||||
except ValueError as e:
|
||||
# 处理异常的逻辑
|
||||
if "You are trying to load the model files of the `variant=fp16`" in str(e):
|
||||
logger.info("variant=fp16 not found, try revision=fp16")
|
||||
return func(**{**kwargs, "variant": None, "revision": "fp16"})
|
||||
except OSError as e:
|
||||
previous_traceback = traceback.format_exc()
|
||||
if "RevisionNotFoundError: 404 Client Error." in previous_traceback:
|
||||
logger.info("revision=fp16 not found, try revision=main")
|
||||
return func(**{**kwargs, "variant": None, "revision": "main"})
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
Reference in New Issue
Block a user