3d-photography-with-image-i.../Inpainting/main.py

259 lines
8.8 KiB
Python

# imports
import numpy as np
import argparse
import glob
import os
from functools import partial
import vispy
import scipy.misc as misc
from tqdm import tqdm
import yaml
import time
import sys
from Inpainting.mesh import write_ply, read_ply, output_3d_photo
from Inpainting.utils import get_MiDaS_samples, read_MiDaS_depth
import torch
import cv2
from skimage.transform import resize
import imageio
import copy
from Inpainting.networks import Inpaint_Color_Net, Inpaint_Depth_Net, Inpaint_Edge_Net
from Inpainting.MiDaS.run import run_depth
from Inpainting.MiDaS.monodepth_net import MonoDepthNet # model to compute depth
import Inpainting.MiDaS.MiDaS_utils as MiDaS_utils
from Inpainting.bilateral_filtering import sparse_bilateral_filtering
import yaml
import subprocess
def inpaint(file_name):
subprocess.call(
[
"sed -i 's/offscreen_rendering: True/offscreen_rendering: False/g' Inpainting/argument.yml"
],
shell=True,
)
argtarget = "Inpainting/argument.yml"
print(f"reading {argtarget} for arguments...")
with open(argtarget) as f:
ybytes = f.read()
list_doc = yaml.load(ybytes, yaml.UnsafeLoader)
f.close()
list_doc["src_folder"] = sys.argv[1]
list_doc["depth_folder"] = "Output"
list_doc["require_midas"] = True
list_doc["specific"] = file_name.split(".")[0]
with open(argtarget, "w") as f:
yaml.dump(list_doc, f)
# command line arguments
with open(argtarget, "r") as f:
config = yaml.load(f, yaml.UnsafeLoader)
f.close()
if config["offscreen_rendering"] is True:
vispy.use(app="egl")
# create some directories
os.makedirs(config["mesh_folder"], exist_ok=True)
os.makedirs(config["video_folder"], exist_ok=True)
os.makedirs(config["depth_folder"], exist_ok=True)
sample_list = get_MiDaS_samples(
config["src_folder"], config["depth_folder"], config, config["specific"]
) # dict of important stuffs
normal_canvas, all_canvas = None, None
# find device
if isinstance(config["gpu_ids"], int) and (config["gpu_ids"] >= 0):
device = config["gpu_ids"]
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
else:
device = "0"
print(f"using {len(sample_list)} samples on device {device}")
# iterate over each image.
for idx in tqdm(range(len(sample_list))):
depth = None
sample = sample_list[idx] # select image
print("Current Source ==> ", sample["src_pair_name"])
mesh_fi = os.path.join(config["mesh_folder"], sample["src_pair_name"] + ".ply")
image = imageio.imread(sample["ref_img_fi"])
print(f"Running depth extraction at {time.time()}")
if config["require_midas"] is True:
run_depth(
[sample["ref_img_fi"]],
config["src_folder"],
config["depth_folder"], # compute depth
config["MiDaS_model_ckpt"],
MonoDepthNet,
MiDaS_utils,
target_w=1280,
)
if "npy" in config["depth_format"]:
config["output_h"], config["output_w"] = np.load(sample["depth_fi"]).shape[
:2
]
else:
config["output_h"], config["output_w"] = imageio.imread(
sample["depth_fi"]
).shape[:2]
frac = config["longer_side_len"] / max(config["output_h"], config["output_w"])
config["output_h"], config["output_w"] = int(config["output_h"] * frac), int(
config["output_w"] * frac
)
config["original_h"], config["original_w"] = (
config["output_h"],
config["output_w"],
)
if image.ndim == 2:
image = image[..., None].repeat(3, -1)
if (
np.sum(np.abs(image[..., 0] - image[..., 1])) == 0
and np.sum(np.abs(image[..., 1] - image[..., 2])) == 0
):
config["gray_image"] = True
else:
config["gray_image"] = False
image = cv2.resize(
image,
(config["output_w"], config["output_h"]),
interpolation=cv2.INTER_AREA,
)
depth = read_MiDaS_depth(
sample["depth_fi"], 3.0, config["output_h"], config["output_w"]
) # read normalized depth computed
mean_loc_depth = depth[depth.shape[0] // 2, depth.shape[1] // 2]
starty = time.time()
if not (config["load_ply"] is True and os.path.exists(mesh_fi)):
vis_photos, vis_depths = sparse_bilateral_filtering(
depth.copy(),
image.copy(),
config,
num_iter=config["sparse_iter"],
spdb=False,
) # do bilateral filtering
depth = vis_depths[-1]
model = None
torch.cuda.empty_cache()
## MODEL INITS
print("Start Running 3D_Photo ...")
print(f"Loading edge model at {time.time()}")
depth_edge_model = Inpaint_Edge_Net(
init_weights=True
) # init edge inpainting model
depth_edge_weight = torch.load(
config["depth_edge_model_ckpt"], map_location="cuda:" + str(device)
)
depth_edge_model.load_state_dict(depth_edge_weight)
depth_edge_model = depth_edge_model.to(device)
depth_edge_model.eval() # in eval mode
print(f"Loading depth model at {time.time()}")
depth_feat_model = Inpaint_Depth_Net() # init depth inpainting model
depth_feat_weight = torch.load(
config["depth_feat_model_ckpt"], map_location=torch.device(device)
)
depth_feat_model.load_state_dict(depth_feat_weight, strict=True)
depth_feat_model = depth_feat_model.to(device)
depth_feat_model.eval()
depth_feat_model = depth_feat_model.to(device)
print(f"Loading rgb model at {time.time()}") # init color inpainting model
rgb_model = Inpaint_Color_Net()
rgb_feat_weight = torch.load(
config["rgb_feat_model_ckpt"], map_location=torch.device(device)
)
rgb_model.load_state_dict(rgb_feat_weight)
rgb_model.eval()
rgb_model = rgb_model.to(device)
graph = None
print(
f"Writing depth ply (and basically doing everything) at {time.time()}"
)
# do some mesh work
# starty = time.time()
rt_info = write_ply(
image,
depth,
sample["int_mtx"],
mesh_fi,
config,
rgb_model,
depth_edge_model,
depth_edge_model,
depth_feat_model,
)
if rt_info is False:
continue
rgb_model = None
color_feat_model = None
depth_edge_model = None
depth_feat_model = None
torch.cuda.empty_cache()
print(f"Total Time taken: {time.time()-starty}")
if config["save_ply"] is True or config["load_ply"] is True:
verts, colors, faces, Height, Width, hFov, vFov = read_ply(
mesh_fi
) # read from whatever mesh thing has done
else:
verts, colors, faces, Height, Width, hFov, vFov = rt_info
startx = time.time()
print(f"Making video at {time.time()}")
videos_poses, video_basename = (
copy.deepcopy(sample["tgts_poses"]),
sample["tgt_name"],
)
top = (
config.get("original_h") // 2 - sample["int_mtx"][1, 2] * config["output_h"]
)
left = (
config.get("original_w") // 2 - sample["int_mtx"][0, 2] * config["output_w"]
)
down, right = top + config["output_h"], left + config["output_w"]
border = [int(xx) for xx in [top, down, left, right]]
normal_canvas, all_canvas = output_3d_photo(
verts.copy(),
colors.copy(),
faces.copy(),
copy.deepcopy(Height),
copy.deepcopy(Width),
copy.deepcopy(hFov),
copy.deepcopy(vFov),
copy.deepcopy(sample["tgt_pose"]),
sample["video_postfix"],
copy.deepcopy(sample["ref_pose"]),
copy.deepcopy(config["video_folder"]),
image.copy(),
copy.deepcopy(sample["int_mtx"]),
config,
image,
videos_poses,
video_basename,
config.get("original_h"),
config.get("original_w"),
border=border,
depth=depth,
normal_canvas=normal_canvas,
all_canvas=all_canvas,
mean_loc_depth=mean_loc_depth,
)
print(f"Total Time taken: {time.time()-startx}")