improve codes comments

This commit is contained in:
Xintao
2021-11-23 00:52:00 +08:00
parent c9023b3d7a
commit 35ee6f781e
20 changed files with 194 additions and 102 deletions

View File

@@ -6,15 +6,23 @@ from torch.nn.utils import spectral_norm
@ARCH_REGISTRY.register()
class UNetDiscriminatorSN(nn.Module):
"""Defines a U-Net discriminator with spectral normalization (SN)"""
"""Defines a U-Net discriminator with spectral normalization (SN)
It is used in Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data.
Arg:
num_in_ch (int): Channel number of inputs. Default: 3.
num_feat (int): Channel number of base intermediate features. Default: 64.
skip_connection (bool): Whether to use skip connections between U-Net. Default: True.
"""
def __init__(self, num_in_ch, num_feat=64, skip_connection=True):
super(UNetDiscriminatorSN, self).__init__()
self.skip_connection = skip_connection
norm = spectral_norm
# the first convolution
self.conv0 = nn.Conv2d(num_in_ch, num_feat, kernel_size=3, stride=1, padding=1)
# downsample
self.conv1 = norm(nn.Conv2d(num_feat, num_feat * 2, 4, 2, 1, bias=False))
self.conv2 = norm(nn.Conv2d(num_feat * 2, num_feat * 4, 4, 2, 1, bias=False))
self.conv3 = norm(nn.Conv2d(num_feat * 4, num_feat * 8, 4, 2, 1, bias=False))
@@ -22,14 +30,13 @@ class UNetDiscriminatorSN(nn.Module):
self.conv4 = norm(nn.Conv2d(num_feat * 8, num_feat * 4, 3, 1, 1, bias=False))
self.conv5 = norm(nn.Conv2d(num_feat * 4, num_feat * 2, 3, 1, 1, bias=False))
self.conv6 = norm(nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1, bias=False))
# extra
# extra convolutions
self.conv7 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False))
self.conv8 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False))
self.conv9 = nn.Conv2d(num_feat, 1, 3, 1, 1)
def forward(self, x):
# downsample
x0 = F.leaky_relu(self.conv0(x), negative_slope=0.2, inplace=True)
x1 = F.leaky_relu(self.conv1(x0), negative_slope=0.2, inplace=True)
x2 = F.leaky_relu(self.conv2(x1), negative_slope=0.2, inplace=True)
@@ -52,7 +59,7 @@ class UNetDiscriminatorSN(nn.Module):
if self.skip_connection:
x6 = x6 + x0
# extra
# extra convolutions
out = F.leaky_relu(self.conv7(x6), negative_slope=0.2, inplace=True)
out = F.leaky_relu(self.conv8(out), negative_slope=0.2, inplace=True)
out = self.conv9(out)