zits use structure_upsample_model

This commit is contained in:
Qing
2022-07-19 21:47:21 +08:00
parent b0c5d22a5a
commit cfcaf82a21
3 changed files with 58 additions and 30 deletions

View File

@@ -41,7 +41,7 @@ def resize(img, height, width, center_crop=False):
side = np.minimum(imgh, imgw)
j = (imgh - side) // 2
i = (imgw - side) // 2
img = img[j: j + side, i: i + side, ...]
img = img[j : j + side, i : i + side, ...]
if imgh > height and imgw > width:
inter = cv2.INTER_AREA
@@ -219,7 +219,9 @@ class ZITS(InpaintModel):
def init_model(self, device):
self.wireframe = load_jit_model(ZITS_WIRE_FRAME_MODEL_URL, device)
self.edge_line = load_jit_model(ZITS_EDGE_LINE_MODEL_URL, device)
# self.structure_upsample = load_jit_model(ZITS_STRUCTURE_UPSAMPLE_MODEL_URL, device)
self.structure_upsample = load_jit_model(
ZITS_STRUCTURE_UPSAMPLE_MODEL_URL, device
)
self.inpaint = load_jit_model(ZITS_INPAINT_MODEL_URL, device)
@staticmethod
@@ -227,7 +229,7 @@ class ZITS(InpaintModel):
model_paths = [
get_cache_path_by_url(ZITS_WIRE_FRAME_MODEL_URL),
get_cache_path_by_url(ZITS_EDGE_LINE_MODEL_URL),
# get_cache_path_by_url(ZITS_STRUCTURE_UPSAMPLE_MODEL_URL),
get_cache_path_by_url(ZITS_STRUCTURE_UPSAMPLE_MODEL_URL),
get_cache_path_by_url(ZITS_INPAINT_MODEL_URL),
]
return all([os.path.exists(it) for it in model_paths])
@@ -272,20 +274,27 @@ class ZITS(InpaintModel):
# cv2.imwrite("line_pred.jpg", np_line_pred)
# exit()
# No structure_upsample_model
input_size = min(items["h"], items["w"])
edge_pred = F.interpolate(
edge_pred,
size=(input_size, input_size),
mode="bilinear",
align_corners=False,
)
line_pred = F.interpolate(
line_pred,
size=(input_size, input_size),
mode="bilinear",
align_corners=False,
)
if input_size != 256 and input_size > 256:
while edge_pred.shape[2] < input_size:
edge_pred = self.structure_upsample(edge_pred)
edge_pred = torch.sigmoid((edge_pred + 2) * 2)
line_pred = self.structure_upsample(line_pred)
line_pred = torch.sigmoid((line_pred + 2) * 2)
edge_pred = F.interpolate(
edge_pred,
size=(input_size, input_size),
mode="bilinear",
align_corners=False,
)
line_pred = F.interpolate(
line_pred,
size=(input_size, input_size),
mode="bilinear",
align_corners=False,
)
# np_edge_pred = (edge_pred[0][0].numpy() * 255).astype(np.uint8)
# cv2.imwrite("edge_pred_upsample.jpg", np_edge_pred)
@@ -308,12 +317,19 @@ class ZITS(InpaintModel):
self.wireframe_edge_and_line(items, config.zits_wireframe)
inpainted_image = self.inpaint(items["images"], items["masks"],
items["edge"], items["line"],
items["rel_pos"], items["direct"])
inpainted_image = self.inpaint(
items["images"],
items["masks"],
items["edge"],
items["line"],
items["rel_pos"],
items["direct"],
)
inpainted_image = inpainted_image * 255.0
inpainted_image = inpainted_image.cpu().permute(0, 2, 3, 1)[0].numpy().astype(np.uint8)
inpainted_image = (
inpainted_image.cpu().permute(0, 2, 3, 1)[0].numpy().astype(np.uint8)
)
inpainted_image = inpainted_image[:, :, ::-1]
# cv2.imwrite("inpainted.jpg", inpainted_image)
@@ -362,7 +378,9 @@ class ZITS(InpaintModel):
lines_tensor = torch.cat(lines_tensor, dim=0)
return lines_tensor.detach().to(self.device)
def sample_edge_line_logits(self, context, mask=None, iterations=1, add_v=0, mul_v=4):
def sample_edge_line_logits(
self, context, mask=None, iterations=1, add_v=0, mul_v=4
):
[img, edge, line] = context
img = img * (1 - mask)
@@ -391,7 +409,9 @@ class ZITS(InpaintModel):
edge_max_probs = edge_probs.max(dim=-1)[0] + (1 - mask) * (-100)
line_max_probs = line_probs.max(dim=-1)[0] + (1 - mask) * (-100)
indices = torch.sort(edge_max_probs + line_max_probs, dim=-1, descending=True)[1]
indices = torch.sort(
edge_max_probs + line_max_probs, dim=-1, descending=True
)[1]
for ii in range(b):
keep = int((i + 1) / iterations * torch.sum(mask[ii, ...]))