add download command
This commit is contained in:
@@ -5,7 +5,7 @@ import cv2
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from lama_cleaner.helper import get_cache_path_by_url, load_jit_model
|
||||
from lama_cleaner.helper import get_cache_path_by_url, load_jit_model, download_model
|
||||
from lama_cleaner.schema import Config
|
||||
import numpy as np
|
||||
|
||||
@@ -171,14 +171,19 @@ def load_image(img, mask, device, sigma256=3.0):
|
||||
|
||||
try:
|
||||
import skimage
|
||||
|
||||
gray_256 = skimage.color.rgb2gray(img_256)
|
||||
edge_256 = skimage.feature.canny(gray_256, sigma=3.0, mask=None).astype(float)
|
||||
# cv2.imwrite("skimage_gray.jpg", (gray_256*255).astype(np.uint8))
|
||||
# cv2.imwrite("skimage_edge.jpg", (edge_256*255).astype(np.uint8))
|
||||
except:
|
||||
gray_256 = cv2.cvtColor(img_256, cv2.COLOR_RGB2GRAY)
|
||||
gray_256_blured = cv2.GaussianBlur(gray_256, ksize=(7, 7), sigmaX=sigma256, sigmaY=sigma256)
|
||||
edge_256 = cv2.Canny(gray_256_blured, threshold1=int(255*0.1), threshold2=int(255*0.2))
|
||||
gray_256_blured = cv2.GaussianBlur(
|
||||
gray_256, ksize=(7, 7), sigmaX=sigma256, sigmaY=sigma256
|
||||
)
|
||||
edge_256 = cv2.Canny(
|
||||
gray_256_blured, threshold1=int(255 * 0.1), threshold2=int(255 * 0.2)
|
||||
)
|
||||
|
||||
# cv2.imwrite("opencv_edge.jpg", edge_256)
|
||||
|
||||
@@ -233,12 +238,27 @@ class ZITS(InpaintModel):
|
||||
self.sample_edge_line_iterations = 1
|
||||
|
||||
def init_model(self, device, **kwargs):
|
||||
self.wireframe = load_jit_model(ZITS_WIRE_FRAME_MODEL_URL, device, ZITS_WIRE_FRAME_MODEL_MD5)
|
||||
self.edge_line = load_jit_model(ZITS_EDGE_LINE_MODEL_URL, device, ZITS_EDGE_LINE_MODEL_MD5)
|
||||
self.wireframe = load_jit_model(
|
||||
ZITS_WIRE_FRAME_MODEL_URL, device, ZITS_WIRE_FRAME_MODEL_MD5
|
||||
)
|
||||
self.edge_line = load_jit_model(
|
||||
ZITS_EDGE_LINE_MODEL_URL, device, ZITS_EDGE_LINE_MODEL_MD5
|
||||
)
|
||||
self.structure_upsample = load_jit_model(
|
||||
ZITS_STRUCTURE_UPSAMPLE_MODEL_URL, device, ZITS_STRUCTURE_UPSAMPLE_MODEL_MD5
|
||||
)
|
||||
self.inpaint = load_jit_model(ZITS_INPAINT_MODEL_URL, device, ZITS_INPAINT_MODEL_MD5)
|
||||
self.inpaint = load_jit_model(
|
||||
ZITS_INPAINT_MODEL_URL, device, ZITS_INPAINT_MODEL_MD5
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def download():
|
||||
download_model(ZITS_WIRE_FRAME_MODEL_URL, ZITS_WIRE_FRAME_MODEL_MD5)
|
||||
download_model(ZITS_EDGE_LINE_MODEL_URL, ZITS_EDGE_LINE_MODEL_MD5)
|
||||
download_model(
|
||||
ZITS_STRUCTURE_UPSAMPLE_MODEL_URL, ZITS_STRUCTURE_UPSAMPLE_MODEL_MD5
|
||||
)
|
||||
download_model(ZITS_INPAINT_MODEL_URL, ZITS_INPAINT_MODEL_MD5)
|
||||
|
||||
@staticmethod
|
||||
def is_downloaded() -> bool:
|
||||
@@ -385,12 +405,20 @@ class ZITS(InpaintModel):
|
||||
if score > mask_th:
|
||||
try:
|
||||
import skimage
|
||||
|
||||
rr, cc, value = skimage.draw.line_aa(
|
||||
*to_int(line[0:2]), *to_int(line[2:4])
|
||||
)
|
||||
lmap[rr, cc] = np.maximum(lmap[rr, cc], value)
|
||||
except:
|
||||
cv2.line(lmap, to_int(line[0:2][::-1]), to_int(line[2:4][::-1]), (1, 1, 1), 1, cv2.LINE_AA)
|
||||
cv2.line(
|
||||
lmap,
|
||||
to_int(line[0:2][::-1]),
|
||||
to_int(line[2:4][::-1]),
|
||||
(1, 1, 1),
|
||||
1,
|
||||
cv2.LINE_AA,
|
||||
)
|
||||
|
||||
lmap = np.clip(lmap * 255, 0, 255).astype(np.uint8)
|
||||
lines_tensor.append(to_tensor(lmap).unsqueeze(0))
|
||||
|
||||
Reference in New Issue
Block a user