improve codes comments

This commit is contained in:
Xintao
2021-11-23 00:52:00 +08:00
parent c9023b3d7a
commit 35ee6f781e
20 changed files with 194 additions and 102 deletions

View File

@@ -1,12 +1,11 @@
name: No Response name: No Response
# TODO: it seems not to work
# Modified from: https://raw.githubusercontent.com/github/docs/main/.github/workflows/no-response.yaml # Modified from: https://raw.githubusercontent.com/github/docs/main/.github/workflows/no-response.yaml
# **What it does**: Closes issues that don't have enough information to be # **What it does**: Closes issues that don't have enough information to be actionable.
# actionable. # **Why we have it**: To remove the need for maintainers to remember to check back on issues periodically
# **Why we have it**: To remove the need for maintainers to remember to check # to see if contributors have responded.
# back on issues periodically to see if contributors have
# responded.
# **Who does it impact**: Everyone that works on docs or docs-internal. # **Who does it impact**: Everyone that works on docs or docs-internal.
on: on:

View File

@@ -8,6 +8,8 @@ from realesrgan import RealESRGANer
def main(): def main():
"""Inference demo for Real-ESRGAN.
"""
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--input', type=str, default='inputs', help='Input image or folder') parser.add_argument('--input', type=str, default='inputs', help='Input image or folder')
parser.add_argument( parser.add_argument(
@@ -53,7 +55,7 @@ def main():
pre_pad=args.pre_pad, pre_pad=args.pre_pad,
half=args.half) half=args.half)
if args.face_enhance: if args.face_enhance: # Use GFPGAN for face enhancement
from gfpgan import GFPGANer from gfpgan import GFPGANer
face_enhancer = GFPGANer( face_enhancer = GFPGANer(
model_path='https://github.com/TencentARC/GFPGAN/releases/download/v0.2.0/GFPGANCleanv1-NoCE-C2.pth', model_path='https://github.com/TencentARC/GFPGAN/releases/download/v0.2.0/GFPGANCleanv1-NoCE-C2.pth',
@@ -78,6 +80,7 @@ def main():
else: else:
img_mode = None img_mode = None
# give warnings for too large/small images
h, w = img.shape[0:2] h, w = img.shape[0:2]
if max(h, w) > 1000 and args.netscale == 4: if max(h, w) > 1000 and args.netscale == 4:
import warnings import warnings

View File

@@ -90,7 +90,6 @@ network_g:
num_block: 23 num_block: 23
num_grow_ch: 32 num_grow_ch: 32
network_d: network_d:
type: UNetDiscriminatorSN type: UNetDiscriminatorSN
num_in_ch: 3 num_in_ch: 3
@@ -169,7 +168,7 @@ train:
# save_img: True # save_img: True
# metrics: # metrics:
# psnr: # metric name, can be arbitrary # psnr: # metric name
# type: calculate_psnr # type: calculate_psnr
# crop_border: 4 # crop_border: 4
# test_y_channel: false # test_y_channel: false

View File

@@ -52,7 +52,6 @@ network_g:
num_block: 23 num_block: 23
num_grow_ch: 32 num_grow_ch: 32
network_d: network_d:
type: UNetDiscriminatorSN type: UNetDiscriminatorSN
num_in_ch: 3 num_in_ch: 3
@@ -131,7 +130,7 @@ train:
# save_img: True # save_img: True
# metrics: # metrics:
# psnr: # metric name, can be arbitrary # psnr: # metric name
# type: calculate_psnr # type: calculate_psnr
# crop_border: 4 # crop_border: 4
# test_y_channel: false # test_y_channel: false

View File

@@ -91,7 +91,6 @@ network_g:
num_grow_ch: 32 num_grow_ch: 32
scale: 2 scale: 2
network_d: network_d:
type: UNetDiscriminatorSN type: UNetDiscriminatorSN
num_in_ch: 3 num_in_ch: 3
@@ -167,7 +166,7 @@ train:
# save_img: True # save_img: True
# metrics: # metrics:
# psnr: # metric name, can be arbitrary # psnr: # metric name
# type: calculate_psnr # type: calculate_psnr
# crop_border: 4 # crop_border: 4
# test_y_channel: false # test_y_channel: false

View File

@@ -90,7 +90,6 @@ network_g:
num_block: 23 num_block: 23
num_grow_ch: 32 num_grow_ch: 32
network_d: network_d:
type: UNetDiscriminatorSN type: UNetDiscriminatorSN
num_in_ch: 3 num_in_ch: 3
@@ -166,7 +165,7 @@ train:
# save_img: True # save_img: True
# metrics: # metrics:
# psnr: # metric name, can be arbitrary # psnr: # metric name
# type: calculate_psnr # type: calculate_psnr
# crop_border: 4 # crop_border: 4
# test_y_channel: false # test_y_channel: false

View File

@@ -125,7 +125,7 @@ train:
# save_img: True # save_img: True
# metrics: # metrics:
# psnr: # metric name, can be arbitrary # psnr: # metric name
# type: calculate_psnr # type: calculate_psnr
# crop_border: 4 # crop_border: 4
# test_y_channel: false # test_y_channel: false

View File

@@ -124,7 +124,7 @@ train:
# save_img: True # save_img: True
# metrics: # metrics:
# psnr: # metric name, can be arbitrary # psnr: # metric name
# type: calculate_psnr # type: calculate_psnr
# crop_border: 4 # crop_border: 4
# test_y_channel: false # test_y_channel: false

View File

@@ -3,4 +3,4 @@ from .archs import *
from .data import * from .data import *
from .models import * from .models import *
from .utils import * from .utils import *
from .version import __gitsha__, __version__ from .version import __version__

View File

@@ -6,15 +6,23 @@ from torch.nn.utils import spectral_norm
@ARCH_REGISTRY.register() @ARCH_REGISTRY.register()
class UNetDiscriminatorSN(nn.Module): class UNetDiscriminatorSN(nn.Module):
"""Defines a U-Net discriminator with spectral normalization (SN)""" """Defines a U-Net discriminator with spectral normalization (SN)
It is used in Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data.
Arg:
num_in_ch (int): Channel number of inputs. Default: 3.
num_feat (int): Channel number of base intermediate features. Default: 64.
skip_connection (bool): Whether to use skip connections between U-Net. Default: True.
"""
def __init__(self, num_in_ch, num_feat=64, skip_connection=True): def __init__(self, num_in_ch, num_feat=64, skip_connection=True):
super(UNetDiscriminatorSN, self).__init__() super(UNetDiscriminatorSN, self).__init__()
self.skip_connection = skip_connection self.skip_connection = skip_connection
norm = spectral_norm norm = spectral_norm
# the first convolution
self.conv0 = nn.Conv2d(num_in_ch, num_feat, kernel_size=3, stride=1, padding=1) self.conv0 = nn.Conv2d(num_in_ch, num_feat, kernel_size=3, stride=1, padding=1)
# downsample
self.conv1 = norm(nn.Conv2d(num_feat, num_feat * 2, 4, 2, 1, bias=False)) self.conv1 = norm(nn.Conv2d(num_feat, num_feat * 2, 4, 2, 1, bias=False))
self.conv2 = norm(nn.Conv2d(num_feat * 2, num_feat * 4, 4, 2, 1, bias=False)) self.conv2 = norm(nn.Conv2d(num_feat * 2, num_feat * 4, 4, 2, 1, bias=False))
self.conv3 = norm(nn.Conv2d(num_feat * 4, num_feat * 8, 4, 2, 1, bias=False)) self.conv3 = norm(nn.Conv2d(num_feat * 4, num_feat * 8, 4, 2, 1, bias=False))
@@ -22,14 +30,13 @@ class UNetDiscriminatorSN(nn.Module):
self.conv4 = norm(nn.Conv2d(num_feat * 8, num_feat * 4, 3, 1, 1, bias=False)) self.conv4 = norm(nn.Conv2d(num_feat * 8, num_feat * 4, 3, 1, 1, bias=False))
self.conv5 = norm(nn.Conv2d(num_feat * 4, num_feat * 2, 3, 1, 1, bias=False)) self.conv5 = norm(nn.Conv2d(num_feat * 4, num_feat * 2, 3, 1, 1, bias=False))
self.conv6 = norm(nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1, bias=False)) self.conv6 = norm(nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1, bias=False))
# extra convolutions
# extra
self.conv7 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False)) self.conv7 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False))
self.conv8 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False)) self.conv8 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False))
self.conv9 = nn.Conv2d(num_feat, 1, 3, 1, 1) self.conv9 = nn.Conv2d(num_feat, 1, 3, 1, 1)
def forward(self, x): def forward(self, x):
# downsample
x0 = F.leaky_relu(self.conv0(x), negative_slope=0.2, inplace=True) x0 = F.leaky_relu(self.conv0(x), negative_slope=0.2, inplace=True)
x1 = F.leaky_relu(self.conv1(x0), negative_slope=0.2, inplace=True) x1 = F.leaky_relu(self.conv1(x0), negative_slope=0.2, inplace=True)
x2 = F.leaky_relu(self.conv2(x1), negative_slope=0.2, inplace=True) x2 = F.leaky_relu(self.conv2(x1), negative_slope=0.2, inplace=True)
@@ -52,7 +59,7 @@ class UNetDiscriminatorSN(nn.Module):
if self.skip_connection: if self.skip_connection:
x6 = x6 + x0 x6 = x6 + x0
# extra # extra convolutions
out = F.leaky_relu(self.conv7(x6), negative_slope=0.2, inplace=True) out = F.leaky_relu(self.conv7(x6), negative_slope=0.2, inplace=True)
out = F.leaky_relu(self.conv8(out), negative_slope=0.2, inplace=True) out = F.leaky_relu(self.conv8(out), negative_slope=0.2, inplace=True)
out = self.conv9(out) out = self.conv9(out)

View File

@@ -15,18 +15,31 @@ from torch.utils import data as data
@DATASET_REGISTRY.register() @DATASET_REGISTRY.register()
class RealESRGANDataset(data.Dataset): class RealESRGANDataset(data.Dataset):
""" """Dataset used for Real-ESRGAN model:
Dataset used for Real-ESRGAN model. Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data.
It loads gt (Ground-Truth) images, and augments them.
It also generates blur kernels and sinc kernels for generating low-quality images.
Note that the low-quality images are processed in tensors on GPUS for faster processing.
Args:
opt (dict): Config for train datasets. It contains the following keys:
dataroot_gt (str): Data root path for gt.
meta_info (str): Path for meta information file.
io_backend (dict): IO backend type and other kwarg.
use_hflip (bool): Use horizontal flips.
use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation).
Please see more options in the codes.
""" """
def __init__(self, opt): def __init__(self, opt):
super(RealESRGANDataset, self).__init__() super(RealESRGANDataset, self).__init__()
self.opt = opt self.opt = opt
# file client (io backend)
self.file_client = None self.file_client = None
self.io_backend_opt = opt['io_backend'] self.io_backend_opt = opt['io_backend']
self.gt_folder = opt['dataroot_gt'] self.gt_folder = opt['dataroot_gt']
# file client (lmdb io backend)
if self.io_backend_opt['type'] == 'lmdb': if self.io_backend_opt['type'] == 'lmdb':
self.io_backend_opt['db_paths'] = [self.gt_folder] self.io_backend_opt['db_paths'] = [self.gt_folder]
self.io_backend_opt['client_keys'] = ['gt'] self.io_backend_opt['client_keys'] = ['gt']
@@ -35,18 +48,20 @@ class RealESRGANDataset(data.Dataset):
with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin: with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin:
self.paths = [line.split('.')[0] for line in fin] self.paths = [line.split('.')[0] for line in fin]
else: else:
# disk backend with meta_info
# Each line in the meta_info describes the relative path to an image
with open(self.opt['meta_info']) as fin: with open(self.opt['meta_info']) as fin:
paths = [line.strip() for line in fin] paths = [line.strip().split(' ')[0] for line in fin]
self.paths = [os.path.join(self.gt_folder, v) for v in paths] self.paths = [os.path.join(self.gt_folder, v) for v in paths]
# blur settings for the first degradation # blur settings for the first degradation
self.blur_kernel_size = opt['blur_kernel_size'] self.blur_kernel_size = opt['blur_kernel_size']
self.kernel_list = opt['kernel_list'] self.kernel_list = opt['kernel_list']
self.kernel_prob = opt['kernel_prob'] self.kernel_prob = opt['kernel_prob'] # a list for each kernel probability
self.blur_sigma = opt['blur_sigma'] self.blur_sigma = opt['blur_sigma']
self.betag_range = opt['betag_range'] self.betag_range = opt['betag_range'] # betag used in generalized Gaussian blur kernels
self.betap_range = opt['betap_range'] self.betap_range = opt['betap_range'] # betap used in plateau blur kernels
self.sinc_prob = opt['sinc_prob'] self.sinc_prob = opt['sinc_prob'] # the probability for sinc filters
# blur settings for the second degradation # blur settings for the second degradation
self.blur_kernel_size2 = opt['blur_kernel_size2'] self.blur_kernel_size2 = opt['blur_kernel_size2']
@@ -61,6 +76,7 @@ class RealESRGANDataset(data.Dataset):
self.final_sinc_prob = opt['final_sinc_prob'] self.final_sinc_prob = opt['final_sinc_prob']
self.kernel_range = [2 * v + 1 for v in range(3, 11)] # kernel size ranges from 7 to 21 self.kernel_range = [2 * v + 1 for v in range(3, 11)] # kernel size ranges from 7 to 21
# TODO: kernel range is now hard-coded, should be in the configure file
self.pulse_tensor = torch.zeros(21, 21).float() # convolving with pulse tensor brings no blurry effect self.pulse_tensor = torch.zeros(21, 21).float() # convolving with pulse tensor brings no blurry effect
self.pulse_tensor[10, 10] = 1 self.pulse_tensor[10, 10] = 1
@@ -89,10 +105,11 @@ class RealESRGANDataset(data.Dataset):
retry -= 1 retry -= 1
img_gt = imfrombytes(img_bytes, float32=True) img_gt = imfrombytes(img_bytes, float32=True)
# -------------------- augmentation for training: flip, rotation -------------------- # # -------------------- Do augmentation for training: flip, rotation -------------------- #
img_gt = augment(img_gt, self.opt['use_hflip'], self.opt['use_rot']) img_gt = augment(img_gt, self.opt['use_hflip'], self.opt['use_rot'])
# crop or pad to 400: 400 is hard-coded. You may change it accordingly # crop or pad to 400
# TODO: 400 is hard-coded. You may change it accordingly
h, w = img_gt.shape[0:2] h, w = img_gt.shape[0:2]
crop_pad_size = 400 crop_pad_size = 400
# pad # pad
@@ -154,7 +171,7 @@ class RealESRGANDataset(data.Dataset):
pad_size = (21 - kernel_size) // 2 pad_size = (21 - kernel_size) // 2
kernel2 = np.pad(kernel2, ((pad_size, pad_size), (pad_size, pad_size))) kernel2 = np.pad(kernel2, ((pad_size, pad_size), (pad_size, pad_size)))
# ------------------------------------- sinc kernel ------------------------------------- # # ------------------------------------- the final sinc kernel ------------------------------------- #
if np.random.uniform() < self.opt['final_sinc_prob']: if np.random.uniform() < self.opt['final_sinc_prob']:
kernel_size = random.choice(self.kernel_range) kernel_size = random.choice(self.kernel_range)
omega_c = np.random.uniform(np.pi / 3, np.pi) omega_c = np.random.uniform(np.pi / 3, np.pi)

View File

@@ -11,8 +11,7 @@ from torchvision.transforms.functional import normalize
class RealESRGANPairedDataset(data.Dataset): class RealESRGANPairedDataset(data.Dataset):
"""Paired image dataset for image restoration. """Paired image dataset for image restoration.
Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc) and Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc) and GT image pairs.
GT image pairs.
There are three modes: There are three modes:
1. 'lmdb': Use lmdb files. 1. 'lmdb': Use lmdb files.
@@ -28,8 +27,8 @@ class RealESRGANPairedDataset(data.Dataset):
dataroot_lq (str): Data root path for lq. dataroot_lq (str): Data root path for lq.
meta_info (str): Path for meta information file. meta_info (str): Path for meta information file.
io_backend (dict): IO backend type and other kwarg. io_backend (dict): IO backend type and other kwarg.
filename_tmpl (str): Template for each filename. Note that the filename_tmpl (str): Template for each filename. Note that the template excludes the file extension.
template excludes the file extension. Default: '{}'. Default: '{}'.
gt_size (int): Cropped patched size for gt patches. gt_size (int): Cropped patched size for gt patches.
use_hflip (bool): Use horizontal flips. use_hflip (bool): Use horizontal flips.
use_rot (bool): Use rotation (use vertical flip and transposing h use_rot (bool): Use rotation (use vertical flip and transposing h
@@ -42,25 +41,25 @@ class RealESRGANPairedDataset(data.Dataset):
def __init__(self, opt): def __init__(self, opt):
super(RealESRGANPairedDataset, self).__init__() super(RealESRGANPairedDataset, self).__init__()
self.opt = opt self.opt = opt
# file client (io backend)
self.file_client = None self.file_client = None
self.io_backend_opt = opt['io_backend'] self.io_backend_opt = opt['io_backend']
# mean and std for normalizing the input images
self.mean = opt['mean'] if 'mean' in opt else None self.mean = opt['mean'] if 'mean' in opt else None
self.std = opt['std'] if 'std' in opt else None self.std = opt['std'] if 'std' in opt else None
self.gt_folder, self.lq_folder = opt['dataroot_gt'], opt['dataroot_lq'] self.gt_folder, self.lq_folder = opt['dataroot_gt'], opt['dataroot_lq']
if 'filename_tmpl' in opt: self.filename_tmpl = opt['filename_tmpl'] if 'filename_tmpl' in opt else '{}'
self.filename_tmpl = opt['filename_tmpl']
else:
self.filename_tmpl = '{}'
# file client (lmdb io backend)
if self.io_backend_opt['type'] == 'lmdb': if self.io_backend_opt['type'] == 'lmdb':
self.io_backend_opt['db_paths'] = [self.lq_folder, self.gt_folder] self.io_backend_opt['db_paths'] = [self.lq_folder, self.gt_folder]
self.io_backend_opt['client_keys'] = ['lq', 'gt'] self.io_backend_opt['client_keys'] = ['lq', 'gt']
self.paths = paired_paths_from_lmdb([self.lq_folder, self.gt_folder], ['lq', 'gt']) self.paths = paired_paths_from_lmdb([self.lq_folder, self.gt_folder], ['lq', 'gt'])
elif 'meta_info' in self.opt and self.opt['meta_info'] is not None: elif 'meta_info' in self.opt and self.opt['meta_info'] is not None:
# disk backend with meta_info
# Each line in the meta_info describes the relative path to an image
with open(self.opt['meta_info']) as fin: with open(self.opt['meta_info']) as fin:
paths = [line.strip() for line in fin] paths = [line.strip().split(' ')[0] for line in fin]
self.paths = [] self.paths = []
for path in paths: for path in paths:
gt_path, lq_path = path.split(', ') gt_path, lq_path = path.split(', ')
@@ -68,6 +67,9 @@ class RealESRGANPairedDataset(data.Dataset):
lq_path = os.path.join(self.lq_folder, lq_path) lq_path = os.path.join(self.lq_folder, lq_path)
self.paths.append(dict([('gt_path', gt_path), ('lq_path', lq_path)])) self.paths.append(dict([('gt_path', gt_path), ('lq_path', lq_path)]))
else: else:
# disk backend
# it will scan the whole folder to get meta info
# it will be time-consuming for folders with too many files. It is recommended using an extra meta txt file
self.paths = paired_paths_from_folder([self.lq_folder, self.gt_folder], ['lq', 'gt'], self.filename_tmpl) self.paths = paired_paths_from_folder([self.lq_folder, self.gt_folder], ['lq', 'gt'], self.filename_tmpl)
def __getitem__(self, index): def __getitem__(self, index):

View File

@@ -13,35 +13,45 @@ from torch.nn import functional as F
@MODEL_REGISTRY.register() @MODEL_REGISTRY.register()
class RealESRGANModel(SRGANModel): class RealESRGANModel(SRGANModel):
"""RealESRGAN Model""" """RealESRGAN Model for Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data.
It mainly performs:
1. randomly synthesize LQ images in GPU tensors
2. optimize the networks with GAN training.
"""
def __init__(self, opt): def __init__(self, opt):
super(RealESRGANModel, self).__init__(opt) super(RealESRGANModel, self).__init__(opt)
self.jpeger = DiffJPEG(differentiable=False).cuda() self.jpeger = DiffJPEG(differentiable=False).cuda() # simulate JPEG compression artifacts
self.usm_sharpener = USMSharp().cuda() self.usm_sharpener = USMSharp().cuda() # do usm sharpening
self.queue_size = opt.get('queue_size', 180) self.queue_size = opt.get('queue_size', 180)
@torch.no_grad() @torch.no_grad()
def _dequeue_and_enqueue(self): def _dequeue_and_enqueue(self):
# training pair pool """It is the training pair pool for increasing the diversity in a batch.
Batch processing limits the diversity of synthetic degradations in a batch. For example, samples in a
batch could not have different resize scaling factors. Therefore, we employ this training pair pool
to increase the degradation diversity in a batch.
"""
# initialize # initialize
b, c, h, w = self.lq.size() b, c, h, w = self.lq.size()
if not hasattr(self, 'queue_lr'): if not hasattr(self, 'queue_lr'):
assert self.queue_size % b == 0, 'queue size should be divisible by batch size' assert self.queue_size % b == 0, f'queue size {self.queue_size} should be divisible by batch size {b}'
self.queue_lr = torch.zeros(self.queue_size, c, h, w).cuda() self.queue_lr = torch.zeros(self.queue_size, c, h, w).cuda()
_, c, h, w = self.gt.size() _, c, h, w = self.gt.size()
self.queue_gt = torch.zeros(self.queue_size, c, h, w).cuda() self.queue_gt = torch.zeros(self.queue_size, c, h, w).cuda()
self.queue_ptr = 0 self.queue_ptr = 0
if self.queue_ptr == self.queue_size: # full if self.queue_ptr == self.queue_size: # the pool is full
# do dequeue and enqueue # do dequeue and enqueue
# shuffle # shuffle
idx = torch.randperm(self.queue_size) idx = torch.randperm(self.queue_size)
self.queue_lr = self.queue_lr[idx] self.queue_lr = self.queue_lr[idx]
self.queue_gt = self.queue_gt[idx] self.queue_gt = self.queue_gt[idx]
# get # get first b samples
lq_dequeue = self.queue_lr[0:b, :, :, :].clone() lq_dequeue = self.queue_lr[0:b, :, :, :].clone()
gt_dequeue = self.queue_gt[0:b, :, :, :].clone() gt_dequeue = self.queue_gt[0:b, :, :, :].clone()
# update # update the queue
self.queue_lr[0:b, :, :, :] = self.lq.clone() self.queue_lr[0:b, :, :, :] = self.lq.clone()
self.queue_gt[0:b, :, :, :] = self.gt.clone() self.queue_gt[0:b, :, :, :] = self.gt.clone()
@@ -55,6 +65,8 @@ class RealESRGANModel(SRGANModel):
@torch.no_grad() @torch.no_grad()
def feed_data(self, data): def feed_data(self, data):
"""Accept data from dataloader, and then add two-order degradations to obtain LQ images.
"""
if self.is_train and self.opt.get('high_order_degradation', True): if self.is_train and self.opt.get('high_order_degradation', True):
# training data synthesis # training data synthesis
self.gt = data['gt'].to(self.device) self.gt = data['gt'].to(self.device)
@@ -79,7 +91,7 @@ class RealESRGANModel(SRGANModel):
scale = 1 scale = 1
mode = random.choice(['area', 'bilinear', 'bicubic']) mode = random.choice(['area', 'bilinear', 'bicubic'])
out = F.interpolate(out, scale_factor=scale, mode=mode) out = F.interpolate(out, scale_factor=scale, mode=mode)
# noise # add noise
gray_noise_prob = self.opt['gray_noise_prob'] gray_noise_prob = self.opt['gray_noise_prob']
if np.random.uniform() < self.opt['gaussian_noise_prob']: if np.random.uniform() < self.opt['gaussian_noise_prob']:
out = random_add_gaussian_noise_pt( out = random_add_gaussian_noise_pt(
@@ -93,7 +105,7 @@ class RealESRGANModel(SRGANModel):
rounds=False) rounds=False)
# JPEG compression # JPEG compression
jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range']) jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range'])
out = torch.clamp(out, 0, 1) out = torch.clamp(out, 0, 1) # clamp to [0, 1], otherwise JPEGer will result in unpleasant artifacts
out = self.jpeger(out, quality=jpeg_p) out = self.jpeger(out, quality=jpeg_p)
# ----------------------- The second degradation process ----------------------- # # ----------------------- The second degradation process ----------------------- #
@@ -111,7 +123,7 @@ class RealESRGANModel(SRGANModel):
mode = random.choice(['area', 'bilinear', 'bicubic']) mode = random.choice(['area', 'bilinear', 'bicubic'])
out = F.interpolate( out = F.interpolate(
out, size=(int(ori_h / self.opt['scale'] * scale), int(ori_w / self.opt['scale'] * scale)), mode=mode) out, size=(int(ori_h / self.opt['scale'] * scale), int(ori_w / self.opt['scale'] * scale)), mode=mode)
# noise # add noise
gray_noise_prob = self.opt['gray_noise_prob2'] gray_noise_prob = self.opt['gray_noise_prob2']
if np.random.uniform() < self.opt['gaussian_noise_prob2']: if np.random.uniform() < self.opt['gaussian_noise_prob2']:
out = random_add_gaussian_noise_pt( out = random_add_gaussian_noise_pt(
@@ -162,7 +174,9 @@ class RealESRGANModel(SRGANModel):
self._dequeue_and_enqueue() self._dequeue_and_enqueue()
# sharpen self.gt again, as we have changed the self.gt with self._dequeue_and_enqueue # sharpen self.gt again, as we have changed the self.gt with self._dequeue_and_enqueue
self.gt_usm = self.usm_sharpener(self.gt) self.gt_usm = self.usm_sharpener(self.gt)
self.lq = self.lq.contiguous() # for the warning: grad and param do not obey the gradient layout contract
else: else:
# for paired training or validation
self.lq = data['lq'].to(self.device) self.lq = data['lq'].to(self.device)
if 'gt' in data: if 'gt' in data:
self.gt = data['gt'].to(self.device) self.gt = data['gt'].to(self.device)
@@ -175,6 +189,7 @@ class RealESRGANModel(SRGANModel):
self.is_train = True self.is_train = True
def optimize_parameters(self, current_iter): def optimize_parameters(self, current_iter):
# usm sharpening
l1_gt = self.gt_usm l1_gt = self.gt_usm
percep_gt = self.gt_usm percep_gt = self.gt_usm
gan_gt = self.gt_usm gan_gt = self.gt_usm

View File

@@ -12,35 +12,46 @@ from torch.nn import functional as F
@MODEL_REGISTRY.register() @MODEL_REGISTRY.register()
class RealESRNetModel(SRModel): class RealESRNetModel(SRModel):
"""RealESRNet Model""" """RealESRNet Model for Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data.
It is trained without GAN losses.
It mainly performs:
1. randomly synthesize LQ images in GPU tensors
2. optimize the networks with GAN training.
"""
def __init__(self, opt): def __init__(self, opt):
super(RealESRNetModel, self).__init__(opt) super(RealESRNetModel, self).__init__(opt)
self.jpeger = DiffJPEG(differentiable=False).cuda() self.jpeger = DiffJPEG(differentiable=False).cuda() # simulate JPEG compression artifacts
self.usm_sharpener = USMSharp().cuda() self.usm_sharpener = USMSharp().cuda() # do usm sharpening
self.queue_size = opt.get('queue_size', 180) self.queue_size = opt.get('queue_size', 180)
@torch.no_grad() @torch.no_grad()
def _dequeue_and_enqueue(self): def _dequeue_and_enqueue(self):
# training pair pool """It is the training pair pool for increasing the diversity in a batch.
Batch processing limits the diversity of synthetic degradations in a batch. For example, samples in a
batch could not have different resize scaling factors. Therefore, we employ this training pair pool
to increase the degradation diversity in a batch.
"""
# initialize # initialize
b, c, h, w = self.lq.size() b, c, h, w = self.lq.size()
if not hasattr(self, 'queue_lr'): if not hasattr(self, 'queue_lr'):
assert self.queue_size % b == 0, 'queue size should be divisible by batch size' assert self.queue_size % b == 0, f'queue size {self.queue_size} should be divisible by batch size {b}'
self.queue_lr = torch.zeros(self.queue_size, c, h, w).cuda() self.queue_lr = torch.zeros(self.queue_size, c, h, w).cuda()
_, c, h, w = self.gt.size() _, c, h, w = self.gt.size()
self.queue_gt = torch.zeros(self.queue_size, c, h, w).cuda() self.queue_gt = torch.zeros(self.queue_size, c, h, w).cuda()
self.queue_ptr = 0 self.queue_ptr = 0
if self.queue_ptr == self.queue_size: # full if self.queue_ptr == self.queue_size: # the pool is full
# do dequeue and enqueue # do dequeue and enqueue
# shuffle # shuffle
idx = torch.randperm(self.queue_size) idx = torch.randperm(self.queue_size)
self.queue_lr = self.queue_lr[idx] self.queue_lr = self.queue_lr[idx]
self.queue_gt = self.queue_gt[idx] self.queue_gt = self.queue_gt[idx]
# get # get first b samples
lq_dequeue = self.queue_lr[0:b, :, :, :].clone() lq_dequeue = self.queue_lr[0:b, :, :, :].clone()
gt_dequeue = self.queue_gt[0:b, :, :, :].clone() gt_dequeue = self.queue_gt[0:b, :, :, :].clone()
# update # update the queue
self.queue_lr[0:b, :, :, :] = self.lq.clone() self.queue_lr[0:b, :, :, :] = self.lq.clone()
self.queue_gt[0:b, :, :, :] = self.gt.clone() self.queue_gt[0:b, :, :, :] = self.gt.clone()
@@ -54,10 +65,12 @@ class RealESRNetModel(SRModel):
@torch.no_grad() @torch.no_grad()
def feed_data(self, data): def feed_data(self, data):
"""Accept data from dataloader, and then add two-order degradations to obtain LQ images.
"""
if self.is_train and self.opt.get('high_order_degradation', True): if self.is_train and self.opt.get('high_order_degradation', True):
# training data synthesis # training data synthesis
self.gt = data['gt'].to(self.device) self.gt = data['gt'].to(self.device)
# USM the GT images # USM sharpen the GT images
if self.opt['gt_usm'] is True: if self.opt['gt_usm'] is True:
self.gt = self.usm_sharpener(self.gt) self.gt = self.usm_sharpener(self.gt)
@@ -80,7 +93,7 @@ class RealESRNetModel(SRModel):
scale = 1 scale = 1
mode = random.choice(['area', 'bilinear', 'bicubic']) mode = random.choice(['area', 'bilinear', 'bicubic'])
out = F.interpolate(out, scale_factor=scale, mode=mode) out = F.interpolate(out, scale_factor=scale, mode=mode)
# noise # add noise
gray_noise_prob = self.opt['gray_noise_prob'] gray_noise_prob = self.opt['gray_noise_prob']
if np.random.uniform() < self.opt['gaussian_noise_prob']: if np.random.uniform() < self.opt['gaussian_noise_prob']:
out = random_add_gaussian_noise_pt( out = random_add_gaussian_noise_pt(
@@ -94,7 +107,7 @@ class RealESRNetModel(SRModel):
rounds=False) rounds=False)
# JPEG compression # JPEG compression
jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range']) jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range'])
out = torch.clamp(out, 0, 1) out = torch.clamp(out, 0, 1) # clamp to [0, 1], otherwise JPEGer will result in unpleasant artifacts
out = self.jpeger(out, quality=jpeg_p) out = self.jpeger(out, quality=jpeg_p)
# ----------------------- The second degradation process ----------------------- # # ----------------------- The second degradation process ----------------------- #
@@ -112,7 +125,7 @@ class RealESRNetModel(SRModel):
mode = random.choice(['area', 'bilinear', 'bicubic']) mode = random.choice(['area', 'bilinear', 'bicubic'])
out = F.interpolate( out = F.interpolate(
out, size=(int(ori_h / self.opt['scale'] * scale), int(ori_w / self.opt['scale'] * scale)), mode=mode) out, size=(int(ori_h / self.opt['scale'] * scale), int(ori_w / self.opt['scale'] * scale)), mode=mode)
# noise # add noise
gray_noise_prob = self.opt['gray_noise_prob2'] gray_noise_prob = self.opt['gray_noise_prob2']
if np.random.uniform() < self.opt['gaussian_noise_prob2']: if np.random.uniform() < self.opt['gaussian_noise_prob2']:
out = random_add_gaussian_noise_pt( out = random_add_gaussian_noise_pt(
@@ -160,7 +173,9 @@ class RealESRNetModel(SRModel):
# training pair pool # training pair pool
self._dequeue_and_enqueue() self._dequeue_and_enqueue()
self.lq = self.lq.contiguous() # for the warning: grad and param do not obey the gradient layout contract
else: else:
# for paired training or validation
self.lq = data['lq'].to(self.device) self.lq = data['lq'].to(self.device)
if 'gt' in data: if 'gt' in data:
self.gt = data['gt'].to(self.device) self.gt = data['gt'].to(self.device)

View File

@@ -12,6 +12,19 @@ ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
class RealESRGANer(): class RealESRGANer():
"""A helper class for upsampling images with RealESRGAN.
Args:
scale (int): Upsampling scale factor used in the networks. It is usually 2 or 4.
model_path (str): The path to the pretrained model. It can be urls (will first download it automatically).
model (nn.Module): The defined network. If None, the model will be constructed here. Default: None.
tile (int): As too large images result in the out of GPU memory issue, so this tile option will first crop
input images into tiles, and then process each of them. Finally, they will be merged into one image.
0 denotes for do not use tile. Default: 0.
tile_pad (int): The pad size for each tile, to remove border artifacts. Default: 10.
pre_pad (int): Pad the input images to avoid border artifacts. Default: 10.
half (float): Whether to use half precision during inference. Default: False.
"""
def __init__(self, scale, model_path, model=None, tile=0, tile_pad=10, pre_pad=10, half=False): def __init__(self, scale, model_path, model=None, tile=0, tile_pad=10, pre_pad=10, half=False):
self.scale = scale self.scale = scale
@@ -26,10 +39,12 @@ class RealESRGANer():
if model is None: if model is None:
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=scale) model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=scale)
# if the model_path starts with https, it will first download models to the folder: realesrgan/weights
if model_path.startswith('https://'): if model_path.startswith('https://'):
model_path = load_file_from_url( model_path = load_file_from_url(
url=model_path, model_dir='realesrgan/weights', progress=True, file_name=None) url=model_path, model_dir='realesrgan/weights', progress=True, file_name=None)
loadnet = torch.load(model_path) loadnet = torch.load(model_path)
# prefer to use params_ema
if 'params_ema' in loadnet: if 'params_ema' in loadnet:
keyname = 'params_ema' keyname = 'params_ema'
else: else:
@@ -41,6 +56,8 @@ class RealESRGANer():
self.model = self.model.half() self.model = self.model.half()
def pre_process(self, img): def pre_process(self, img):
"""Pre-process, such as pre-pad and mod pad, so that the images can be divisible
"""
img = torch.from_numpy(np.transpose(img, (2, 0, 1))).float() img = torch.from_numpy(np.transpose(img, (2, 0, 1))).float()
self.img = img.unsqueeze(0).to(self.device) self.img = img.unsqueeze(0).to(self.device)
if self.half: if self.half:
@@ -49,7 +66,7 @@ class RealESRGANer():
# pre_pad # pre_pad
if self.pre_pad != 0: if self.pre_pad != 0:
self.img = F.pad(self.img, (0, self.pre_pad, 0, self.pre_pad), 'reflect') self.img = F.pad(self.img, (0, self.pre_pad, 0, self.pre_pad), 'reflect')
# mod pad # mod pad for divisible borders
if self.scale == 2: if self.scale == 2:
self.mod_scale = 2 self.mod_scale = 2
elif self.scale == 1: elif self.scale == 1:
@@ -64,10 +81,14 @@ class RealESRGANer():
self.img = F.pad(self.img, (0, self.mod_pad_w, 0, self.mod_pad_h), 'reflect') self.img = F.pad(self.img, (0, self.mod_pad_w, 0, self.mod_pad_h), 'reflect')
def process(self): def process(self):
# model inference
self.output = self.model(self.img) self.output = self.model(self.img)
def tile_process(self): def tile_process(self):
"""Modified from: https://github.com/ata4/esrgan-launcher """It will first crop input images to tiles, and then process each tile.
Finally, all the processed tiles are merged into one images.
Modified from: https://github.com/ata4/esrgan-launcher
""" """
batch, channel, height, width = self.img.shape batch, channel, height, width = self.img.shape
output_height = height * self.scale output_height = height * self.scale
@@ -188,7 +209,7 @@ class RealESRGANer():
output_alpha = output_alpha.data.squeeze().float().cpu().clamp_(0, 1).numpy() output_alpha = output_alpha.data.squeeze().float().cpu().clamp_(0, 1).numpy()
output_alpha = np.transpose(output_alpha[[2, 1, 0], :, :], (1, 2, 0)) output_alpha = np.transpose(output_alpha[[2, 1, 0], :, :], (1, 2, 0))
output_alpha = cv2.cvtColor(output_alpha, cv2.COLOR_BGR2GRAY) output_alpha = cv2.cvtColor(output_alpha, cv2.COLOR_BGR2GRAY)
else: else: # use the cv2 resize for alpha channel
h, w = alpha.shape[0:2] h, w = alpha.shape[0:2]
output_alpha = cv2.resize(alpha, (w * self.scale, h * self.scale), interpolation=cv2.INTER_LINEAR) output_alpha = cv2.resize(alpha, (w * self.scale, h * self.scale), interpolation=cv2.INTER_LINEAR)
@@ -213,7 +234,9 @@ class RealESRGANer():
def load_file_from_url(url, model_dir=None, progress=True, file_name=None): def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
"""Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py """Load file form http url, will download models if necessary.
Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py
""" """
if model_dir is None: if model_dir is None:
hub_dir = get_dir() hub_dir = get_dir()

View File

@@ -14,34 +14,24 @@ def main(args):
opt (dict): Configuration dict. It contains: opt (dict): Configuration dict. It contains:
n_thread (int): Thread number. n_thread (int): Thread number.
compression_level (int): CV_IMWRITE_PNG_COMPRESSION from 0 to 9. compression_level (int): CV_IMWRITE_PNG_COMPRESSION from 0 to 9. A higher value means a smaller size
A higher value means a smaller size and longer compression time. and longer compression time. Use 0 for faster CPU decompression. Default: 3, same in cv2.
Use 0 for faster CPU decompression. Default: 3, same in cv2.
input_folder (str): Path to the input folder. input_folder (str): Path to the input folder.
save_folder (str): Path to save folder. save_folder (str): Path to save folder.
crop_size (int): Crop size. crop_size (int): Crop size.
step (int): Step for overlapped sliding window. step (int): Step for overlapped sliding window.
thresh_size (int): Threshold size. Patches whose size is lower thresh_size (int): Threshold size. Patches whose size is lower than thresh_size will be dropped.
than thresh_size will be dropped.
Usage: Usage:
For each folder, run this script. For each folder, run this script.
Typically, there are four folders to be processed for DIV2K dataset. Typically, there are GT folder and LQ folder to be processed for DIV2K dataset.
DIV2K_train_HR After process, each sub_folder should have the same number of subimages.
DIV2K_train_LR_bicubic/X2
DIV2K_train_LR_bicubic/X3
DIV2K_train_LR_bicubic/X4
After process, each sub_folder should have the same number of
subimages.
Remember to modify opt configurations according to your settings. Remember to modify opt configurations according to your settings.
""" """
opt = {} opt = {}
opt['n_thread'] = args.n_thread opt['n_thread'] = args.n_thread
opt['compression_level'] = args.compression_level opt['compression_level'] = args.compression_level
# HR images
opt['input_folder'] = args.input opt['input_folder'] = args.input
opt['save_folder'] = args.output opt['save_folder'] = args.output
opt['crop_size'] = args.crop_size opt['crop_size'] = args.crop_size
@@ -68,6 +58,7 @@ def extract_subimages(opt):
print(f'Folder {save_folder} already exists. Exit.') print(f'Folder {save_folder} already exists. Exit.')
sys.exit(1) sys.exit(1)
# scan all images
img_list = list(scandir(input_folder, full_path=True)) img_list = list(scandir(input_folder, full_path=True))
pbar = tqdm(total=len(img_list), unit='image', desc='Extract') pbar = tqdm(total=len(img_list), unit='image', desc='Extract')
@@ -88,8 +79,7 @@ def worker(path, opt):
opt (dict): Configuration dict. It contains: opt (dict): Configuration dict. It contains:
crop_size (int): Crop size. crop_size (int): Crop size.
step (int): Step for overlapped sliding window. step (int): Step for overlapped sliding window.
thresh_size (int): Threshold size. Patches whose size is lower thresh_size (int): Threshold size. Patches whose size is lower than thresh_size will be dropped.
than thresh_size will be dropped.
save_folder (str): Path to save folder. save_folder (str): Path to save folder.
compression_level (int): for cv2.IMWRITE_PNG_COMPRESSION. compression_level (int): for cv2.IMWRITE_PNG_COMPRESSION.

View File

@@ -11,6 +11,7 @@ def main(args):
for img_path in img_paths: for img_path in img_paths:
status = True status = True
if args.check: if args.check:
# read the image once for check, as some images may have errors
try: try:
img = cv2.imread(img_path) img = cv2.imread(img_path)
except Exception as error: except Exception as error:
@@ -20,6 +21,7 @@ def main(args):
status = False status = False
print(f'Img is None: {img_path}') print(f'Img is None: {img_path}')
if status: if status:
# get the relative path
img_name = os.path.relpath(img_path, root) img_name = os.path.relpath(img_path, root)
print(img_name) print(img_name)
txt_file.write(f'{img_name}\n') txt_file.write(f'{img_name}\n')

View File

@@ -5,6 +5,7 @@ import os
def main(args): def main(args):
txt_file = open(args.meta_info, 'w') txt_file = open(args.meta_info, 'w')
# sca images
img_paths_gt = sorted(glob.glob(os.path.join(args.input[0], '*'))) img_paths_gt = sorted(glob.glob(os.path.join(args.input[0], '*')))
img_paths_lq = sorted(glob.glob(os.path.join(args.input[1], '*'))) img_paths_lq = sorted(glob.glob(os.path.join(args.input[1], '*')))
@@ -12,6 +13,7 @@ def main(args):
f'{len(img_paths_gt)} and {len(img_paths_lq)}.') f'{len(img_paths_gt)} and {len(img_paths_lq)}.')
for img_path_gt, img_path_lq in zip(img_paths_gt, img_paths_lq): for img_path_gt, img_path_lq in zip(img_paths_gt, img_paths_lq):
# get the relative paths
img_name_gt = os.path.relpath(img_path_gt, args.root[0]) img_name_gt = os.path.relpath(img_path_gt, args.root[0])
img_name_lq = os.path.relpath(img_path_lq, args.root[1]) img_name_lq = os.path.relpath(img_path_lq, args.root[1])
print(f'{img_name_gt}, {img_name_lq}') print(f'{img_name_gt}, {img_name_lq}')
@@ -19,7 +21,7 @@ def main(args):
if __name__ == '__main__': if __name__ == '__main__':
"""Generate meta info (txt file) for paired images. """This script is used to generate meta info (txt file) for paired images.
""" """
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument(

View File

@@ -5,7 +5,6 @@ from PIL import Image
def main(args): def main(args):
# For DF2K, we consider the following three scales, # For DF2K, we consider the following three scales,
# and the smallest image whose shortest edge is 400 # and the smallest image whose shortest edge is 400
scale_list = [0.75, 0.5, 1 / 3] scale_list = [0.75, 0.5, 1 / 3]
@@ -37,6 +36,9 @@ def main(args):
if __name__ == '__main__': if __name__ == '__main__':
"""Generate multi-scale versions for GT images with LANCZOS resampling.
It is now used for DF2K dataset (DIV2K + Flickr 2K)
"""
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--input', type=str, default='datasets/DF2K/DF2K_HR', help='Input folder') parser.add_argument('--input', type=str, default='datasets/DF2K/DF2K_HR', help='Input folder')
parser.add_argument('--output', type=str, default='datasets/DF2K/DF2K_multiscale', help='Output folder') parser.add_argument('--output', type=str, default='datasets/DF2K/DF2K_multiscale', help='Output folder')

View File

@@ -1,17 +1,36 @@
import argparse
import torch import torch
import torch.onnx import torch.onnx
from basicsr.archs.rrdbnet_arch import RRDBNet from basicsr.archs.rrdbnet_arch import RRDBNet
# An instance of your model
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
model.load_state_dict(torch.load('experiments/pretrained_models/RealESRGAN_x4plus.pth')['params_ema'])
# set the train mode to false since we will only run the forward pass.
model.train(False)
model.cpu().eval()
# An example input you would normally provide to your model's forward() method def main(args):
x = torch.rand(1, 3, 64, 64) # An instance of the model
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
if args.params:
keyname = 'params'
else:
keyname = 'params_ema'
model.load_state_dict(torch.load(args.input)[keyname])
# set the train mode to false since we will only run the forward pass.
model.train(False)
model.cpu().eval()
# Export the model # An example input
with torch.no_grad(): x = torch.rand(1, 3, 64, 64)
torch_out = torch.onnx._export(model, x, 'realesrgan-x4.onnx', opset_version=11, export_params=True) # Export the model
with torch.no_grad():
torch_out = torch.onnx._export(model, x, args.output, opset_version=11, export_params=True)
print(torch_out.shape)
if __name__ == '__main__':
"""Convert pytorch model to onnx models"""
parser = argparse.ArgumentParser()
parser.add_argument(
'--input', type=str, default='experiments/pretrained_models/RealESRGAN_x4plus.pth', help='Input model path')
parser.add_argument('--output', type=str, default='realesrgan-x4.onnx', help='Output onnx path')
parser.add_argument('--params', action='store_false', help='Use params instead of params_ema')
args = parser.parse_args()
main(args)