3d-photography-with-image-i.../Inpainting/networks.py
2022-02-15 19:23:51 +05:30

502 lines
22 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import torch.nn.functional as F
class BaseNetwork(nn.Module):
def __init__(self):
super(BaseNetwork, self).__init__()
def init_weights(self, init_type='normal', gain=0.02):
'''
initialize network's weights
init_type: normal | xavier | kaiming | orthogonal
https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/9451e70673400885567d08a9e97ade2524c700d0/models/networks.py#L39
'''
def init_func(m):
classname = m.__class__.__name__
if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
if init_type == 'normal':
nn.init.normal_(m.weight.data, 0.0, gain)
elif init_type == 'xavier':
nn.init.xavier_normal_(m.weight.data, gain=gain)
elif init_type == 'kaiming':
nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
elif init_type == 'orthogonal':
nn.init.orthogonal_(m.weight.data, gain=gain)
if hasattr(m, 'bias') and m.bias is not None:
nn.init.constant_(m.bias.data, 0.0)
elif classname.find('BatchNorm2d') != -1:
nn.init.normal_(m.weight.data, 1.0, gain)
nn.init.constant_(m.bias.data, 0.0)
self.apply(init_func)
def weights_init(init_type='gaussian'):
def init_fun(m):
classname = m.__class__.__name__
if (classname.find('Conv') == 0 or classname.find(
'Linear') == 0) and hasattr(m, 'weight'):
if init_type == 'gaussian':
nn.init.normal_(m.weight, 0.0, 0.02)
elif init_type == 'xavier':
nn.init.xavier_normal_(m.weight, gain=math.sqrt(2))
elif init_type == 'kaiming':
nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in')
elif init_type == 'orthogonal':
nn.init.orthogonal_(m.weight, gain=math.sqrt(2))
elif init_type == 'default':
pass
else:
assert 0, "Unsupported initialization: {}".format(init_type)
if hasattr(m, 'bias') and m.bias is not None:
nn.init.constant_(m.bias, 0.0)
return init_fun
class PartialConv(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1, bias=True):
super().__init__()
self.input_conv = nn.Conv2d(in_channels, out_channels, kernel_size,
stride, padding, dilation, groups, bias)
self.mask_conv = nn.Conv2d(in_channels, out_channels, kernel_size,
stride, padding, dilation, groups, False)
self.input_conv.apply(weights_init('kaiming'))
self.slide_winsize = in_channels * kernel_size * kernel_size
torch.nn.init.constant_(self.mask_conv.weight, 1.0)
# mask is not updated
for param in self.mask_conv.parameters():
param.requires_grad = False
def forward(self, input, mask):
# http://masc.cs.gmu.edu/wiki/partialconv
# C(X) = W^T * X + b, C(0) = b, D(M) = 1 * M + 0 = sum(M)
# W^T* (M .* X) / sum(M) + b = [C(M .* X) C(0)] / D(M) + C(0)
output = self.input_conv(input * mask)
if self.input_conv.bias is not None:
output_bias = self.input_conv.bias.view(1, -1, 1, 1).expand_as(
output)
else:
output_bias = torch.zeros_like(output)
with torch.no_grad():
output_mask = self.mask_conv(mask)
no_update_holes = output_mask == 0
mask_sum = output_mask.masked_fill_(no_update_holes, 1.0)
output_pre = ((output - output_bias) * self.slide_winsize) / mask_sum + output_bias
output = output_pre.masked_fill_(no_update_holes, 0.0)
new_mask = torch.ones_like(output)
new_mask = new_mask.masked_fill_(no_update_holes, 0.0)
return output, new_mask
class PCBActiv(nn.Module):
def __init__(self, in_ch, out_ch, bn=True, sample='none-3', activ='relu',
conv_bias=False):
super().__init__()
if sample == 'down-5':
self.conv = PartialConv(in_ch, out_ch, 5, 2, 2, bias=conv_bias)
elif sample == 'down-7':
self.conv = PartialConv(in_ch, out_ch, 7, 2, 3, bias=conv_bias)
elif sample == 'down-3':
self.conv = PartialConv(in_ch, out_ch, 3, 2, 1, bias=conv_bias)
else:
self.conv = PartialConv(in_ch, out_ch, 3, 1, 1, bias=conv_bias)
if bn:
self.bn = nn.BatchNorm2d(out_ch)
if activ == 'relu':
self.activation = nn.ReLU()
elif activ == 'leaky':
self.activation = nn.LeakyReLU(negative_slope=0.2)
def forward(self, input, input_mask):
h, h_mask = self.conv(input, input_mask)
if hasattr(self, 'bn'):
h = self.bn(h)
if hasattr(self, 'activation'):
h = self.activation(h)
return h, h_mask
class Inpaint_Depth_Net(nn.Module):
def __init__(self, layer_size=7, upsampling_mode='nearest'):
super().__init__()
in_channels = 4
out_channels = 1
self.freeze_enc_bn = False
self.upsampling_mode = upsampling_mode
self.layer_size = layer_size
self.enc_1 = PCBActiv(in_channels, 64, bn=False, sample='down-7', conv_bias=True)
self.enc_2 = PCBActiv(64, 128, sample='down-5', conv_bias=True)
self.enc_3 = PCBActiv(128, 256, sample='down-5')
self.enc_4 = PCBActiv(256, 512, sample='down-3')
for i in range(4, self.layer_size):
name = 'enc_{:d}'.format(i + 1)
setattr(self, name, PCBActiv(512, 512, sample='down-3'))
for i in range(4, self.layer_size):
name = 'dec_{:d}'.format(i + 1)
setattr(self, name, PCBActiv(512 + 512, 512, activ='leaky'))
self.dec_4 = PCBActiv(512 + 256, 256, activ='leaky')
self.dec_3 = PCBActiv(256 + 128, 128, activ='leaky')
self.dec_2 = PCBActiv(128 + 64, 64, activ='leaky')
self.dec_1 = PCBActiv(64 + in_channels, out_channels,
bn=False, activ=None, conv_bias=True)
def add_border(self, input, mask_flag, PCONV=True):
with torch.no_grad():
h = input.shape[-2]
w = input.shape[-1]
require_len_unit = 2 ** self.layer_size
residual_h = int(np.ceil(h / float(require_len_unit)) * require_len_unit - h) # + 2*require_len_unit
residual_w = int(np.ceil(w / float(require_len_unit)) * require_len_unit - w) # + 2*require_len_unit
enlarge_input = torch.zeros((input.shape[0], input.shape[1], h + residual_h, w + residual_w)).to(input.device)
if mask_flag:
if PCONV is False:
enlarge_input += 1.0
enlarge_input = enlarge_input.clamp(0.0, 1.0)
else:
enlarge_input[:, 2, ...] = 0.0
anchor_h = residual_h//2
anchor_w = residual_w//2
enlarge_input[..., anchor_h:anchor_h+h, anchor_w:anchor_w+w] = input
return enlarge_input, [anchor_h, anchor_h+h, anchor_w, anchor_w+w]
def forward_3P(self, mask, context, depth, edge, unit_length=128, cuda=None):
with torch.no_grad():
input = torch.cat((depth, edge, context, mask), dim=1)
n, c, h, w = input.shape
residual_h = int(np.ceil(h / float(unit_length)) * unit_length - h)
residual_w = int(np.ceil(w / float(unit_length)) * unit_length - w)
anchor_h = residual_h//2
anchor_w = residual_w//2
enlarge_input = torch.zeros((n, c, h + residual_h, w + residual_w)).to(cuda)
enlarge_input[..., anchor_h:anchor_h+h, anchor_w:anchor_w+w] = input
# enlarge_input[:, 3] = 1. - enlarge_input[:, 3]
depth_output = self.forward(enlarge_input)
depth_output = depth_output[..., anchor_h:anchor_h+h, anchor_w:anchor_w+w]
# import pdb; pdb.set_trace()
return depth_output
def forward(self, input_feat, refine_border=False, sample=False, PCONV=True):
input = input_feat
input_mask = (input_feat[:, -2:-1] + input_feat[:, -1:]).clamp(0, 1).repeat(1, input.shape[1], 1, 1)
vis_input = input.cpu().data.numpy()
vis_input_mask = input_mask.cpu().data.numpy()
H, W = input.shape[-2:]
if refine_border is True:
input, anchor = self.add_border(input, mask_flag=False)
input_mask, anchor = self.add_border(input_mask, mask_flag=True, PCONV=PCONV)
h_dict = {} # for the output of enc_N
h_mask_dict = {} # for the output of enc_N
h_dict['h_0'], h_mask_dict['h_0'] = input, input_mask
h_key_prev = 'h_0'
for i in range(1, self.layer_size + 1):
l_key = 'enc_{:d}'.format(i)
h_key = 'h_{:d}'.format(i)
h_dict[h_key], h_mask_dict[h_key] = getattr(self, l_key)(
h_dict[h_key_prev], h_mask_dict[h_key_prev])
h_key_prev = h_key
h_key = 'h_{:d}'.format(self.layer_size)
h, h_mask = h_dict[h_key], h_mask_dict[h_key]
for i in range(self.layer_size, 0, -1):
enc_h_key = 'h_{:d}'.format(i - 1)
dec_l_key = 'dec_{:d}'.format(i)
h = F.interpolate(h, scale_factor=2, mode=self.upsampling_mode)
h_mask = F.interpolate(h_mask, scale_factor=2, mode='nearest')
h = torch.cat([h, h_dict[enc_h_key]], dim=1)
h_mask = torch.cat([h_mask, h_mask_dict[enc_h_key]], dim=1)
h, h_mask = getattr(self, dec_l_key)(h, h_mask)
output = h
if refine_border is True:
h_mask = h_mask[..., anchor[0]:anchor[1], anchor[2]:anchor[3]]
output = output[..., anchor[0]:anchor[1], anchor[2]:anchor[3]]
return output
class Inpaint_Edge_Net(BaseNetwork):
def __init__(self, residual_blocks=8, init_weights=True):
super(Inpaint_Edge_Net, self).__init__()
in_channels = 7
out_channels = 1
self.encoder = []
# 0
self.encoder_0 = nn.Sequential(
nn.ReflectionPad2d(3),
spectral_norm(nn.Conv2d(in_channels=in_channels, out_channels=64, kernel_size=7, padding=0), True),
nn.InstanceNorm2d(64, track_running_stats=False),
nn.ReLU(True))
# 1
self.encoder_1 = nn.Sequential(
spectral_norm(nn.Conv2d(in_channels=64, out_channels=128, kernel_size=4, stride=2, padding=1), True),
nn.InstanceNorm2d(128, track_running_stats=False),
nn.ReLU(True))
# 2
self.encoder_2 = nn.Sequential(
spectral_norm(nn.Conv2d(in_channels=128, out_channels=256, kernel_size=4, stride=2, padding=1), True),
nn.InstanceNorm2d(256, track_running_stats=False),
nn.ReLU(True))
# 3
blocks = []
for _ in range(residual_blocks):
block = ResnetBlock(256, 2)
blocks.append(block)
self.middle = nn.Sequential(*blocks)
# + 3
self.decoder_0 = nn.Sequential(
spectral_norm(nn.ConvTranspose2d(in_channels=256+256, out_channels=128, kernel_size=4, stride=2, padding=1), True),
nn.InstanceNorm2d(128, track_running_stats=False),
nn.ReLU(True))
# + 2
self.decoder_1 = nn.Sequential(
spectral_norm(nn.ConvTranspose2d(in_channels=128+128, out_channels=64, kernel_size=4, stride=2, padding=1), True),
nn.InstanceNorm2d(64, track_running_stats=False),
nn.ReLU(True))
# + 1
self.decoder_2 = nn.Sequential(
nn.ReflectionPad2d(3),
nn.Conv2d(in_channels=64+64, out_channels=out_channels, kernel_size=7, padding=0),
)
if init_weights:
self.init_weights()
def add_border(self, input, channel_pad_1=None):
h = input.shape[-2]
w = input.shape[-1]
require_len_unit = 16
residual_h = int(np.ceil(h / float(require_len_unit)) * require_len_unit - h) # + 2*require_len_unit
residual_w = int(np.ceil(w / float(require_len_unit)) * require_len_unit - w) # + 2*require_len_unit
enlarge_input = torch.zeros((input.shape[0], input.shape[1], h + residual_h, w + residual_w)).to(input.device)
if channel_pad_1 is not None:
for channel in channel_pad_1:
enlarge_input[:, channel] = 1
anchor_h = residual_h//2
anchor_w = residual_w//2
enlarge_input[..., anchor_h:anchor_h+h, anchor_w:anchor_w+w] = input
return enlarge_input, [anchor_h, anchor_h+h, anchor_w, anchor_w+w]
def forward_3P(self, mask, context, rgb, disp, edge, unit_length=128, cuda=None):
with torch.no_grad():
input = torch.cat((rgb, disp/disp.max(), edge, context, mask), dim=1)
n, c, h, w = input.shape
residual_h = int(np.ceil(h / float(unit_length)) * unit_length - h)
residual_w = int(np.ceil(w / float(unit_length)) * unit_length - w)
anchor_h = residual_h//2
anchor_w = residual_w//2
enlarge_input = torch.zeros((n, c, h + residual_h, w + residual_w)).to(cuda)
enlarge_input[..., anchor_h:anchor_h+h, anchor_w:anchor_w+w] = input
edge_output = self.forward(enlarge_input)
edge_output = edge_output[..., anchor_h:anchor_h+h, anchor_w:anchor_w+w]
return edge_output
def forward(self, x, refine_border=False):
if refine_border:
x, anchor = self.add_border(x, [5])
x1 = self.encoder_0(x)
x2 = self.encoder_1(x1)
x3 = self.encoder_2(x2)
x4 = self.middle(x3)
x5 = self.decoder_0(torch.cat((x4, x3), dim=1))
x6 = self.decoder_1(torch.cat((x5, x2), dim=1))
x7 = self.decoder_2(torch.cat((x6, x1), dim=1))
x = torch.sigmoid(x7)
if refine_border:
x = x[..., anchor[0]:anchor[1], anchor[2]:anchor[3]]
return x
class Inpaint_Color_Net(nn.Module):
def __init__(self, layer_size=7, upsampling_mode='nearest', add_hole_mask=False, add_two_layer=False, add_border=False):
super().__init__()
self.freeze_enc_bn = False
self.upsampling_mode = upsampling_mode
self.layer_size = layer_size
in_channels = 6
self.enc_1 = PCBActiv(in_channels, 64, bn=False, sample='down-7')
self.enc_2 = PCBActiv(64, 128, sample='down-5')
self.enc_3 = PCBActiv(128, 256, sample='down-5')
self.enc_4 = PCBActiv(256, 512, sample='down-3')
self.enc_5 = PCBActiv(512, 512, sample='down-3')
self.enc_6 = PCBActiv(512, 512, sample='down-3')
self.enc_7 = PCBActiv(512, 512, sample='down-3')
self.dec_7 = PCBActiv(512+512, 512, activ='leaky')
self.dec_6 = PCBActiv(512+512, 512, activ='leaky')
self.dec_5A = PCBActiv(512 + 512, 512, activ='leaky')
self.dec_4A = PCBActiv(512 + 256, 256, activ='leaky')
self.dec_3A = PCBActiv(256 + 128, 128, activ='leaky')
self.dec_2A = PCBActiv(128 + 64, 64, activ='leaky')
self.dec_1A = PCBActiv(64 + in_channels, 3, bn=False, activ=None, conv_bias=True)
'''
self.dec_5B = PCBActiv(512 + 512, 512, activ='leaky')
self.dec_4B = PCBActiv(512 + 256, 256, activ='leaky')
self.dec_3B = PCBActiv(256 + 128, 128, activ='leaky')
self.dec_2B = PCBActiv(128 + 64, 64, activ='leaky')
self.dec_1B = PCBActiv(64 + 4, 1, bn=False, activ=None, conv_bias=True)
'''
def cat(self, A, B):
return torch.cat((A, B), dim=1)
def upsample(self, feat, mask):
feat = F.interpolate(feat, scale_factor=2, mode=self.upsampling_mode)
mask = F.interpolate(mask, scale_factor=2, mode='nearest')
return feat, mask
def forward_3P(self, mask, context, rgb, edge, unit_length=128, cuda=None):
with torch.no_grad():
input = torch.cat((rgb, edge, context, mask), dim=1)
n, c, h, w = input.shape
residual_h = int(np.ceil(h / float(unit_length)) * unit_length - h) # + 128
residual_w = int(np.ceil(w / float(unit_length)) * unit_length - w) # + 256
anchor_h = residual_h//2
anchor_w = residual_w//2
enlarge_input = torch.zeros((n, c, h + residual_h, w + residual_w)).to(cuda)
enlarge_input[..., anchor_h:anchor_h+h, anchor_w:anchor_w+w] = input
# enlarge_input[:, 3] = 1. - enlarge_input[:, 3]
enlarge_input = enlarge_input.to(cuda)
rgb_output = self.forward(enlarge_input)
rgb_output = rgb_output[..., anchor_h:anchor_h+h, anchor_w:anchor_w+w]
return rgb_output
def forward(self, input, add_border=False):
input_mask = (input[:, -2:-1] + input[:, -1:]).clamp(0, 1)
H, W = input.shape[-2:]
f_0, h_0 = input, input_mask.repeat((1,input.shape[1],1,1))
f_1, h_1 = self.enc_1(f_0, h_0)
f_2, h_2 = self.enc_2(f_1, h_1)
f_3, h_3 = self.enc_3(f_2, h_2)
f_4, h_4 = self.enc_4(f_3, h_3)
f_5, h_5 = self.enc_5(f_4, h_4)
f_6, h_6 = self.enc_6(f_5, h_5)
f_7, h_7 = self.enc_7(f_6, h_6)
o_7, k_7 = self.upsample(f_7, h_7)
o_6, k_6 = self.dec_7(self.cat(o_7, f_6), self.cat(k_7, h_6))
o_6, k_6 = self.upsample(o_6, k_6)
o_5, k_5 = self.dec_6(self.cat(o_6, f_5), self.cat(k_6, h_5))
o_5, k_5 = self.upsample(o_5, k_5)
o_5A, k_5A = o_5, k_5
o_5B, k_5B = o_5, k_5
###############
o_4A, k_4A = self.dec_5A(self.cat(o_5A, f_4), self.cat(k_5A, h_4))
o_4A, k_4A = self.upsample(o_4A, k_4A)
o_3A, k_3A = self.dec_4A(self.cat(o_4A, f_3), self.cat(k_4A, h_3))
o_3A, k_3A = self.upsample(o_3A, k_3A)
o_2A, k_2A = self.dec_3A(self.cat(o_3A, f_2), self.cat(k_3A, h_2))
o_2A, k_2A = self.upsample(o_2A, k_2A)
o_1A, k_1A = self.dec_2A(self.cat(o_2A, f_1), self.cat(k_2A, h_1))
o_1A, k_1A = self.upsample(o_1A, k_1A)
o_0A, k_0A = self.dec_1A(self.cat(o_1A, f_0), self.cat(k_1A, h_0))
return torch.sigmoid(o_0A)
def train(self, mode=True):
"""
Override the default train() to freeze the BN parameters
"""
super().train(mode)
if self.freeze_enc_bn:
for name, module in self.named_modules():
if isinstance(module, nn.BatchNorm2d) and 'enc' in name:
module.eval()
class Discriminator(BaseNetwork):
def __init__(self, use_sigmoid=True, use_spectral_norm=True, init_weights=True, in_channels=None):
super(Discriminator, self).__init__()
self.use_sigmoid = use_sigmoid
self.conv1 = self.features = nn.Sequential(
spectral_norm(nn.Conv2d(in_channels=in_channels, out_channels=64, kernel_size=4, stride=2, padding=1, bias=not use_spectral_norm), use_spectral_norm),
nn.LeakyReLU(0.2, inplace=True),
)
self.conv2 = nn.Sequential(
spectral_norm(nn.Conv2d(in_channels=64, out_channels=128, kernel_size=4, stride=2, padding=1, bias=not use_spectral_norm), use_spectral_norm),
nn.LeakyReLU(0.2, inplace=True),
)
self.conv3 = nn.Sequential(
spectral_norm(nn.Conv2d(in_channels=128, out_channels=256, kernel_size=4, stride=2, padding=1, bias=not use_spectral_norm), use_spectral_norm),
nn.LeakyReLU(0.2, inplace=True),
)
self.conv4 = nn.Sequential(
spectral_norm(nn.Conv2d(in_channels=256, out_channels=512, kernel_size=4, stride=1, padding=1, bias=not use_spectral_norm), use_spectral_norm),
nn.LeakyReLU(0.2, inplace=True),
)
self.conv5 = nn.Sequential(
spectral_norm(nn.Conv2d(in_channels=512, out_channels=1, kernel_size=4, stride=1, padding=1, bias=not use_spectral_norm), use_spectral_norm),
)
if init_weights:
self.init_weights()
def forward(self, x):
conv1 = self.conv1(x)
conv2 = self.conv2(conv1)
conv3 = self.conv3(conv2)
conv4 = self.conv4(conv3)
conv5 = self.conv5(conv4)
outputs = conv5
if self.use_sigmoid:
outputs = torch.sigmoid(conv5)
return outputs, [conv1, conv2, conv3, conv4, conv5]
class ResnetBlock(nn.Module):
def __init__(self, dim, dilation=1):
super(ResnetBlock, self).__init__()
self.conv_block = nn.Sequential(
nn.ReflectionPad2d(dilation),
spectral_norm(nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=3, padding=0, dilation=dilation, bias=not True), True),
nn.InstanceNorm2d(dim, track_running_stats=False),
nn.LeakyReLU(negative_slope=0.2),
nn.ReflectionPad2d(1),
spectral_norm(nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=3, padding=0, dilation=1, bias=not True), True),
nn.InstanceNorm2d(dim, track_running_stats=False),
)
def forward(self, x):
out = x + self.conv_block(x)
# Remove ReLU at the end of the residual block
# http://torch.ch/blog/2016/02/04/resnets.html
return out
def spectral_norm(module, mode=True):
if mode:
return nn.utils.spectral_norm(module)
return module