codeformer/inference_codeformer.py
2022-07-16 22:03:45 +08:00

133 lines
5.6 KiB
Python

import argparse
import glob
import time
from xml.sax import parse
import numpy as np
import os
import cv2
import torch
import torchvision.transforms as transforms
from skimage import io
from basicsr.utils import imwrite, tensor2img
from basicsr.utils.face_util import FaceRestorationHelper
import torch.nn.functional as F
from basicsr.utils.registry import ARCH_REGISTRY
if __name__ == '__main__':
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
parser = argparse.ArgumentParser()
parser.add_argument('--w', type=float, default=0.5, help='Balance the quality and fidelity')
parser.add_argument('--upscale_factor', type=int, default=2)
parser.add_argument('--test_path', type=str, default='./inputs/cropped_faces')
parser.add_argument('--has_aligned', action='store_true', help='Input are cropped and aligned faces')
parser.add_argument('--upsample_num_times', type=int, default=1, help='Upsample the image before face detection')
parser.add_argument('--save_inverse_affine', action='store_true')
parser.add_argument('--only_keep_largest', action='store_true')
parser.add_argument('--draw_box', action='store_true')
# The following are the paths for dlib models
parser.add_argument(
'--detection_path', type=str,
default='weights/dlib/mmod_human_face_detector-4cb19393.dat'
)
parser.add_argument(
'--landmark5_path', type=str,
default='weights/dlib/shape_predictor_5_face_landmarks-c4b1e980.dat'
)
parser.add_argument(
'--landmark68_path', type=str,
default='weights/dlib/shape_predictor_68_face_landmarks-fbdc2cb8.dat'
)
args = parser.parse_args()
if args.test_path.endswith('/'): # solve when path ends with /
args.test_path = args.test_path[:-1]
w = args.w
result_root = f'results/{os.path.basename(args.test_path)}_{w}'
# set up the Network
net = ARCH_REGISTRY.get('CodeFormer')(dim_embd=512, codebook_size=1024, n_head=8, n_layers=9,
connect_list=['32', '64', '128', '256']).to(device)
ckpt_path = 'weights/CodeFormer/codeformer.pth'
checkpoint = torch.load(ckpt_path)['params_ema']
net.load_state_dict(checkpoint)
net.eval()
save_crop_root = os.path.join(result_root, 'cropped_faces')
save_restore_root = os.path.join(result_root, 'restored_faces')
save_final_root = os.path.join(result_root, 'final_results')
save_input_root = os.path.join(result_root, 'inputs')
face_helper = FaceRestorationHelper(args.upscale_factor, face_size=512)
face_helper.init_dlib(args.detection_path, args.landmark5_path, args.landmark68_path)
# scan all the jpg and png images
for img_path in sorted(glob.glob(os.path.join(args.test_path, '*.[jp][pn]g'))):
img_name = os.path.basename(img_path)
print(f'Processing: {img_name}')
if args.has_aligned:
# the input faces are already cropped and aligned
img = cv2.imread(img_path, cv2.IMREAD_COLOR)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = cv2.resize(img, (512, 512), interpolation=cv2.INTER_LINEAR)
face_helper.cropped_faces = [img]
cropped_faces = face_helper.cropped_faces
else:
# detect faces
num_det_faces = face_helper.detect_faces(
img_path, upsample_num_times=args.upsample_num_times, only_keep_largest=args.only_keep_largest)
# get 5 face landmarks for each face
num_landmarks = face_helper.get_face_landmarks_5()
print(f'\tDetect {num_det_faces} faces, {num_landmarks} landmarks.')
# warp and crop each face
save_crop_path = os.path.join(save_crop_root, img_name)
face_helper.warp_crop_faces(save_crop_path, save_inverse_affine_path=None)
cropped_faces = face_helper.cropped_faces
# get 68 landmarks for each cropped face
# num_landmarks = face_helper.get_face_landmarks_68()
# print(f'\tDetect {num_landmarks} faces for 68 landmarks.')
# assert len(cropped_faces) == len(face_helper.all_landmarks_68)
# TODO
# face_helper.free_dlib_gpu_memory()
# face restoration for each cropped face
for idx, cropped_face in enumerate(cropped_faces):
# prepare data
cropped_face = transforms.ToTensor()(cropped_face)
cropped_face = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(cropped_face)
cropped_face = cropped_face.unsqueeze(0).to(device)
try:
with torch.no_grad():
output = net(cropped_face, w=w, adain=True)[0]
restored_face = tensor2img(output, min_max=(-1, 1))
del output
torch.cuda.empty_cache()
except Exception as error:
print(f'\tFailed inference for CodeFormer: {error}')
restored_face = tensor2img(cropped_face, min_max=(-1, 1))
path = os.path.splitext(os.path.join(save_restore_root, img_name))[0]
if not args.has_aligned:
save_path = f'{path}_{idx:02d}.png'
face_helper.add_restored_face(restored_face)
else:
save_path = f'{path}.png'
imwrite(restored_face, save_path)
if not args.has_aligned:
# paste each restored face to the input image
face_helper.paste_faces_to_input_image(os.path.join(save_final_root, img_name), draw_box=args.draw_box)
# clean all the intermediate results to process the next image
face_helper.clean_all()
print(f'\nAll results are saved in {result_root}')