502 lines
22 KiB
Python
502 lines
22 KiB
Python
import torch
|
||
import torch.nn as nn
|
||
import numpy as np
|
||
import matplotlib.pyplot as plt
|
||
import torch.nn.functional as F
|
||
|
||
|
||
class BaseNetwork(nn.Module):
|
||
def __init__(self):
|
||
super(BaseNetwork, self).__init__()
|
||
|
||
def init_weights(self, init_type='normal', gain=0.02):
|
||
'''
|
||
initialize network's weights
|
||
init_type: normal | xavier | kaiming | orthogonal
|
||
https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/9451e70673400885567d08a9e97ade2524c700d0/models/networks.py#L39
|
||
'''
|
||
|
||
def init_func(m):
|
||
classname = m.__class__.__name__
|
||
if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
|
||
if init_type == 'normal':
|
||
nn.init.normal_(m.weight.data, 0.0, gain)
|
||
elif init_type == 'xavier':
|
||
nn.init.xavier_normal_(m.weight.data, gain=gain)
|
||
elif init_type == 'kaiming':
|
||
nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
|
||
elif init_type == 'orthogonal':
|
||
nn.init.orthogonal_(m.weight.data, gain=gain)
|
||
|
||
if hasattr(m, 'bias') and m.bias is not None:
|
||
nn.init.constant_(m.bias.data, 0.0)
|
||
|
||
elif classname.find('BatchNorm2d') != -1:
|
||
nn.init.normal_(m.weight.data, 1.0, gain)
|
||
nn.init.constant_(m.bias.data, 0.0)
|
||
|
||
self.apply(init_func)
|
||
|
||
def weights_init(init_type='gaussian'):
|
||
def init_fun(m):
|
||
classname = m.__class__.__name__
|
||
if (classname.find('Conv') == 0 or classname.find(
|
||
'Linear') == 0) and hasattr(m, 'weight'):
|
||
if init_type == 'gaussian':
|
||
nn.init.normal_(m.weight, 0.0, 0.02)
|
||
elif init_type == 'xavier':
|
||
nn.init.xavier_normal_(m.weight, gain=math.sqrt(2))
|
||
elif init_type == 'kaiming':
|
||
nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in')
|
||
elif init_type == 'orthogonal':
|
||
nn.init.orthogonal_(m.weight, gain=math.sqrt(2))
|
||
elif init_type == 'default':
|
||
pass
|
||
else:
|
||
assert 0, "Unsupported initialization: {}".format(init_type)
|
||
if hasattr(m, 'bias') and m.bias is not None:
|
||
nn.init.constant_(m.bias, 0.0)
|
||
|
||
return init_fun
|
||
|
||
class PartialConv(nn.Module):
|
||
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
|
||
padding=0, dilation=1, groups=1, bias=True):
|
||
super().__init__()
|
||
self.input_conv = nn.Conv2d(in_channels, out_channels, kernel_size,
|
||
stride, padding, dilation, groups, bias)
|
||
self.mask_conv = nn.Conv2d(in_channels, out_channels, kernel_size,
|
||
stride, padding, dilation, groups, False)
|
||
self.input_conv.apply(weights_init('kaiming'))
|
||
self.slide_winsize = in_channels * kernel_size * kernel_size
|
||
|
||
torch.nn.init.constant_(self.mask_conv.weight, 1.0)
|
||
|
||
# mask is not updated
|
||
for param in self.mask_conv.parameters():
|
||
param.requires_grad = False
|
||
|
||
def forward(self, input, mask):
|
||
# http://masc.cs.gmu.edu/wiki/partialconv
|
||
# C(X) = W^T * X + b, C(0) = b, D(M) = 1 * M + 0 = sum(M)
|
||
# W^T* (M .* X) / sum(M) + b = [C(M .* X) – C(0)] / D(M) + C(0)
|
||
output = self.input_conv(input * mask)
|
||
if self.input_conv.bias is not None:
|
||
output_bias = self.input_conv.bias.view(1, -1, 1, 1).expand_as(
|
||
output)
|
||
else:
|
||
output_bias = torch.zeros_like(output)
|
||
|
||
with torch.no_grad():
|
||
output_mask = self.mask_conv(mask)
|
||
|
||
no_update_holes = output_mask == 0
|
||
|
||
mask_sum = output_mask.masked_fill_(no_update_holes, 1.0)
|
||
|
||
output_pre = ((output - output_bias) * self.slide_winsize) / mask_sum + output_bias
|
||
output = output_pre.masked_fill_(no_update_holes, 0.0)
|
||
|
||
new_mask = torch.ones_like(output)
|
||
new_mask = new_mask.masked_fill_(no_update_holes, 0.0)
|
||
|
||
return output, new_mask
|
||
|
||
|
||
class PCBActiv(nn.Module):
|
||
def __init__(self, in_ch, out_ch, bn=True, sample='none-3', activ='relu',
|
||
conv_bias=False):
|
||
super().__init__()
|
||
if sample == 'down-5':
|
||
self.conv = PartialConv(in_ch, out_ch, 5, 2, 2, bias=conv_bias)
|
||
elif sample == 'down-7':
|
||
self.conv = PartialConv(in_ch, out_ch, 7, 2, 3, bias=conv_bias)
|
||
elif sample == 'down-3':
|
||
self.conv = PartialConv(in_ch, out_ch, 3, 2, 1, bias=conv_bias)
|
||
else:
|
||
self.conv = PartialConv(in_ch, out_ch, 3, 1, 1, bias=conv_bias)
|
||
|
||
if bn:
|
||
self.bn = nn.BatchNorm2d(out_ch)
|
||
if activ == 'relu':
|
||
self.activation = nn.ReLU()
|
||
elif activ == 'leaky':
|
||
self.activation = nn.LeakyReLU(negative_slope=0.2)
|
||
|
||
def forward(self, input, input_mask):
|
||
h, h_mask = self.conv(input, input_mask)
|
||
if hasattr(self, 'bn'):
|
||
h = self.bn(h)
|
||
if hasattr(self, 'activation'):
|
||
h = self.activation(h)
|
||
return h, h_mask
|
||
|
||
class Inpaint_Depth_Net(nn.Module):
|
||
def __init__(self, layer_size=7, upsampling_mode='nearest'):
|
||
super().__init__()
|
||
in_channels = 4
|
||
out_channels = 1
|
||
self.freeze_enc_bn = False
|
||
self.upsampling_mode = upsampling_mode
|
||
self.layer_size = layer_size
|
||
self.enc_1 = PCBActiv(in_channels, 64, bn=False, sample='down-7', conv_bias=True)
|
||
self.enc_2 = PCBActiv(64, 128, sample='down-5', conv_bias=True)
|
||
self.enc_3 = PCBActiv(128, 256, sample='down-5')
|
||
self.enc_4 = PCBActiv(256, 512, sample='down-3')
|
||
for i in range(4, self.layer_size):
|
||
name = 'enc_{:d}'.format(i + 1)
|
||
setattr(self, name, PCBActiv(512, 512, sample='down-3'))
|
||
|
||
for i in range(4, self.layer_size):
|
||
name = 'dec_{:d}'.format(i + 1)
|
||
setattr(self, name, PCBActiv(512 + 512, 512, activ='leaky'))
|
||
self.dec_4 = PCBActiv(512 + 256, 256, activ='leaky')
|
||
self.dec_3 = PCBActiv(256 + 128, 128, activ='leaky')
|
||
self.dec_2 = PCBActiv(128 + 64, 64, activ='leaky')
|
||
self.dec_1 = PCBActiv(64 + in_channels, out_channels,
|
||
bn=False, activ=None, conv_bias=True)
|
||
def add_border(self, input, mask_flag, PCONV=True):
|
||
with torch.no_grad():
|
||
h = input.shape[-2]
|
||
w = input.shape[-1]
|
||
require_len_unit = 2 ** self.layer_size
|
||
residual_h = int(np.ceil(h / float(require_len_unit)) * require_len_unit - h) # + 2*require_len_unit
|
||
residual_w = int(np.ceil(w / float(require_len_unit)) * require_len_unit - w) # + 2*require_len_unit
|
||
enlarge_input = torch.zeros((input.shape[0], input.shape[1], h + residual_h, w + residual_w)).to(input.device)
|
||
if mask_flag:
|
||
if PCONV is False:
|
||
enlarge_input += 1.0
|
||
enlarge_input = enlarge_input.clamp(0.0, 1.0)
|
||
else:
|
||
enlarge_input[:, 2, ...] = 0.0
|
||
anchor_h = residual_h//2
|
||
anchor_w = residual_w//2
|
||
enlarge_input[..., anchor_h:anchor_h+h, anchor_w:anchor_w+w] = input
|
||
|
||
return enlarge_input, [anchor_h, anchor_h+h, anchor_w, anchor_w+w]
|
||
|
||
def forward_3P(self, mask, context, depth, edge, unit_length=128, cuda=None):
|
||
with torch.no_grad():
|
||
input = torch.cat((depth, edge, context, mask), dim=1)
|
||
n, c, h, w = input.shape
|
||
residual_h = int(np.ceil(h / float(unit_length)) * unit_length - h)
|
||
residual_w = int(np.ceil(w / float(unit_length)) * unit_length - w)
|
||
anchor_h = residual_h//2
|
||
anchor_w = residual_w//2
|
||
enlarge_input = torch.zeros((n, c, h + residual_h, w + residual_w)).to(cuda)
|
||
enlarge_input[..., anchor_h:anchor_h+h, anchor_w:anchor_w+w] = input
|
||
# enlarge_input[:, 3] = 1. - enlarge_input[:, 3]
|
||
depth_output = self.forward(enlarge_input)
|
||
depth_output = depth_output[..., anchor_h:anchor_h+h, anchor_w:anchor_w+w]
|
||
# import pdb; pdb.set_trace()
|
||
|
||
return depth_output
|
||
|
||
def forward(self, input_feat, refine_border=False, sample=False, PCONV=True):
|
||
input = input_feat
|
||
input_mask = (input_feat[:, -2:-1] + input_feat[:, -1:]).clamp(0, 1).repeat(1, input.shape[1], 1, 1)
|
||
|
||
vis_input = input.cpu().data.numpy()
|
||
vis_input_mask = input_mask.cpu().data.numpy()
|
||
H, W = input.shape[-2:]
|
||
if refine_border is True:
|
||
input, anchor = self.add_border(input, mask_flag=False)
|
||
input_mask, anchor = self.add_border(input_mask, mask_flag=True, PCONV=PCONV)
|
||
h_dict = {} # for the output of enc_N
|
||
h_mask_dict = {} # for the output of enc_N
|
||
h_dict['h_0'], h_mask_dict['h_0'] = input, input_mask
|
||
|
||
h_key_prev = 'h_0'
|
||
for i in range(1, self.layer_size + 1):
|
||
l_key = 'enc_{:d}'.format(i)
|
||
h_key = 'h_{:d}'.format(i)
|
||
h_dict[h_key], h_mask_dict[h_key] = getattr(self, l_key)(
|
||
h_dict[h_key_prev], h_mask_dict[h_key_prev])
|
||
h_key_prev = h_key
|
||
|
||
h_key = 'h_{:d}'.format(self.layer_size)
|
||
h, h_mask = h_dict[h_key], h_mask_dict[h_key]
|
||
|
||
for i in range(self.layer_size, 0, -1):
|
||
enc_h_key = 'h_{:d}'.format(i - 1)
|
||
dec_l_key = 'dec_{:d}'.format(i)
|
||
|
||
h = F.interpolate(h, scale_factor=2, mode=self.upsampling_mode)
|
||
h_mask = F.interpolate(h_mask, scale_factor=2, mode='nearest')
|
||
|
||
h = torch.cat([h, h_dict[enc_h_key]], dim=1)
|
||
h_mask = torch.cat([h_mask, h_mask_dict[enc_h_key]], dim=1)
|
||
h, h_mask = getattr(self, dec_l_key)(h, h_mask)
|
||
output = h
|
||
if refine_border is True:
|
||
h_mask = h_mask[..., anchor[0]:anchor[1], anchor[2]:anchor[3]]
|
||
output = output[..., anchor[0]:anchor[1], anchor[2]:anchor[3]]
|
||
|
||
return output
|
||
|
||
class Inpaint_Edge_Net(BaseNetwork):
|
||
def __init__(self, residual_blocks=8, init_weights=True):
|
||
super(Inpaint_Edge_Net, self).__init__()
|
||
in_channels = 7
|
||
out_channels = 1
|
||
self.encoder = []
|
||
# 0
|
||
self.encoder_0 = nn.Sequential(
|
||
nn.ReflectionPad2d(3),
|
||
spectral_norm(nn.Conv2d(in_channels=in_channels, out_channels=64, kernel_size=7, padding=0), True),
|
||
nn.InstanceNorm2d(64, track_running_stats=False),
|
||
nn.ReLU(True))
|
||
# 1
|
||
self.encoder_1 = nn.Sequential(
|
||
spectral_norm(nn.Conv2d(in_channels=64, out_channels=128, kernel_size=4, stride=2, padding=1), True),
|
||
nn.InstanceNorm2d(128, track_running_stats=False),
|
||
nn.ReLU(True))
|
||
# 2
|
||
self.encoder_2 = nn.Sequential(
|
||
spectral_norm(nn.Conv2d(in_channels=128, out_channels=256, kernel_size=4, stride=2, padding=1), True),
|
||
nn.InstanceNorm2d(256, track_running_stats=False),
|
||
nn.ReLU(True))
|
||
# 3
|
||
blocks = []
|
||
for _ in range(residual_blocks):
|
||
block = ResnetBlock(256, 2)
|
||
blocks.append(block)
|
||
|
||
self.middle = nn.Sequential(*blocks)
|
||
# + 3
|
||
self.decoder_0 = nn.Sequential(
|
||
spectral_norm(nn.ConvTranspose2d(in_channels=256+256, out_channels=128, kernel_size=4, stride=2, padding=1), True),
|
||
nn.InstanceNorm2d(128, track_running_stats=False),
|
||
nn.ReLU(True))
|
||
# + 2
|
||
self.decoder_1 = nn.Sequential(
|
||
spectral_norm(nn.ConvTranspose2d(in_channels=128+128, out_channels=64, kernel_size=4, stride=2, padding=1), True),
|
||
nn.InstanceNorm2d(64, track_running_stats=False),
|
||
nn.ReLU(True))
|
||
# + 1
|
||
self.decoder_2 = nn.Sequential(
|
||
nn.ReflectionPad2d(3),
|
||
nn.Conv2d(in_channels=64+64, out_channels=out_channels, kernel_size=7, padding=0),
|
||
)
|
||
|
||
if init_weights:
|
||
self.init_weights()
|
||
|
||
def add_border(self, input, channel_pad_1=None):
|
||
h = input.shape[-2]
|
||
w = input.shape[-1]
|
||
require_len_unit = 16
|
||
residual_h = int(np.ceil(h / float(require_len_unit)) * require_len_unit - h) # + 2*require_len_unit
|
||
residual_w = int(np.ceil(w / float(require_len_unit)) * require_len_unit - w) # + 2*require_len_unit
|
||
enlarge_input = torch.zeros((input.shape[0], input.shape[1], h + residual_h, w + residual_w)).to(input.device)
|
||
if channel_pad_1 is not None:
|
||
for channel in channel_pad_1:
|
||
enlarge_input[:, channel] = 1
|
||
anchor_h = residual_h//2
|
||
anchor_w = residual_w//2
|
||
enlarge_input[..., anchor_h:anchor_h+h, anchor_w:anchor_w+w] = input
|
||
|
||
return enlarge_input, [anchor_h, anchor_h+h, anchor_w, anchor_w+w]
|
||
|
||
def forward_3P(self, mask, context, rgb, disp, edge, unit_length=128, cuda=None):
|
||
with torch.no_grad():
|
||
input = torch.cat((rgb, disp/disp.max(), edge, context, mask), dim=1)
|
||
n, c, h, w = input.shape
|
||
residual_h = int(np.ceil(h / float(unit_length)) * unit_length - h)
|
||
residual_w = int(np.ceil(w / float(unit_length)) * unit_length - w)
|
||
anchor_h = residual_h//2
|
||
anchor_w = residual_w//2
|
||
enlarge_input = torch.zeros((n, c, h + residual_h, w + residual_w)).to(cuda)
|
||
enlarge_input[..., anchor_h:anchor_h+h, anchor_w:anchor_w+w] = input
|
||
edge_output = self.forward(enlarge_input)
|
||
edge_output = edge_output[..., anchor_h:anchor_h+h, anchor_w:anchor_w+w]
|
||
|
||
return edge_output
|
||
|
||
def forward(self, x, refine_border=False):
|
||
if refine_border:
|
||
x, anchor = self.add_border(x, [5])
|
||
x1 = self.encoder_0(x)
|
||
x2 = self.encoder_1(x1)
|
||
x3 = self.encoder_2(x2)
|
||
x4 = self.middle(x3)
|
||
x5 = self.decoder_0(torch.cat((x4, x3), dim=1))
|
||
x6 = self.decoder_1(torch.cat((x5, x2), dim=1))
|
||
x7 = self.decoder_2(torch.cat((x6, x1), dim=1))
|
||
x = torch.sigmoid(x7)
|
||
if refine_border:
|
||
x = x[..., anchor[0]:anchor[1], anchor[2]:anchor[3]]
|
||
|
||
return x
|
||
|
||
class Inpaint_Color_Net(nn.Module):
|
||
def __init__(self, layer_size=7, upsampling_mode='nearest', add_hole_mask=False, add_two_layer=False, add_border=False):
|
||
super().__init__()
|
||
self.freeze_enc_bn = False
|
||
self.upsampling_mode = upsampling_mode
|
||
self.layer_size = layer_size
|
||
in_channels = 6
|
||
self.enc_1 = PCBActiv(in_channels, 64, bn=False, sample='down-7')
|
||
self.enc_2 = PCBActiv(64, 128, sample='down-5')
|
||
self.enc_3 = PCBActiv(128, 256, sample='down-5')
|
||
self.enc_4 = PCBActiv(256, 512, sample='down-3')
|
||
self.enc_5 = PCBActiv(512, 512, sample='down-3')
|
||
self.enc_6 = PCBActiv(512, 512, sample='down-3')
|
||
self.enc_7 = PCBActiv(512, 512, sample='down-3')
|
||
|
||
self.dec_7 = PCBActiv(512+512, 512, activ='leaky')
|
||
self.dec_6 = PCBActiv(512+512, 512, activ='leaky')
|
||
|
||
self.dec_5A = PCBActiv(512 + 512, 512, activ='leaky')
|
||
self.dec_4A = PCBActiv(512 + 256, 256, activ='leaky')
|
||
self.dec_3A = PCBActiv(256 + 128, 128, activ='leaky')
|
||
self.dec_2A = PCBActiv(128 + 64, 64, activ='leaky')
|
||
self.dec_1A = PCBActiv(64 + in_channels, 3, bn=False, activ=None, conv_bias=True)
|
||
'''
|
||
self.dec_5B = PCBActiv(512 + 512, 512, activ='leaky')
|
||
self.dec_4B = PCBActiv(512 + 256, 256, activ='leaky')
|
||
self.dec_3B = PCBActiv(256 + 128, 128, activ='leaky')
|
||
self.dec_2B = PCBActiv(128 + 64, 64, activ='leaky')
|
||
self.dec_1B = PCBActiv(64 + 4, 1, bn=False, activ=None, conv_bias=True)
|
||
'''
|
||
def cat(self, A, B):
|
||
return torch.cat((A, B), dim=1)
|
||
|
||
def upsample(self, feat, mask):
|
||
feat = F.interpolate(feat, scale_factor=2, mode=self.upsampling_mode)
|
||
mask = F.interpolate(mask, scale_factor=2, mode='nearest')
|
||
|
||
return feat, mask
|
||
|
||
def forward_3P(self, mask, context, rgb, edge, unit_length=128, cuda=None):
|
||
with torch.no_grad():
|
||
input = torch.cat((rgb, edge, context, mask), dim=1)
|
||
n, c, h, w = input.shape
|
||
residual_h = int(np.ceil(h / float(unit_length)) * unit_length - h) # + 128
|
||
residual_w = int(np.ceil(w / float(unit_length)) * unit_length - w) # + 256
|
||
anchor_h = residual_h//2
|
||
anchor_w = residual_w//2
|
||
enlarge_input = torch.zeros((n, c, h + residual_h, w + residual_w)).to(cuda)
|
||
enlarge_input[..., anchor_h:anchor_h+h, anchor_w:anchor_w+w] = input
|
||
# enlarge_input[:, 3] = 1. - enlarge_input[:, 3]
|
||
enlarge_input = enlarge_input.to(cuda)
|
||
rgb_output = self.forward(enlarge_input)
|
||
rgb_output = rgb_output[..., anchor_h:anchor_h+h, anchor_w:anchor_w+w]
|
||
|
||
return rgb_output
|
||
|
||
def forward(self, input, add_border=False):
|
||
input_mask = (input[:, -2:-1] + input[:, -1:]).clamp(0, 1)
|
||
H, W = input.shape[-2:]
|
||
f_0, h_0 = input, input_mask.repeat((1,input.shape[1],1,1))
|
||
f_1, h_1 = self.enc_1(f_0, h_0)
|
||
f_2, h_2 = self.enc_2(f_1, h_1)
|
||
f_3, h_3 = self.enc_3(f_2, h_2)
|
||
f_4, h_4 = self.enc_4(f_3, h_3)
|
||
f_5, h_5 = self.enc_5(f_4, h_4)
|
||
f_6, h_6 = self.enc_6(f_5, h_5)
|
||
f_7, h_7 = self.enc_7(f_6, h_6)
|
||
|
||
o_7, k_7 = self.upsample(f_7, h_7)
|
||
o_6, k_6 = self.dec_7(self.cat(o_7, f_6), self.cat(k_7, h_6))
|
||
o_6, k_6 = self.upsample(o_6, k_6)
|
||
o_5, k_5 = self.dec_6(self.cat(o_6, f_5), self.cat(k_6, h_5))
|
||
o_5, k_5 = self.upsample(o_5, k_5)
|
||
o_5A, k_5A = o_5, k_5
|
||
o_5B, k_5B = o_5, k_5
|
||
###############
|
||
o_4A, k_4A = self.dec_5A(self.cat(o_5A, f_4), self.cat(k_5A, h_4))
|
||
o_4A, k_4A = self.upsample(o_4A, k_4A)
|
||
o_3A, k_3A = self.dec_4A(self.cat(o_4A, f_3), self.cat(k_4A, h_3))
|
||
o_3A, k_3A = self.upsample(o_3A, k_3A)
|
||
o_2A, k_2A = self.dec_3A(self.cat(o_3A, f_2), self.cat(k_3A, h_2))
|
||
o_2A, k_2A = self.upsample(o_2A, k_2A)
|
||
o_1A, k_1A = self.dec_2A(self.cat(o_2A, f_1), self.cat(k_2A, h_1))
|
||
o_1A, k_1A = self.upsample(o_1A, k_1A)
|
||
o_0A, k_0A = self.dec_1A(self.cat(o_1A, f_0), self.cat(k_1A, h_0))
|
||
|
||
return torch.sigmoid(o_0A)
|
||
|
||
def train(self, mode=True):
|
||
"""
|
||
Override the default train() to freeze the BN parameters
|
||
"""
|
||
super().train(mode)
|
||
if self.freeze_enc_bn:
|
||
for name, module in self.named_modules():
|
||
if isinstance(module, nn.BatchNorm2d) and 'enc' in name:
|
||
module.eval()
|
||
|
||
class Discriminator(BaseNetwork):
|
||
def __init__(self, use_sigmoid=True, use_spectral_norm=True, init_weights=True, in_channels=None):
|
||
super(Discriminator, self).__init__()
|
||
self.use_sigmoid = use_sigmoid
|
||
self.conv1 = self.features = nn.Sequential(
|
||
spectral_norm(nn.Conv2d(in_channels=in_channels, out_channels=64, kernel_size=4, stride=2, padding=1, bias=not use_spectral_norm), use_spectral_norm),
|
||
nn.LeakyReLU(0.2, inplace=True),
|
||
)
|
||
|
||
self.conv2 = nn.Sequential(
|
||
spectral_norm(nn.Conv2d(in_channels=64, out_channels=128, kernel_size=4, stride=2, padding=1, bias=not use_spectral_norm), use_spectral_norm),
|
||
nn.LeakyReLU(0.2, inplace=True),
|
||
)
|
||
|
||
self.conv3 = nn.Sequential(
|
||
spectral_norm(nn.Conv2d(in_channels=128, out_channels=256, kernel_size=4, stride=2, padding=1, bias=not use_spectral_norm), use_spectral_norm),
|
||
nn.LeakyReLU(0.2, inplace=True),
|
||
)
|
||
|
||
self.conv4 = nn.Sequential(
|
||
spectral_norm(nn.Conv2d(in_channels=256, out_channels=512, kernel_size=4, stride=1, padding=1, bias=not use_spectral_norm), use_spectral_norm),
|
||
nn.LeakyReLU(0.2, inplace=True),
|
||
)
|
||
|
||
self.conv5 = nn.Sequential(
|
||
spectral_norm(nn.Conv2d(in_channels=512, out_channels=1, kernel_size=4, stride=1, padding=1, bias=not use_spectral_norm), use_spectral_norm),
|
||
)
|
||
|
||
if init_weights:
|
||
self.init_weights()
|
||
|
||
def forward(self, x):
|
||
conv1 = self.conv1(x)
|
||
conv2 = self.conv2(conv1)
|
||
conv3 = self.conv3(conv2)
|
||
conv4 = self.conv4(conv3)
|
||
conv5 = self.conv5(conv4)
|
||
|
||
outputs = conv5
|
||
if self.use_sigmoid:
|
||
outputs = torch.sigmoid(conv5)
|
||
|
||
return outputs, [conv1, conv2, conv3, conv4, conv5]
|
||
|
||
class ResnetBlock(nn.Module):
|
||
def __init__(self, dim, dilation=1):
|
||
super(ResnetBlock, self).__init__()
|
||
self.conv_block = nn.Sequential(
|
||
nn.ReflectionPad2d(dilation),
|
||
spectral_norm(nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=3, padding=0, dilation=dilation, bias=not True), True),
|
||
nn.InstanceNorm2d(dim, track_running_stats=False),
|
||
nn.LeakyReLU(negative_slope=0.2),
|
||
|
||
nn.ReflectionPad2d(1),
|
||
spectral_norm(nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=3, padding=0, dilation=1, bias=not True), True),
|
||
nn.InstanceNorm2d(dim, track_running_stats=False),
|
||
)
|
||
|
||
def forward(self, x):
|
||
out = x + self.conv_block(x)
|
||
|
||
# Remove ReLU at the end of the residual block
|
||
# http://torch.ch/blog/2016/02/04/resnets.html
|
||
|
||
return out
|
||
|
||
|
||
def spectral_norm(module, mode=True):
|
||
if mode:
|
||
return nn.utils.spectral_norm(module)
|
||
|
||
return module
|