1084 lines
55 KiB
Python
1084 lines
55 KiB
Python
import os
|
|
import numpy as np
|
|
try:
|
|
import networkx as netx
|
|
except ImportError:
|
|
import networkx as netx
|
|
|
|
import json
|
|
import scipy.misc as misc
|
|
#import OpenEXR
|
|
import scipy.signal as signal
|
|
import matplotlib.pyplot as plt
|
|
import cv2
|
|
import scipy.misc as misc
|
|
from skimage import io
|
|
from functools import partial
|
|
from vispy import scene, io
|
|
from vispy.scene import visuals
|
|
from functools import reduce
|
|
# from moviepy.editor import ImageSequenceClip
|
|
import scipy.misc as misc
|
|
from vispy.visuals.filters import Alpha
|
|
import cv2
|
|
from skimage.transform import resize
|
|
import copy
|
|
import torch
|
|
import os
|
|
from Inpainting.utils import refine_depth_around_edge, smooth_cntsyn_gap
|
|
from Inpainting.utils import require_depth_edge, filter_irrelevant_edge_new, open_small_mask
|
|
from skimage.feature import canny
|
|
from scipy import ndimage
|
|
import time
|
|
import transforms3d
|
|
|
|
def relabel_node(mesh, nodes, cur_node, new_node):
|
|
if cur_node == new_node:
|
|
return mesh
|
|
mesh.add_node(new_node)
|
|
for key, value in nodes[cur_node].items():
|
|
nodes[new_node][key] = value
|
|
for ne in mesh.neighbors(cur_node):
|
|
mesh.add_edge(new_node, ne)
|
|
mesh.remove_node(cur_node)
|
|
|
|
return mesh
|
|
|
|
def filter_edge(mesh, edge_ccs, config, invalid=False):
|
|
context_ccs = [set() for _ in edge_ccs]
|
|
mesh_nodes = mesh.nodes
|
|
for edge_id, edge_cc in enumerate(edge_ccs):
|
|
if config['context_thickness'] == 0:
|
|
continue
|
|
edge_group = {}
|
|
for edge_node in edge_cc:
|
|
far_nodes = mesh_nodes[edge_node].get('far')
|
|
if far_nodes is None:
|
|
continue
|
|
for far_node in far_nodes:
|
|
context_ccs[edge_id].add(far_node)
|
|
if mesh_nodes[far_node].get('edge_id') is not None:
|
|
if edge_group.get(mesh_nodes[far_node]['edge_id']) is None:
|
|
edge_group[mesh_nodes[far_node]['edge_id']] = set()
|
|
edge_group[mesh_nodes[far_node]['edge_id']].add(far_node)
|
|
if len(edge_cc) > 2:
|
|
for edge_key in [*edge_group.keys()]:
|
|
if len(edge_group[edge_key]) == 1:
|
|
context_ccs[edge_id].remove([*edge_group[edge_key]][0])
|
|
valid_edge_ccs = []
|
|
for xidx, yy in enumerate(edge_ccs):
|
|
if invalid is not True and len(context_ccs[xidx]) > 0:
|
|
# if len(context_ccs[xidx]) > 0:
|
|
valid_edge_ccs.append(yy)
|
|
elif invalid is True and len(context_ccs[xidx]) == 0:
|
|
valid_edge_ccs.append(yy)
|
|
else:
|
|
valid_edge_ccs.append(set())
|
|
# valid_edge_ccs = [yy for xidx, yy in enumerate(edge_ccs) if len(context_ccs[xidx]) > 0]
|
|
|
|
return valid_edge_ccs
|
|
|
|
def extrapolate(global_mesh,
|
|
info_on_pix,
|
|
image,
|
|
depth,
|
|
other_edge_with_id,
|
|
edge_map,
|
|
edge_ccs,
|
|
depth_edge_model,
|
|
depth_feat_model,
|
|
rgb_feat_model,
|
|
config,
|
|
direc='right-up'):
|
|
h_off, w_off = global_mesh.graph['hoffset'], global_mesh.graph['woffset']
|
|
noext_H, noext_W = global_mesh.graph['noext_H'], global_mesh.graph['noext_W']
|
|
|
|
if "up" in direc.lower() and "-" not in direc.lower():
|
|
all_anchor = [0, h_off + config['context_thickness'], w_off, w_off + noext_W]
|
|
global_shift = [all_anchor[0], all_anchor[2]]
|
|
mask_anchor = [0, h_off, w_off, w_off + noext_W]
|
|
context_anchor = [h_off, h_off + config['context_thickness'], w_off, w_off + noext_W]
|
|
valid_line_anchor = [h_off, h_off + 1, w_off, w_off + noext_W]
|
|
valid_anchor = [min(mask_anchor[0], context_anchor[0]), max(mask_anchor[1], context_anchor[1]),
|
|
min(mask_anchor[2], context_anchor[2]), max(mask_anchor[3], context_anchor[3])]
|
|
elif "down" in direc.lower() and "-" not in direc.lower():
|
|
all_anchor = [h_off + noext_H - config['context_thickness'], 2 * h_off + noext_H, w_off, w_off + noext_W]
|
|
global_shift = [all_anchor[0], all_anchor[2]]
|
|
mask_anchor = [h_off + noext_H, 2 * h_off + noext_H, w_off, w_off + noext_W]
|
|
context_anchor = [h_off + noext_H - config['context_thickness'], h_off + noext_H, w_off, w_off + noext_W]
|
|
valid_line_anchor = [h_off + noext_H - 1, h_off + noext_H, w_off, w_off + noext_W]
|
|
valid_anchor = [min(mask_anchor[0], context_anchor[0]), max(mask_anchor[1], context_anchor[1]),
|
|
min(mask_anchor[2], context_anchor[2]), max(mask_anchor[3], context_anchor[3])]
|
|
elif "left" in direc.lower() and "-" not in direc.lower():
|
|
all_anchor = [h_off, h_off + noext_H, 0, w_off + config['context_thickness']]
|
|
global_shift = [all_anchor[0], all_anchor[2]]
|
|
mask_anchor = [h_off, h_off + noext_H, 0, w_off]
|
|
context_anchor = [h_off, h_off + noext_H, w_off, w_off + config['context_thickness']]
|
|
valid_line_anchor = [h_off, h_off + noext_H, w_off, w_off + 1]
|
|
valid_anchor = [min(mask_anchor[0], context_anchor[0]), max(mask_anchor[1], context_anchor[1]),
|
|
min(mask_anchor[2], context_anchor[2]), max(mask_anchor[3], context_anchor[3])]
|
|
elif "right" in direc.lower() and "-" not in direc.lower():
|
|
all_anchor = [h_off, h_off + noext_H, w_off + noext_W - config['context_thickness'], 2 * w_off + noext_W]
|
|
global_shift = [all_anchor[0], all_anchor[2]]
|
|
mask_anchor = [h_off, h_off + noext_H, w_off + noext_W, 2 * w_off + noext_W]
|
|
context_anchor = [h_off, h_off + noext_H, w_off + noext_W - config['context_thickness'], w_off + noext_W]
|
|
valid_line_anchor = [h_off, h_off + noext_H, w_off + noext_W - 1, w_off + noext_W]
|
|
valid_anchor = [min(mask_anchor[0], context_anchor[0]), max(mask_anchor[1], context_anchor[1]),
|
|
min(mask_anchor[2], context_anchor[2]), max(mask_anchor[3], context_anchor[3])]
|
|
elif "left" in direc.lower() and "up" in direc.lower() and "-" in direc.lower():
|
|
all_anchor = [0, h_off + config['context_thickness'], 0, w_off + config['context_thickness']]
|
|
global_shift = [all_anchor[0], all_anchor[2]]
|
|
mask_anchor = [0, h_off, 0, w_off]
|
|
context_anchor = "inv-mask"
|
|
valid_line_anchor = None
|
|
valid_anchor = all_anchor
|
|
elif "left" in direc.lower() and "down" in direc.lower() and "-" in direc.lower():
|
|
all_anchor = [h_off + noext_H - config['context_thickness'], 2 * h_off + noext_H, 0, w_off + config['context_thickness']]
|
|
global_shift = [all_anchor[0], all_anchor[2]]
|
|
mask_anchor = [h_off + noext_H, 2 * h_off + noext_H, 0, w_off]
|
|
context_anchor = "inv-mask"
|
|
valid_line_anchor = None
|
|
valid_anchor = all_anchor
|
|
elif "right" in direc.lower() and "up" in direc.lower() and "-" in direc.lower():
|
|
all_anchor = [0, h_off + config['context_thickness'], w_off + noext_W - config['context_thickness'], 2 * w_off + noext_W]
|
|
global_shift = [all_anchor[0], all_anchor[2]]
|
|
mask_anchor = [0, h_off, w_off + noext_W, 2 * w_off + noext_W]
|
|
context_anchor = "inv-mask"
|
|
valid_line_anchor = None
|
|
valid_anchor = all_anchor
|
|
elif "right" in direc.lower() and "down" in direc.lower() and "-" in direc.lower():
|
|
all_anchor = [h_off + noext_H - config['context_thickness'], 2 * h_off + noext_H, w_off + noext_W - config['context_thickness'], 2 * w_off + noext_W]
|
|
global_shift = [all_anchor[0], all_anchor[2]]
|
|
mask_anchor = [h_off + noext_H, 2 * h_off + noext_H, w_off + noext_W, 2 * w_off + noext_W]
|
|
context_anchor = "inv-mask"
|
|
valid_line_anchor = None
|
|
valid_anchor = all_anchor
|
|
|
|
global_mask = np.zeros_like(depth)
|
|
global_mask[mask_anchor[0]:mask_anchor[1],mask_anchor[2]:mask_anchor[3]] = 1
|
|
mask = global_mask[valid_anchor[0]:valid_anchor[1], valid_anchor[2]:valid_anchor[3]] * 1
|
|
context = 1 - mask
|
|
global_context = np.zeros_like(depth)
|
|
global_context[all_anchor[0]:all_anchor[1],all_anchor[2]:all_anchor[3]] = context
|
|
# context = global_context[valid_anchor[0]:valid_anchor[1], valid_anchor[2]:valid_anchor[3]] * 1
|
|
|
|
|
|
|
|
valid_area = mask + context
|
|
input_rgb = image[valid_anchor[0]:valid_anchor[1], valid_anchor[2]:valid_anchor[3]] / 255. * context[..., None]
|
|
input_depth = depth[valid_anchor[0]:valid_anchor[1], valid_anchor[2]:valid_anchor[3]] * context
|
|
log_depth = np.log(input_depth + 1e-8)
|
|
log_depth[mask > 0] = 0
|
|
input_mean_depth = np.mean(log_depth[context > 0])
|
|
input_zero_mean_depth = (log_depth - input_mean_depth) * context
|
|
input_disp = 1./np.abs(input_depth)
|
|
input_disp[mask > 0] = 0
|
|
input_disp = input_disp / input_disp.max()
|
|
valid_line = np.zeros_like(depth)
|
|
if valid_line_anchor is not None:
|
|
valid_line[valid_line_anchor[0]:valid_line_anchor[1], valid_line_anchor[2]:valid_line_anchor[3]] = 1
|
|
valid_line = valid_line[all_anchor[0]:all_anchor[1], all_anchor[2]:all_anchor[3]]
|
|
# f, ((ax1, ax2)) = plt.subplots(1, 2, sharex=True, sharey=True); ax1.imshow(global_context * 1 + global_mask * 2); ax2.imshow(image); plt.show()
|
|
# f, ((ax1, ax2, ax3)) = plt.subplots(1, 3, sharex=True, sharey=True); ax1.imshow(context * 1 + mask * 2); ax2.imshow(input_rgb); ax3.imshow(valid_line); plt.show()
|
|
# import pdb; pdb.set_trace()
|
|
# return
|
|
input_edge_map = edge_map[all_anchor[0]:all_anchor[1], all_anchor[2]:all_anchor[3]] * context
|
|
input_other_edge_with_id = other_edge_with_id[all_anchor[0]:all_anchor[1], all_anchor[2]:all_anchor[3]]
|
|
end_depth_maps = ((valid_line * input_edge_map) > 0) * input_depth
|
|
|
|
|
|
if isinstance(config["gpu_ids"], int) and (config["gpu_ids"] >= 0):
|
|
device = config["gpu_ids"]
|
|
else:
|
|
device = "cpu"
|
|
|
|
valid_edge_ids = sorted(list(input_other_edge_with_id[(valid_line * input_edge_map) > 0]))
|
|
valid_edge_ids = valid_edge_ids[1:] if (len(valid_edge_ids) > 0 and valid_edge_ids[0] == -1) else valid_edge_ids
|
|
edge = reduce(lambda x, y: (x + (input_other_edge_with_id == y).astype(np.uint8)).clip(0, 1), [np.zeros_like(mask)] + list(valid_edge_ids))
|
|
t_edge = torch.FloatTensor(edge).to(device)[None, None, ...]
|
|
t_rgb = torch.FloatTensor(input_rgb).to(device).permute(2,0,1).unsqueeze(0)
|
|
t_mask = torch.FloatTensor(mask).to(device)[None, None, ...]
|
|
t_context = torch.FloatTensor(context).to(device)[None, None, ...]
|
|
t_disp = torch.FloatTensor(input_disp).to(device)[None, None, ...]
|
|
t_depth_zero_mean_depth = torch.FloatTensor(input_zero_mean_depth).to(device)[None, None, ...]
|
|
|
|
depth_edge_output = depth_edge_model.forward_3P(t_mask, t_context, t_rgb, t_disp, t_edge, unit_length=128,
|
|
cuda=device)
|
|
t_output_edge = (depth_edge_output> config['ext_edge_threshold']).float() * t_mask + t_edge
|
|
output_raw_edge = t_output_edge.data.cpu().numpy().squeeze()
|
|
# import pdb; pdb.set_trace()
|
|
mesh = netx.Graph()
|
|
hxs, hys = np.where(output_raw_edge * mask > 0)
|
|
valid_map = mask + context
|
|
for hx, hy in zip(hxs, hys):
|
|
node = (hx, hy)
|
|
mesh.add_node((hx, hy))
|
|
eight_nes = [ne for ne in [(hx + 1, hy), (hx - 1, hy), (hx, hy + 1), (hx, hy - 1), \
|
|
(hx + 1, hy + 1), (hx - 1, hy - 1), (hx - 1, hy + 1), (hx + 1, hy - 1)]\
|
|
if 0 <= ne[0] < output_raw_edge.shape[0] and 0 <= ne[1] < output_raw_edge.shape[1] and 0 < output_raw_edge[ne[0], ne[1]]]
|
|
for ne in eight_nes:
|
|
mesh.add_edge(node, ne, length=np.hypot(ne[0] - hx, ne[1] - hy))
|
|
if end_depth_maps[ne[0], ne[1]] != 0:
|
|
mesh.nodes[ne[0], ne[1]]['cnt'] = True
|
|
mesh.nodes[ne[0], ne[1]]['depth'] = end_depth_maps[ne[0], ne[1]]
|
|
ccs = [*netx.connected_components(mesh)]
|
|
end_pts = []
|
|
for cc in ccs:
|
|
end_pts.append(set())
|
|
for node in cc:
|
|
if mesh.nodes[node].get('cnt') is not None:
|
|
end_pts[-1].add((node[0], node[1], mesh.nodes[node]['depth']))
|
|
fpath_map = np.zeros_like(output_raw_edge) - 1
|
|
npath_map = np.zeros_like(output_raw_edge) - 1
|
|
for end_pt, cc in zip(end_pts, ccs):
|
|
sorted_end_pt = []
|
|
if len(end_pt) >= 2:
|
|
continue
|
|
if len(end_pt) == 0:
|
|
continue
|
|
if len(end_pt) == 1:
|
|
sub_mesh = mesh.subgraph(list(cc)).copy()
|
|
pnodes = netx.periphery(sub_mesh)
|
|
ends = [*end_pt]
|
|
edge_id = global_mesh.nodes[(ends[0][0] + all_anchor[0], ends[0][1] + all_anchor[2], -ends[0][2])]['edge_id']
|
|
pnodes = sorted(pnodes,
|
|
key=lambda x: np.hypot((x[0] - ends[0][0]), (x[1] - ends[0][1])),
|
|
reverse=True)[0]
|
|
npath = [*netx.shortest_path(sub_mesh, (ends[0][0], ends[0][1]), pnodes, weight='length')]
|
|
for np_node in npath:
|
|
npath_map[np_node[0], np_node[1]] = edge_id
|
|
fpath = []
|
|
if global_mesh.nodes[(ends[0][0] + all_anchor[0], ends[0][1] + all_anchor[2], -ends[0][2])].get('far') is None:
|
|
print("None far")
|
|
import pdb; pdb.set_trace()
|
|
else:
|
|
fnodes = global_mesh.nodes[(ends[0][0] + all_anchor[0], ends[0][1] + all_anchor[2], -ends[0][2])].get('far')
|
|
fnodes = [(xx[0] - all_anchor[0], xx[1] - all_anchor[2], xx[2]) for xx in fnodes]
|
|
dmask = mask + 0
|
|
did = 0
|
|
while True:
|
|
did += 1
|
|
dmask = cv2.dilate(dmask, np.ones((3, 3)), iterations=1)
|
|
if did > 3:
|
|
break
|
|
# ffnode = [fnode for fnode in fnodes if (dmask[fnode[0], fnode[1]] > 0)]
|
|
ffnode = [fnode for fnode in fnodes if (dmask[fnode[0], fnode[1]] > 0 and mask[fnode[0], fnode[1]] == 0)]
|
|
if len(ffnode) > 0:
|
|
fnode = ffnode[0]
|
|
break
|
|
if len(ffnode) == 0:
|
|
continue
|
|
fpath.append((fnode[0], fnode[1]))
|
|
for step in range(0, len(npath) - 1):
|
|
parr = (npath[step + 1][0] - npath[step][0], npath[step + 1][1] - npath[step][1])
|
|
new_loc = (fpath[-1][0] + parr[0], fpath[-1][1] + parr[1])
|
|
new_loc_nes = [xx for xx in [(new_loc[0] + 1, new_loc[1]), (new_loc[0] - 1, new_loc[1]),
|
|
(new_loc[0], new_loc[1] + 1), (new_loc[0], new_loc[1] - 1)]\
|
|
if xx[0] >= 0 and xx[0] < fpath_map.shape[0] and xx[1] >= 0 and xx[1] < fpath_map.shape[1]]
|
|
if np.sum([fpath_map[nlne[0], nlne[1]] for nlne in new_loc_nes]) != -4:
|
|
break
|
|
if npath_map[new_loc[0], new_loc[1]] != -1:
|
|
if npath_map[new_loc[0], new_loc[1]] != edge_id:
|
|
break
|
|
else:
|
|
continue
|
|
if valid_area[new_loc[0], new_loc[1]] == 0:
|
|
break
|
|
new_loc_nes_eight = [xx for xx in [(new_loc[0] + 1, new_loc[1]), (new_loc[0] - 1, new_loc[1]),
|
|
(new_loc[0], new_loc[1] + 1), (new_loc[0], new_loc[1] - 1),
|
|
(new_loc[0] + 1, new_loc[1] + 1), (new_loc[0] + 1, new_loc[1] - 1),
|
|
(new_loc[0] - 1, new_loc[1] - 1), (new_loc[0] - 1, new_loc[1] + 1)]\
|
|
if xx[0] >= 0 and xx[0] < fpath_map.shape[0] and xx[1] >= 0 and xx[1] < fpath_map.shape[1]]
|
|
if np.sum([int(npath_map[nlne[0], nlne[1]] == edge_id) for nlne in new_loc_nes_eight]) == 0:
|
|
break
|
|
fpath.append((fpath[-1][0] + parr[0], fpath[-1][1] + parr[1]))
|
|
if step != len(npath) - 2:
|
|
for xx in npath[step+1:]:
|
|
if npath_map[xx[0], xx[1]] == edge_id:
|
|
npath_map[xx[0], xx[1]] = -1
|
|
if len(fpath) > 0:
|
|
for fp_node in fpath:
|
|
fpath_map[fp_node[0], fp_node[1]] = edge_id
|
|
# import pdb; pdb.set_trace()
|
|
far_edge = (fpath_map > -1).astype(np.uint8)
|
|
update_edge = (npath_map > -1) * mask + edge
|
|
t_update_edge = torch.FloatTensor(update_edge).to(device)[None, None, ...]
|
|
depth_output = depth_feat_model.forward_3P(t_mask, t_context, t_depth_zero_mean_depth, t_update_edge, unit_length=128,
|
|
cuda=device)
|
|
depth_output = depth_output.cpu().data.numpy().squeeze()
|
|
depth_output = np.exp(depth_output + input_mean_depth) * mask # + input_depth * context
|
|
# if "right" in direc.lower() and "-" not in direc.lower():
|
|
# plt.imshow(depth_output); plt.show()
|
|
# import pdb; pdb.set_trace()
|
|
# f, ((ax1, ax2)) = plt.subplots(1, 2, sharex=True, sharey=True); ax1.imshow(depth_output); ax2.imshow(npath_map + fpath_map); plt.show()
|
|
for near_id in np.unique(npath_map[npath_map > -1]):
|
|
depth_output = refine_depth_around_edge(depth_output.copy(),
|
|
(fpath_map == near_id).astype(np.uint8) * mask, # far_edge_map_in_mask,
|
|
(fpath_map == near_id).astype(np.uint8), # far_edge_map,
|
|
(npath_map == near_id).astype(np.uint8) * mask,
|
|
mask.copy(),
|
|
np.zeros_like(mask),
|
|
config)
|
|
# if "right" in direc.lower() and "-" not in direc.lower():
|
|
# plt.imshow(depth_output); plt.show()
|
|
# import pdb; pdb.set_trace()
|
|
# f, ((ax1, ax2)) = plt.subplots(1, 2, sharex=True, sharey=True); ax1.imshow(depth_output); ax2.imshow(npath_map + fpath_map); plt.show()
|
|
rgb_output = rgb_feat_model.forward_3P(t_mask, t_context, t_rgb, t_update_edge, unit_length=128,
|
|
cuda=device)
|
|
|
|
# rgb_output = rgb_feat_model.forward_3P(t_mask, t_context, t_rgb, t_update_edge, unit_length=128, cuda=config['gpu_ids'])
|
|
if config.get('gray_image') is True:
|
|
rgb_output = rgb_output.mean(1, keepdim=True).repeat((1,3,1,1))
|
|
rgb_output = ((rgb_output.squeeze().data.cpu().permute(1,2,0).numpy() * mask[..., None] + input_rgb) * 255).astype(np.uint8)
|
|
image[all_anchor[0]:all_anchor[1], all_anchor[2]:all_anchor[3]][mask > 0] = rgb_output[mask > 0] # np.array([255,0,0]) # rgb_output[mask > 0]
|
|
depth[all_anchor[0]:all_anchor[1], all_anchor[2]:all_anchor[3]][mask > 0] = depth_output[mask > 0]
|
|
# nxs, nys = np.where(mask > -1)
|
|
# for nx, ny in zip(nxs, nys):
|
|
# info_on_pix[(nx, ny)][0]['color'] = rgb_output[]
|
|
|
|
|
|
nxs, nys = np.where((npath_map > -1))
|
|
for nx, ny in zip(nxs, nys):
|
|
n_id = npath_map[nx, ny]
|
|
four_nes = [xx for xx in [(nx + 1, ny), (nx - 1, ny), (nx, ny + 1), (nx, ny - 1)]\
|
|
if 0 <= xx[0] < fpath_map.shape[0] and 0 <= xx[1] < fpath_map.shape[1]]
|
|
for nex, ney in four_nes:
|
|
if fpath_map[nex, ney] == n_id:
|
|
na, nb = (nx + all_anchor[0], ny + all_anchor[2], info_on_pix[(nx + all_anchor[0], ny + all_anchor[2])][0]['depth']), \
|
|
(nex + all_anchor[0], ney + all_anchor[2], info_on_pix[(nex + all_anchor[0], ney + all_anchor[2])][0]['depth'])
|
|
if global_mesh.has_edge(na, nb):
|
|
global_mesh.remove_edge(na, nb)
|
|
nxs, nys = np.where((fpath_map > -1))
|
|
for nx, ny in zip(nxs, nys):
|
|
n_id = fpath_map[nx, ny]
|
|
four_nes = [xx for xx in [(nx + 1, ny), (nx - 1, ny), (nx, ny + 1), (nx, ny - 1)]\
|
|
if 0 <= xx[0] < npath_map.shape[0] and 0 <= xx[1] < npath_map.shape[1]]
|
|
for nex, ney in four_nes:
|
|
if npath_map[nex, ney] == n_id:
|
|
na, nb = (nx + all_anchor[0], ny + all_anchor[2], info_on_pix[(nx + all_anchor[0], ny + all_anchor[2])][0]['depth']), \
|
|
(nex + all_anchor[0], ney + all_anchor[2], info_on_pix[(nex + all_anchor[0], ney + all_anchor[2])][0]['depth'])
|
|
if global_mesh.has_edge(na, nb):
|
|
global_mesh.remove_edge(na, nb)
|
|
nxs, nys = np.where(mask > 0)
|
|
for x, y in zip(nxs, nys):
|
|
x = x + all_anchor[0]
|
|
y = y + all_anchor[2]
|
|
cur_node = (x, y, 0)
|
|
new_node = (x, y, -abs(depth[x, y]))
|
|
disp = 1. / -abs(depth[x, y])
|
|
mapping_dict = {cur_node: new_node}
|
|
info_on_pix, global_mesh = update_info(mapping_dict, info_on_pix, global_mesh)
|
|
global_mesh.nodes[new_node]['color'] = image[x, y]
|
|
global_mesh.nodes[new_node]['old_color'] = image[x, y]
|
|
global_mesh.nodes[new_node]['disp'] = disp
|
|
info_on_pix[(x, y)][0]['depth'] = -abs(depth[x, y])
|
|
info_on_pix[(x, y)][0]['disp'] = disp
|
|
info_on_pix[(x, y)][0]['color'] = image[x, y]
|
|
|
|
|
|
nxs, nys = np.where((npath_map > -1))
|
|
for nx, ny in zip(nxs, nys):
|
|
self_node = (nx + all_anchor[0], ny + all_anchor[2], info_on_pix[(nx + all_anchor[0], ny + all_anchor[2])][0]['depth'])
|
|
if global_mesh.has_node(self_node) is False:
|
|
break
|
|
n_id = int(round(npath_map[nx, ny]))
|
|
four_nes = [xx for xx in [(nx + 1, ny), (nx - 1, ny), (nx, ny + 1), (nx, ny - 1)]\
|
|
if 0 <= xx[0] < fpath_map.shape[0] and 0 <= xx[1] < fpath_map.shape[1]]
|
|
for nex, ney in four_nes:
|
|
ne_node = (nex + all_anchor[0], ney + all_anchor[2], info_on_pix[(nex + all_anchor[0], ney + all_anchor[2])][0]['depth'])
|
|
if global_mesh.has_node(ne_node) is False:
|
|
continue
|
|
if fpath_map[nex, ney] == n_id:
|
|
if global_mesh.nodes[self_node].get('edge_id') is None:
|
|
global_mesh.nodes[self_node]['edge_id'] = n_id
|
|
edge_ccs[n_id].add(self_node)
|
|
info_on_pix[(self_node[0], self_node[1])][0]['edge_id'] = n_id
|
|
if global_mesh.has_edge(self_node, ne_node) is True:
|
|
global_mesh.remove_edge(self_node, ne_node)
|
|
if global_mesh.nodes[self_node].get('far') is None:
|
|
global_mesh.nodes[self_node]['far'] = []
|
|
global_mesh.nodes[self_node]['far'].append(ne_node)
|
|
|
|
global_fpath_map = np.zeros_like(other_edge_with_id) - 1
|
|
global_fpath_map[all_anchor[0]:all_anchor[1], all_anchor[2]:all_anchor[3]] = fpath_map
|
|
fpath_ids = np.unique(global_fpath_map)
|
|
fpath_ids = fpath_ids[1:] if fpath_ids.shape[0] > 0 and fpath_ids[0] == -1 else []
|
|
fpath_real_id_map = np.zeros_like(global_fpath_map) - 1
|
|
for fpath_id in fpath_ids:
|
|
fpath_real_id = np.unique(((global_fpath_map == fpath_id).astype(np.int) * (other_edge_with_id + 1)) - 1)
|
|
fpath_real_id = fpath_real_id[1:] if fpath_real_id.shape[0] > 0 and fpath_real_id[0] == -1 else []
|
|
fpath_real_id = fpath_real_id.astype(np.int)
|
|
fpath_real_id = np.bincount(fpath_real_id).argmax()
|
|
fpath_real_id_map[global_fpath_map == fpath_id] = fpath_real_id
|
|
nxs, nys = np.where((fpath_map > -1))
|
|
for nx, ny in zip(nxs, nys):
|
|
self_node = (nx + all_anchor[0], ny + all_anchor[2], info_on_pix[(nx + all_anchor[0], ny + all_anchor[2])][0]['depth'])
|
|
n_id = fpath_map[nx, ny]
|
|
four_nes = [xx for xx in [(nx + 1, ny), (nx - 1, ny), (nx, ny + 1), (nx, ny - 1)]\
|
|
if 0 <= xx[0] < npath_map.shape[0] and 0 <= xx[1] < npath_map.shape[1]]
|
|
for nex, ney in four_nes:
|
|
ne_node = (nex + all_anchor[0], ney + all_anchor[2], info_on_pix[(nex + all_anchor[0], ney + all_anchor[2])][0]['depth'])
|
|
if global_mesh.has_node(ne_node) is False:
|
|
continue
|
|
if npath_map[nex, ney] == n_id or global_mesh.nodes[ne_node].get('edge_id') == n_id:
|
|
if global_mesh.has_edge(self_node, ne_node) is True:
|
|
global_mesh.remove_edge(self_node, ne_node)
|
|
if global_mesh.nodes[self_node].get('near') is None:
|
|
global_mesh.nodes[self_node]['near'] = []
|
|
if global_mesh.nodes[self_node].get('edge_id') is None:
|
|
f_id = int(round(fpath_real_id_map[self_node[0], self_node[1]]))
|
|
global_mesh.nodes[self_node]['edge_id'] = f_id
|
|
info_on_pix[(self_node[0], self_node[1])][0]['edge_id'] = f_id
|
|
edge_ccs[f_id].add(self_node)
|
|
global_mesh.nodes[self_node]['near'].append(ne_node)
|
|
|
|
return info_on_pix, global_mesh, image, depth, edge_ccs
|
|
# for edge_cc in edge_ccs:
|
|
# for edge_node in edge_cc:
|
|
# edge_ccs
|
|
# context_ccs, mask_ccs, broken_mask_ccs, edge_ccs, erode_context_ccs, init_mask_connect, edge_maps, extend_context_ccs, extend_edge_ccs
|
|
|
|
def get_valid_size(imap):
|
|
x_max = np.where(imap.sum(1).squeeze() > 0)[0].max() + 1
|
|
x_min = np.where(imap.sum(1).squeeze() > 0)[0].min()
|
|
y_max = np.where(imap.sum(0).squeeze() > 0)[0].max() + 1
|
|
y_min = np.where(imap.sum(0).squeeze() > 0)[0].min()
|
|
size_dict = {'x_max':x_max, 'y_max':y_max, 'x_min':x_min, 'y_min':y_min}
|
|
|
|
return size_dict
|
|
|
|
def dilate_valid_size(isize_dict, imap, dilate=[0, 0]):
|
|
osize_dict = copy.deepcopy(isize_dict)
|
|
osize_dict['x_min'] = max(0, osize_dict['x_min'] - dilate[0])
|
|
osize_dict['x_max'] = min(imap.shape[0], osize_dict['x_max'] + dilate[0])
|
|
osize_dict['y_min'] = max(0, osize_dict['y_min'] - dilate[0])
|
|
osize_dict['y_max'] = min(imap.shape[1], osize_dict['y_max'] + dilate[1])
|
|
|
|
return osize_dict
|
|
|
|
def size_operation(size_a, size_b, operation):
|
|
assert operation == '+' or operation == '-', "Operation must be '+' (union) or '-' (exclude)"
|
|
osize = {}
|
|
if operation == '+':
|
|
osize['x_min'] = min(size_a['x_min'], size_b['x_min'])
|
|
osize['y_min'] = min(size_a['y_min'], size_b['y_min'])
|
|
osize['x_max'] = max(size_a['x_max'], size_b['x_max'])
|
|
osize['y_max'] = max(size_a['y_max'], size_b['y_max'])
|
|
assert operation != '-', "Operation '-' is undefined !"
|
|
|
|
return osize
|
|
|
|
def fill_dummy_bord(mesh, info_on_pix, image, depth, config):
|
|
context = np.zeros_like(depth).astype(np.uint8)
|
|
context[mesh.graph['hoffset']:mesh.graph['hoffset'] + mesh.graph['noext_H'],
|
|
mesh.graph['woffset']:mesh.graph['woffset'] + mesh.graph['noext_W']] = 1
|
|
mask = 1 - context
|
|
xs, ys = np.where(mask > 0)
|
|
depth = depth * context
|
|
image = image * context[..., None]
|
|
cur_depth = 0
|
|
cur_disp = 0
|
|
color = [0, 0, 0]
|
|
for x, y in zip(xs, ys):
|
|
cur_node = (x, y, cur_depth)
|
|
mesh.add_node(cur_node, color=color,
|
|
synthesis=False,
|
|
disp=cur_disp,
|
|
cc_id=set(),
|
|
ext_pixel=True)
|
|
info_on_pix[(x, y)] = [{'depth':cur_depth,
|
|
'color':mesh.nodes[(x, y, cur_depth)]['color'],
|
|
'synthesis':False,
|
|
'disp':mesh.nodes[cur_node]['disp'],
|
|
'ext_pixel':True}]
|
|
# for x, y in zip(xs, ys):
|
|
four_nes = [(xx, yy) for xx, yy in [(x + 1, y), (x - 1, y), (x, y + 1), (x, y - 1)] if\
|
|
0 <= x < mesh.graph['H'] and 0 <= y < mesh.graph['W'] and info_on_pix.get((xx, yy)) is not None]
|
|
for ne in four_nes:
|
|
# if (ne[0] - x) + (ne[1] - y) == 1 and info_on_pix.get((ne[0], ne[1])) is not None:
|
|
mesh.add_edge(cur_node, (ne[0], ne[1], info_on_pix[(ne[0], ne[1])][0]['depth']))
|
|
|
|
return mesh, info_on_pix
|
|
|
|
|
|
def enlarge_border(mesh, info_on_pix, depth, image, config):
|
|
mesh.graph['hoffset'], mesh.graph['woffset'] = config['extrapolation_thickness'], config['extrapolation_thickness']
|
|
mesh.graph['bord_up'], mesh.graph['bord_left'], mesh.graph['bord_down'], mesh.graph['bord_right'] = \
|
|
0, 0, mesh.graph['H'], mesh.graph['W']
|
|
# new_image = np.pad(image,
|
|
# pad_width=((config['extrapolation_thickness'], config['extrapolation_thickness']),
|
|
# (config['extrapolation_thickness'], config['extrapolation_thickness']), (0, 0)),
|
|
# mode='constant')
|
|
# new_depth = np.pad(depth,
|
|
# pad_width=((config['extrapolation_thickness'], config['extrapolation_thickness']),
|
|
# (config['extrapolation_thickness'], config['extrapolation_thickness'])),
|
|
# mode='constant')
|
|
|
|
return mesh, info_on_pix, depth, image
|
|
|
|
def fill_missing_node(mesh, info_on_pix, image, depth):
|
|
for x in range(mesh.graph['bord_up'], mesh.graph['bord_down']):
|
|
for y in range(mesh.graph['bord_left'], mesh.graph['bord_right']):
|
|
if info_on_pix.get((x, y)) is None:
|
|
print("fill missing node = ", x, y)
|
|
import pdb; pdb.set_trace()
|
|
re_depth, re_count = 0, 0
|
|
for ne in [(x + 1, y), (x - 1, y), (x, y + 1), (x, y - 1)]:
|
|
if info_on_pix.get(ne) is not None:
|
|
re_depth += info_on_pix[ne][0]['depth']
|
|
re_count += 1
|
|
if re_count == 0:
|
|
re_depth = -abs(depth[x, y])
|
|
else:
|
|
re_depth = re_depth / re_count
|
|
depth[x, y] = abs(re_depth)
|
|
info_on_pix[(x, y)] = [{'depth':re_depth,
|
|
'color':image[x, y],
|
|
'synthesis':False,
|
|
'disp':1./re_depth}]
|
|
mesh.add_node((x, y, re_depth), color=image[x, y],
|
|
synthesis=False,
|
|
disp=1./re_depth,
|
|
cc_id=set())
|
|
return mesh, info_on_pix, depth
|
|
|
|
|
|
|
|
def refresh_bord_depth(mesh, info_on_pix, image, depth):
|
|
H, W = mesh.graph['H'], mesh.graph['W']
|
|
corner_nodes = [(mesh.graph['bord_up'], mesh.graph['bord_left']),
|
|
(mesh.graph['bord_up'], mesh.graph['bord_right'] - 1),
|
|
(mesh.graph['bord_down'] - 1, mesh.graph['bord_left']),
|
|
(mesh.graph['bord_down'] - 1, mesh.graph['bord_right'] - 1)]
|
|
# (0, W - 1), (H - 1, 0), (H - 1, W - 1)]
|
|
bord_nodes = []
|
|
bord_nodes += [(mesh.graph['bord_up'], xx) for xx in range(mesh.graph['bord_left'] + 1, mesh.graph['bord_right'] - 1)]
|
|
bord_nodes += [(mesh.graph['bord_down'] - 1, xx) for xx in range(mesh.graph['bord_left'] + 1, mesh.graph['bord_right'] - 1)]
|
|
bord_nodes += [(xx, mesh.graph['bord_left']) for xx in range(mesh.graph['bord_up'] + 1, mesh.graph['bord_down'] - 1)]
|
|
bord_nodes += [(xx, mesh.graph['bord_right'] - 1) for xx in range(mesh.graph['bord_up'] + 1, mesh.graph['bord_down'] - 1)]
|
|
for xy in bord_nodes:
|
|
tgt_loc = None
|
|
if xy[0] == mesh.graph['bord_up']:
|
|
tgt_loc = (xy[0] + 1, xy[1])# (1, xy[1])
|
|
elif xy[0] == mesh.graph['bord_down'] - 1:
|
|
tgt_loc = (xy[0] - 1, xy[1]) # (H - 2, xy[1])
|
|
elif xy[1] == mesh.graph['bord_left']:
|
|
tgt_loc = (xy[0], xy[1] + 1)
|
|
elif xy[1] == mesh.graph['bord_right'] - 1:
|
|
tgt_loc = (xy[0], xy[1] - 1)
|
|
if tgt_loc is not None:
|
|
ne_infos = info_on_pix.get(tgt_loc)
|
|
if ne_infos is None:
|
|
import pdb; pdb.set_trace()
|
|
# if ne_infos is not None and len(ne_infos) == 1:
|
|
tgt_depth = ne_infos[0]['depth']
|
|
tgt_disp = ne_infos[0]['disp']
|
|
new_node = (xy[0], xy[1], tgt_depth)
|
|
src_node = (tgt_loc[0], tgt_loc[1], tgt_depth)
|
|
tgt_nes_loc = [(xx[0], xx[1]) \
|
|
for xx in mesh.neighbors(src_node)]
|
|
tgt_nes_loc = [(xx[0] - tgt_loc[0] + xy[0], xx[1] - tgt_loc[1] + xy[1]) for xx in tgt_nes_loc \
|
|
if abs(xx[0] - xy[0]) == 1 and abs(xx[1] - xy[1]) == 1]
|
|
tgt_nes_loc = [xx for xx in tgt_nes_loc if info_on_pix.get(xx) is not None]
|
|
tgt_nes_loc.append(tgt_loc)
|
|
# if (xy[0], xy[1]) == (559, 60):
|
|
# import pdb; pdb.set_trace()
|
|
if info_on_pix.get(xy) is not None and len(info_on_pix.get(xy)) > 0:
|
|
old_depth = info_on_pix[xy][0].get('depth')
|
|
old_node = (xy[0], xy[1], old_depth)
|
|
mesh.remove_edges_from([(old_ne, old_node) for old_ne in mesh.neighbors(old_node)])
|
|
mesh.add_edges_from([((zz[0], zz[1], info_on_pix[zz][0]['depth']), old_node) for zz in tgt_nes_loc])
|
|
mapping_dict = {old_node: new_node}
|
|
# if old_node[2] == new_node[2]:
|
|
# print("mapping_dict = ", mapping_dict)
|
|
info_on_pix, mesh = update_info(mapping_dict, info_on_pix, mesh)
|
|
else:
|
|
info_on_pix[xy] = []
|
|
info_on_pix[xy][0] = info_on_pix[tgt_loc][0]
|
|
info_on_pix['color'] = image[xy[0], xy[1]]
|
|
info_on_pix['old_color'] = image[xy[0], xy[1]]
|
|
mesh.add_node(new_node)
|
|
mesh.add_edges_from([((zz[0], zz[1], info_on_pix[zz][0]['depth']), new_node) for zz in tgt_nes_loc])
|
|
mesh.nodes[new_node]['far'] = None
|
|
mesh.nodes[new_node]['near'] = None
|
|
if mesh.nodes[src_node].get('far') is not None:
|
|
redundant_nodes = [ne for ne in mesh.nodes[src_node]['far'] if (ne[0], ne[1]) == xy]
|
|
[mesh.nodes[src_node]['far'].remove(aa) for aa in redundant_nodes]
|
|
if mesh.nodes[src_node].get('near') is not None:
|
|
redundant_nodes = [ne for ne in mesh.nodes[src_node]['near'] if (ne[0], ne[1]) == xy]
|
|
[mesh.nodes[src_node]['near'].remove(aa) for aa in redundant_nodes]
|
|
for xy in corner_nodes:
|
|
hx, hy = xy
|
|
four_nes = [xx for xx in [(hx + 1, hy), (hx - 1, hy), (hx, hy + 1), (hx, hy - 1)] if \
|
|
mesh.graph['bord_up'] <= xx[0] < mesh.graph['bord_down'] and \
|
|
mesh.graph['bord_left'] <= xx[1] < mesh.graph['bord_right']]
|
|
ne_nodes = []
|
|
ne_depths = []
|
|
for ne_loc in four_nes:
|
|
if info_on_pix.get(ne_loc) is not None:
|
|
ne_depths.append(info_on_pix[ne_loc][0]['depth'])
|
|
ne_nodes.append((ne_loc[0], ne_loc[1], info_on_pix[ne_loc][0]['depth']))
|
|
new_node = (xy[0], xy[1], float(np.mean(ne_depths)))
|
|
if info_on_pix.get(xy) is not None and len(info_on_pix.get(xy)) > 0:
|
|
old_depth = info_on_pix[xy][0].get('depth')
|
|
old_node = (xy[0], xy[1], old_depth)
|
|
mesh.remove_edges_from([(old_ne, old_node) for old_ne in mesh.neighbors(old_node)])
|
|
mesh.add_edges_from([(zz, old_node) for zz in ne_nodes])
|
|
mapping_dict = {old_node: new_node}
|
|
info_on_pix, mesh = update_info(mapping_dict, info_on_pix, mesh)
|
|
else:
|
|
info_on_pix[xy] = []
|
|
info_on_pix[xy][0] = info_on_pix[ne_loc[-1]][0]
|
|
info_on_pix['color'] = image[xy[0], xy[1]]
|
|
info_on_pix['old_color'] = image[xy[0], xy[1]]
|
|
mesh.add_node(new_node)
|
|
mesh.add_edges_from([(zz, new_node) for zz in ne_nodes])
|
|
mesh.nodes[new_node]['far'] = None
|
|
mesh.nodes[new_node]['near'] = None
|
|
for xy in bord_nodes + corner_nodes:
|
|
# if (xy[0], xy[1]) == (559, 60):
|
|
# import pdb; pdb.set_trace()
|
|
depth[xy[0], xy[1]] = abs(info_on_pix[xy][0]['depth'])
|
|
for xy in bord_nodes:
|
|
cur_node = (xy[0], xy[1], info_on_pix[xy][0]['depth'])
|
|
nes = mesh.neighbors(cur_node)
|
|
four_nes = set([(xy[0] + 1, xy[1]), (xy[0] - 1, xy[1]), (xy[0], xy[1] + 1), (xy[0], xy[1] - 1)]) - \
|
|
set([(ne[0], ne[1]) for ne in nes])
|
|
four_nes = [ne for ne in four_nes if mesh.graph['bord_up'] <= ne[0] < mesh.graph['bord_down'] and \
|
|
mesh.graph['bord_left'] <= ne[1] < mesh.graph['bord_right']]
|
|
four_nes = [(ne[0], ne[1], info_on_pix[(ne[0], ne[1])][0]['depth']) for ne in four_nes]
|
|
mesh.nodes[cur_node]['far'] = []
|
|
mesh.nodes[cur_node]['near'] = []
|
|
for ne in four_nes:
|
|
if abs(ne[2]) >= abs(cur_node[2]):
|
|
mesh.nodes[cur_node]['far'].append(ne)
|
|
else:
|
|
mesh.nodes[cur_node]['near'].append(ne)
|
|
|
|
return mesh, info_on_pix, depth
|
|
|
|
def get_union_size(mesh, dilate, *alls_cc):
|
|
all_cc = reduce(lambda x, y: x | y, [set()] + [*alls_cc])
|
|
min_x, min_y, max_x, max_y = mesh.graph['H'], mesh.graph['W'], 0, 0
|
|
H, W = mesh.graph['H'], mesh.graph['W']
|
|
for node in all_cc:
|
|
if node[0] < min_x:
|
|
min_x = node[0]
|
|
if node[0] > max_x:
|
|
max_x = node[0]
|
|
if node[1] < min_y:
|
|
min_y = node[1]
|
|
if node[1] > max_y:
|
|
max_y = node[1]
|
|
max_x = max_x + 1
|
|
max_y = max_y + 1
|
|
# mask_size = dilate_valid_size(mask_size, edge_dict['mask'], dilate=[20, 20])
|
|
osize_dict = dict()
|
|
osize_dict['x_min'] = max(0, min_x - dilate[0])
|
|
osize_dict['x_max'] = min(H, max_x + dilate[0])
|
|
osize_dict['y_min'] = max(0, min_y - dilate[1])
|
|
osize_dict['y_max'] = min(W, max_y + dilate[1])
|
|
|
|
return osize_dict
|
|
|
|
def incomplete_node(mesh, edge_maps, info_on_pix):
|
|
vis_map = np.zeros((mesh.graph['H'], mesh.graph['W']))
|
|
|
|
for node in mesh.nodes:
|
|
if mesh.nodes[node].get('synthesis') is not True:
|
|
connect_all_flag = False
|
|
nes = [xx for xx in mesh.neighbors(node) if mesh.nodes[xx].get('synthesis') is not True]
|
|
if len(nes) < 3 and 0 < node[0] < mesh.graph['H'] - 1 and 0 < node[1] < mesh.graph['W'] - 1:
|
|
if len(nes) <= 1:
|
|
connect_all_flag = True
|
|
else:
|
|
dan_ne_node_a = nes[0]
|
|
dan_ne_node_b = nes[1]
|
|
if abs(dan_ne_node_a[0] - dan_ne_node_b[0]) > 1 or \
|
|
abs(dan_ne_node_a[1] - dan_ne_node_b[1]) > 1:
|
|
connect_all_flag = True
|
|
if connect_all_flag == True:
|
|
vis_map[node[0], node[1]] = len(nes)
|
|
four_nes = [(node[0] - 1, node[1]), (node[0] + 1, node[1]), (node[0], node[1] - 1), (node[0], node[1] + 1)]
|
|
for ne in four_nes:
|
|
for info in info_on_pix[(ne[0], ne[1])]:
|
|
ne_node = (ne[0], ne[1], info['depth'])
|
|
if info.get('synthesis') is not True and mesh.has_node(ne_node):
|
|
mesh.add_edge(node, ne_node)
|
|
break
|
|
|
|
return mesh
|
|
|
|
def edge_inpainting(edge_id, context_cc, erode_context_cc, mask_cc, edge_cc, extend_edge_cc,
|
|
mesh, edge_map, edge_maps_with_id, config, union_size, depth_edge_model, inpaint_iter):
|
|
edge_dict = get_edge_from_nodes(context_cc, erode_context_cc, mask_cc, edge_cc, extend_edge_cc,
|
|
mesh.graph['H'], mesh.graph['W'], mesh)
|
|
edge_dict['edge'], end_depth_maps, _ = \
|
|
filter_irrelevant_edge_new(edge_dict['self_edge'] + edge_dict['comp_edge'],
|
|
edge_map,
|
|
edge_maps_with_id,
|
|
edge_id,
|
|
edge_dict['context'],
|
|
edge_dict['depth'], mesh, context_cc | erode_context_cc, spdb=True)
|
|
patch_edge_dict = dict()
|
|
patch_edge_dict['mask'], patch_edge_dict['context'], patch_edge_dict['rgb'], \
|
|
patch_edge_dict['disp'], patch_edge_dict['edge'] = \
|
|
crop_maps_by_size(union_size, edge_dict['mask'], edge_dict['context'],
|
|
edge_dict['rgb'], edge_dict['disp'], edge_dict['edge'])
|
|
tensor_edge_dict = convert2tensor(patch_edge_dict)
|
|
if require_depth_edge(patch_edge_dict['edge'], patch_edge_dict['mask']) and inpaint_iter == 0:
|
|
with torch.no_grad():
|
|
device = config["gpu_ids"] if isinstance(config["gpu_ids"], int) and config["gpu_ids"] >= 0 else "cpu"
|
|
depth_edge_output = depth_edge_model.forward_3P(tensor_edge_dict['mask'],
|
|
tensor_edge_dict['context'],
|
|
tensor_edge_dict['rgb'],
|
|
tensor_edge_dict['disp'],
|
|
tensor_edge_dict['edge'],
|
|
unit_length=128,
|
|
cuda=device)
|
|
depth_edge_output = depth_edge_output.cpu()
|
|
tensor_edge_dict['output'] = (depth_edge_output > config['ext_edge_threshold']).float() * tensor_edge_dict['mask'] + tensor_edge_dict['edge']
|
|
else:
|
|
tensor_edge_dict['output'] = tensor_edge_dict['edge']
|
|
depth_edge_output = tensor_edge_dict['edge'] + 0
|
|
patch_edge_dict['output'] = tensor_edge_dict['output'].squeeze().data.cpu().numpy()
|
|
edge_dict['output'] = np.zeros((mesh.graph['H'], mesh.graph['W']))
|
|
edge_dict['output'][union_size['x_min']:union_size['x_max'], union_size['y_min']:union_size['y_max']] = \
|
|
patch_edge_dict['output']
|
|
|
|
return edge_dict, end_depth_maps
|
|
|
|
def depth_inpainting(context_cc, extend_context_cc, erode_context_cc, mask_cc, mesh, config, union_size, depth_feat_model, edge_output, given_depth_dict=False, spdb=False):
|
|
if given_depth_dict is False:
|
|
depth_dict = get_depth_from_nodes(context_cc | extend_context_cc, erode_context_cc, mask_cc, mesh.graph['H'], mesh.graph['W'], mesh, config['log_depth'])
|
|
if edge_output is not None:
|
|
depth_dict['edge'] = edge_output
|
|
else:
|
|
depth_dict = given_depth_dict
|
|
patch_depth_dict = dict()
|
|
patch_depth_dict['mask'], patch_depth_dict['context'], patch_depth_dict['depth'], \
|
|
patch_depth_dict['zero_mean_depth'], patch_depth_dict['edge'] = \
|
|
crop_maps_by_size(union_size, depth_dict['mask'], depth_dict['context'],
|
|
depth_dict['real_depth'], depth_dict['zero_mean_depth'], depth_dict['edge'])
|
|
tensor_depth_dict = convert2tensor(patch_depth_dict)
|
|
resize_mask = open_small_mask(tensor_depth_dict['mask'], tensor_depth_dict['context'], 3, 41)
|
|
with torch.no_grad():
|
|
device = config["gpu_ids"] if isinstance(config["gpu_ids"], int) and config["gpu_ids"] >= 0 else "cpu"
|
|
depth_output = depth_feat_model.forward_3P(resize_mask,
|
|
tensor_depth_dict['context'],
|
|
tensor_depth_dict['zero_mean_depth'],
|
|
tensor_depth_dict['edge'],
|
|
unit_length=128,
|
|
cuda=device)
|
|
depth_output = depth_output.cpu()
|
|
tensor_depth_dict['output'] = torch.exp(depth_output + depth_dict['mean_depth']) * \
|
|
tensor_depth_dict['mask'] + tensor_depth_dict['depth']
|
|
patch_depth_dict['output'] = tensor_depth_dict['output'].data.cpu().numpy().squeeze()
|
|
depth_dict['output'] = np.zeros((mesh.graph['H'], mesh.graph['W']))
|
|
depth_dict['output'][union_size['x_min']:union_size['x_max'], union_size['y_min']:union_size['y_max']] = \
|
|
patch_depth_dict['output']
|
|
depth_output = depth_dict['output'] * depth_dict['mask'] + depth_dict['depth'] * depth_dict['context']
|
|
depth_output = smooth_cntsyn_gap(depth_dict['output'].copy() * depth_dict['mask'] + depth_dict['depth'] * depth_dict['context'],
|
|
depth_dict['mask'], depth_dict['context'],
|
|
init_mask_region=depth_dict['mask'])
|
|
if spdb is True:
|
|
f, ((ax1, ax2)) = plt.subplots(1, 2, sharex=True, sharey=True);
|
|
ax1.imshow(depth_output * depth_dict['mask'] + depth_dict['depth']); ax2.imshow(depth_dict['output'] * depth_dict['mask'] + depth_dict['depth']); plt.show()
|
|
import pdb; pdb.set_trace()
|
|
depth_dict['output'] = depth_output * depth_dict['mask'] + depth_dict['depth'] * depth_dict['context']
|
|
|
|
return depth_dict
|
|
|
|
def update_info(mapping_dict, info_on_pix, *meshes):
|
|
rt_meshes = []
|
|
for mesh in meshes:
|
|
rt_meshes.append(relabel_node(mesh, mesh.nodes, [*mapping_dict.keys()][0], [*mapping_dict.values()][0]))
|
|
x, y, _ = [*mapping_dict.keys()][0]
|
|
info_on_pix[(x, y)][0]['depth'] = [*mapping_dict.values()][0][2]
|
|
|
|
return [info_on_pix] + rt_meshes
|
|
|
|
def build_connection(mesh, cur_node, dst_node):
|
|
if (abs(cur_node[0] - dst_node[0]) + abs(cur_node[1] - dst_node[1])) < 2:
|
|
mesh.add_edge(cur_node, dst_node)
|
|
if abs(cur_node[0] - dst_node[0]) > 1 or abs(cur_node[1] - dst_node[1]) > 1:
|
|
return mesh
|
|
ne_nodes = [*mesh.neighbors(cur_node)].copy()
|
|
for ne_node in ne_nodes:
|
|
if mesh.has_edge(ne_node, dst_node) or ne_node == dst_node:
|
|
continue
|
|
else:
|
|
mesh = build_connection(mesh, ne_node, dst_node)
|
|
|
|
return mesh
|
|
|
|
def recursive_add_edge(edge_mesh, mesh, info_on_pix, cur_node, mark):
|
|
ne_nodes = [(x[0], x[1]) for x in edge_mesh.neighbors(cur_node)]
|
|
for node_xy in ne_nodes:
|
|
node = (node_xy[0], node_xy[1], info_on_pix[node_xy][0]['depth'])
|
|
if mark[node[0], node[1]] != 3:
|
|
continue
|
|
else:
|
|
mark[node[0], node[1]] = 0
|
|
mesh.remove_edges_from([(xx, node) for xx in mesh.neighbors(node)])
|
|
mesh = build_connection(mesh, cur_node, node)
|
|
re_info = dict(depth=0, count=0)
|
|
for re_ne in mesh.neighbors(node):
|
|
re_info['depth'] += re_ne[2]
|
|
re_info['count'] += 1.
|
|
try:
|
|
re_depth = re_info['depth'] / re_info['count']
|
|
except:
|
|
re_depth = node[2]
|
|
re_node = (node_xy[0], node_xy[1], re_depth)
|
|
mapping_dict = {node: re_node}
|
|
info_on_pix, edge_mesh, mesh = update_info(mapping_dict, info_on_pix, edge_mesh, mesh)
|
|
|
|
edge_mesh, mesh, mark, info_on_pix = recursive_add_edge(edge_mesh, mesh, info_on_pix, re_node, mark)
|
|
|
|
return edge_mesh, mesh, mark, info_on_pix
|
|
|
|
def resize_for_edge(tensor_dict, largest_size):
|
|
resize_dict = {k: v.clone() for k, v in tensor_dict.items()}
|
|
frac = largest_size / np.array([*resize_dict['edge'].shape[-2:]]).max()
|
|
if frac < 1:
|
|
resize_mark = torch.nn.functional.interpolate(torch.cat((resize_dict['mask'],
|
|
resize_dict['context']),
|
|
dim=1),
|
|
scale_factor=frac,
|
|
mode='bilinear')
|
|
resize_dict['mask'] = (resize_mark[:, 0:1] > 0).float()
|
|
resize_dict['context'] = (resize_mark[:, 1:2] == 1).float()
|
|
resize_dict['context'][resize_dict['mask'] > 0] = 0
|
|
resize_dict['edge'] = torch.nn.functional.interpolate(resize_dict['edge'],
|
|
scale_factor=frac,
|
|
mode='bilinear')
|
|
resize_dict['edge'] = (resize_dict['edge'] > 0).float()
|
|
resize_dict['edge'] = resize_dict['edge'] * resize_dict['context']
|
|
resize_dict['disp'] = torch.nn.functional.interpolate(resize_dict['disp'],
|
|
scale_factor=frac,
|
|
mode='nearest')
|
|
resize_dict['disp'] = resize_dict['disp'] * resize_dict['context']
|
|
resize_dict['rgb'] = torch.nn.functional.interpolate(resize_dict['rgb'],
|
|
scale_factor=frac,
|
|
mode='bilinear')
|
|
resize_dict['rgb'] = resize_dict['rgb'] * resize_dict['context']
|
|
return resize_dict
|
|
|
|
def get_map_from_nodes(nodes, height, width):
|
|
omap = np.zeros((height, width))
|
|
for n in nodes:
|
|
omap[n[0], n[1]] = 1
|
|
|
|
return omap
|
|
|
|
def get_map_from_ccs(ccs, height, width, condition_input=None, condition=None, real_id=False, id_shift=0):
|
|
if condition is None:
|
|
condition = lambda x, condition_input: True
|
|
|
|
if real_id is True:
|
|
omap = np.zeros((height, width)) + (-1) + id_shift
|
|
else:
|
|
omap = np.zeros((height, width))
|
|
for cc_id, cc in enumerate(ccs):
|
|
for n in cc:
|
|
if condition(n, condition_input):
|
|
if real_id is True:
|
|
omap[n[0], n[1]] = cc_id + id_shift
|
|
else:
|
|
omap[n[0], n[1]] = 1
|
|
return omap
|
|
|
|
def revise_map_by_nodes(nodes, imap, operation, limit_constr=None):
|
|
assert operation == '+' or operation == '-', "Operation must be '+' (union) or '-' (exclude)"
|
|
omap = copy.deepcopy(imap)
|
|
revise_flag = True
|
|
if operation == '+':
|
|
for n in nodes:
|
|
omap[n[0], n[1]] = 1
|
|
if limit_constr is not None and omap.sum() > limit_constr:
|
|
omap = imap
|
|
revise_flag = False
|
|
elif operation == '-':
|
|
for n in nodes:
|
|
omap[n[0], n[1]] = 0
|
|
if limit_constr is not None and omap.sum() < limit_constr:
|
|
omap = imap
|
|
revise_flag = False
|
|
|
|
return omap, revise_flag
|
|
|
|
def repaint_info(mesh, cc, x_anchor, y_anchor, source_type):
|
|
if source_type == 'rgb':
|
|
feat = np.zeros((3, x_anchor[1] - x_anchor[0], y_anchor[1] - y_anchor[0]))
|
|
else:
|
|
feat = np.zeros((1, x_anchor[1] - x_anchor[0], y_anchor[1] - y_anchor[0]))
|
|
for node in cc:
|
|
if source_type == 'rgb':
|
|
feat[:, node[0] - x_anchor[0], node[1] - y_anchor[0]] = np.array(mesh.nodes[node]['color']) / 255.
|
|
elif source_type == 'd':
|
|
feat[:, node[0] - x_anchor[0], node[1] - y_anchor[0]] = abs(node[2])
|
|
|
|
return feat
|
|
|
|
def get_context_from_nodes(mesh, cc, H, W, source_type=''):
|
|
if 'rgb' in source_type or 'color' in source_type:
|
|
feat = np.zeros((H, W, 3))
|
|
else:
|
|
feat = np.zeros((H, W))
|
|
context = np.zeros((H, W))
|
|
for node in cc:
|
|
if 'rgb' in source_type or 'color' in source_type:
|
|
feat[node[0], node[1]] = np.array(mesh.nodes[node]['color']) / 255.
|
|
context[node[0], node[1]] = 1
|
|
else:
|
|
feat[node[0], node[1]] = abs(node[2])
|
|
|
|
return feat, context
|
|
|
|
def get_mask_from_nodes(mesh, cc, H, W):
|
|
mask = np.zeros((H, W))
|
|
for node in cc:
|
|
mask[node[0], node[1]] = abs(node[2])
|
|
|
|
return mask
|
|
|
|
|
|
def get_edge_from_nodes(context_cc, erode_context_cc, mask_cc, edge_cc, extend_edge_cc, H, W, mesh):
|
|
context = np.zeros((H, W))
|
|
mask = np.zeros((H, W))
|
|
rgb = np.zeros((H, W, 3))
|
|
disp = np.zeros((H, W))
|
|
depth = np.zeros((H, W))
|
|
real_depth = np.zeros((H, W))
|
|
edge = np.zeros((H, W))
|
|
comp_edge = np.zeros((H, W))
|
|
fpath_map = np.zeros((H, W)) - 1
|
|
npath_map = np.zeros((H, W)) - 1
|
|
near_depth = np.zeros((H, W))
|
|
for node in context_cc:
|
|
rgb[node[0], node[1]] = np.array(mesh.nodes[node]['color'])
|
|
disp[node[0], node[1]] = mesh.nodes[node]['disp']
|
|
depth[node[0], node[1]] = node[2]
|
|
context[node[0], node[1]] = 1
|
|
for node in erode_context_cc:
|
|
rgb[node[0], node[1]] = np.array(mesh.nodes[node]['color'])
|
|
disp[node[0], node[1]] = mesh.nodes[node]['disp']
|
|
depth[node[0], node[1]] = node[2]
|
|
context[node[0], node[1]] = 1
|
|
rgb = rgb / 255.
|
|
disp = np.abs(disp)
|
|
disp = disp / disp.max()
|
|
real_depth = depth.copy()
|
|
for node in context_cc:
|
|
if mesh.nodes[node].get('real_depth') is not None:
|
|
real_depth[node[0], node[1]] = mesh.nodes[node]['real_depth']
|
|
for node in erode_context_cc:
|
|
if mesh.nodes[node].get('real_depth') is not None:
|
|
real_depth[node[0], node[1]] = mesh.nodes[node]['real_depth']
|
|
for node in mask_cc:
|
|
mask[node[0], node[1]] = 1
|
|
near_depth[node[0], node[1]] = node[2]
|
|
for node in edge_cc:
|
|
edge[node[0], node[1]] = 1
|
|
for node in extend_edge_cc:
|
|
comp_edge[node[0], node[1]] = 1
|
|
rt_dict = {'rgb': rgb, 'disp': disp, 'depth': depth, 'real_depth': real_depth, 'self_edge': edge, 'context': context,
|
|
'mask': mask, 'fpath_map': fpath_map, 'npath_map': npath_map, 'comp_edge': comp_edge, 'valid_area': context + mask,
|
|
'near_depth': near_depth}
|
|
|
|
return rt_dict
|
|
|
|
def get_depth_from_maps(context_map, mask_map, depth_map, H, W, log_depth=False):
|
|
context = context_map.astype(np.uint8)
|
|
mask = mask_map.astype(np.uint8).copy()
|
|
depth = np.abs(depth_map)
|
|
real_depth = depth.copy()
|
|
zero_mean_depth = np.zeros((H, W))
|
|
|
|
if log_depth is True:
|
|
log_depth = np.log(real_depth + 1e-8) * context
|
|
mean_depth = np.mean(log_depth[context > 0])
|
|
zero_mean_depth = (log_depth - mean_depth) * context
|
|
else:
|
|
zero_mean_depth = real_depth
|
|
mean_depth = 0
|
|
edge = np.zeros_like(depth)
|
|
|
|
rt_dict = {'depth': depth, 'real_depth': real_depth, 'context': context, 'mask': mask,
|
|
'mean_depth': mean_depth, 'zero_mean_depth': zero_mean_depth, 'edge': edge}
|
|
|
|
return rt_dict
|
|
|
|
def get_depth_from_nodes(context_cc, erode_context_cc, mask_cc, H, W, mesh, log_depth=False):
|
|
context = np.zeros((H, W))
|
|
mask = np.zeros((H, W))
|
|
depth = np.zeros((H, W))
|
|
real_depth = np.zeros((H, W))
|
|
zero_mean_depth = np.zeros((H, W))
|
|
for node in context_cc:
|
|
depth[node[0], node[1]] = node[2]
|
|
context[node[0], node[1]] = 1
|
|
for node in erode_context_cc:
|
|
depth[node[0], node[1]] = node[2]
|
|
context[node[0], node[1]] = 1
|
|
depth = np.abs(depth)
|
|
real_depth = depth.copy()
|
|
for node in context_cc:
|
|
if mesh.nodes[node].get('real_depth') is not None:
|
|
real_depth[node[0], node[1]] = mesh.nodes[node]['real_depth']
|
|
for node in erode_context_cc:
|
|
if mesh.nodes[node].get('real_depth') is not None:
|
|
real_depth[node[0], node[1]] = mesh.nodes[node]['real_depth']
|
|
real_depth = np.abs(real_depth)
|
|
for node in mask_cc:
|
|
mask[node[0], node[1]] = 1
|
|
if log_depth is True:
|
|
log_depth = np.log(real_depth + 1e-8) * context
|
|
mean_depth = np.mean(log_depth[context > 0])
|
|
zero_mean_depth = (log_depth - mean_depth) * context
|
|
else:
|
|
zero_mean_depth = real_depth
|
|
mean_depth = 0
|
|
|
|
rt_dict = {'depth': depth, 'real_depth': real_depth, 'context': context, 'mask': mask,
|
|
'mean_depth': mean_depth, 'zero_mean_depth': zero_mean_depth}
|
|
|
|
return rt_dict
|
|
|
|
def get_rgb_from_nodes(context_cc, erode_context_cc, mask_cc, H, W, mesh):
|
|
context = np.zeros((H, W))
|
|
mask = np.zeros((H, W))
|
|
rgb = np.zeros((H, W, 3))
|
|
erode_context = np.zeros((H, W))
|
|
for node in context_cc:
|
|
rgb[node[0], node[1]] = np.array(mesh.nodes[node]['color'])
|
|
context[node[0], node[1]] = 1
|
|
rgb = rgb / 255.
|
|
for node in mask_cc:
|
|
mask[node[0], node[1]] = 1
|
|
for node in erode_context_cc:
|
|
erode_context[node[0], node[1]] = 1
|
|
mask[node[0], node[1]] = 1
|
|
rt_dict = {'rgb': rgb, 'context': context, 'mask': mask,
|
|
'erode': erode_context}
|
|
|
|
return rt_dict
|
|
|
|
def crop_maps_by_size(size, *imaps):
|
|
omaps = []
|
|
for imap in imaps:
|
|
omaps.append(imap[size['x_min']:size['x_max'], size['y_min']:size['y_max']].copy())
|
|
|
|
return omaps
|
|
|
|
def convert2tensor(input_dict):
|
|
rt_dict = {}
|
|
for key, value in input_dict.items():
|
|
if 'rgb' in key or 'color' in key:
|
|
rt_dict[key] = torch.FloatTensor(value).permute(2, 0, 1)[None, ...]
|
|
else:
|
|
rt_dict[key] = torch.FloatTensor(value)[None, None, ...]
|
|
|
|
return rt_dict
|