enable text_encoder cpu
This commit is contained in:
@@ -9,6 +9,8 @@ from lama_cleaner.model_manager import ModelManager
|
||||
from lama_cleaner.schema import Config, HDStrategy, LDMSampler, SDSampler
|
||||
|
||||
current_dir = Path(__file__).parent.absolute().resolve()
|
||||
save_dir = current_dir / 'result'
|
||||
save_dir.mkdir(exist_ok=True, parents=True)
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
|
||||
|
||||
@@ -40,7 +42,7 @@ def assert_equal(model, config, gt_name, fx=1, fy=1, img_p=current_dir / "image.
|
||||
img, mask = get_data(fx=fx, fy=fy, img_p=img_p, mask_p=mask_p)
|
||||
res = model(img, mask, config)
|
||||
cv2.imwrite(
|
||||
str(current_dir / gt_name),
|
||||
str(save_dir / gt_name),
|
||||
res,
|
||||
[int(cv2.IMWRITE_JPEG_QUALITY), 100, int(cv2.IMWRITE_PNG_COMPRESSION), 0],
|
||||
)
|
||||
@@ -163,7 +165,12 @@ def test_sd(strategy, sampler):
|
||||
print(f"sd_step_{step}")
|
||||
|
||||
sd_steps = 50
|
||||
model = ModelManager(name="sd1.4", device=device, hf_access_token=os.environ['HF_ACCESS_TOKEN'],
|
||||
model = ModelManager(name="sd1.4",
|
||||
device=device,
|
||||
hf_access_token=os.environ['HF_ACCESS_TOKEN'],
|
||||
sd_run_local=False,
|
||||
sd_disable_nsfw=False,
|
||||
sd_cpu_textencoder=False,
|
||||
callbacks=[callback])
|
||||
cfg = get_config(strategy, prompt='a cat sitting on a bench', sd_steps=sd_steps)
|
||||
cfg.sd_sampler = sampler
|
||||
@@ -187,7 +194,8 @@ def test_sd(strategy, sampler):
|
||||
|
||||
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
|
||||
@pytest.mark.parametrize("sampler", [SDSampler.ddim])
|
||||
def test_sd_run_local(strategy, sampler):
|
||||
@pytest.mark.parametrize("disable_nsfw", [True, False])
|
||||
def test_sd_run_local(strategy, sampler, disable_nsfw):
|
||||
def callback(step: int):
|
||||
print(f"sd_step_{step}")
|
||||
|
||||
@@ -195,11 +203,11 @@ def test_sd_run_local(strategy, sampler):
|
||||
model = ModelManager(
|
||||
name="sd1.4",
|
||||
device=device,
|
||||
# hf_access_token=os.environ.get('HF_ACCESS_TOKEN', None),
|
||||
hf_access_token=None,
|
||||
sd_run_local=True,
|
||||
sd_disable_nsfw=True,
|
||||
sd_disable_nsfw=disable_nsfw,
|
||||
sd_cpu_textencoder=True,
|
||||
callbacks=[callback]
|
||||
)
|
||||
cfg = get_config(strategy, prompt='a cat sitting on a bench', sd_steps=sd_steps)
|
||||
cfg.sd_sampler = sampler
|
||||
@@ -219,3 +227,4 @@ def test_sd_run_local(strategy, sampler):
|
||||
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
||||
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask_blur.png",
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user