From 35ee6f781e9a5a80d5f2f1efb9102c9899a81ae1 Mon Sep 17 00:00:00 2001 From: Xintao Date: Tue, 23 Nov 2021 00:52:00 +0800 Subject: [PATCH] improve codes comments --- .github/workflows/no-response.yml | 9 ++-- inference_realesrgan.py | 5 ++- options/finetune_realesrgan_x4plus.yml | 3 +- .../finetune_realesrgan_x4plus_pairdata.yml | 3 +- options/train_realesrgan_x2plus.yml | 3 +- options/train_realesrgan_x4plus.yml | 3 +- options/train_realesrnet_x2plus.yml | 2 +- options/train_realesrnet_x4plus.yml | 2 +- realesrgan/__init__.py | 2 +- realesrgan/archs/discriminator_arch.py | 21 ++++++---- realesrgan/data/realesrgan_dataset.py | 39 +++++++++++++----- realesrgan/data/realesrgan_paired_dataset.py | 22 +++++----- realesrgan/models/realesrgan_model.py | 37 ++++++++++++----- realesrgan/models/realesrnet_model.py | 39 ++++++++++++------ realesrgan/utils.py | 31 ++++++++++++-- scripts/extract_subimages.py | 24 ++++------- scripts/generate_meta_info.py | 2 + scripts/generate_meta_info_pairdata.py | 4 +- scripts/generate_multiscale_DF2K.py | 4 +- scripts/pytorch2onnx.py | 41 ++++++++++++++----- 20 files changed, 194 insertions(+), 102 deletions(-) diff --git a/.github/workflows/no-response.yml b/.github/workflows/no-response.yml index de76ce6..fa702ee 100644 --- a/.github/workflows/no-response.yml +++ b/.github/workflows/no-response.yml @@ -1,12 +1,11 @@ name: No Response +# TODO: it seems not to work # Modified from: https://raw.githubusercontent.com/github/docs/main/.github/workflows/no-response.yaml -# **What it does**: Closes issues that don't have enough information to be -# actionable. -# **Why we have it**: To remove the need for maintainers to remember to check -# back on issues periodically to see if contributors have -# responded. +# **What it does**: Closes issues that don't have enough information to be actionable. +# **Why we have it**: To remove the need for maintainers to remember to check back on issues periodically +# to see if contributors have responded. # **Who does it impact**: Everyone that works on docs or docs-internal. on: diff --git a/inference_realesrgan.py b/inference_realesrgan.py index 644ecf0..819a1b1 100644 --- a/inference_realesrgan.py +++ b/inference_realesrgan.py @@ -8,6 +8,8 @@ from realesrgan import RealESRGANer def main(): + """Inference demo for Real-ESRGAN. + """ parser = argparse.ArgumentParser() parser.add_argument('--input', type=str, default='inputs', help='Input image or folder') parser.add_argument( @@ -53,7 +55,7 @@ def main(): pre_pad=args.pre_pad, half=args.half) - if args.face_enhance: + if args.face_enhance: # Use GFPGAN for face enhancement from gfpgan import GFPGANer face_enhancer = GFPGANer( model_path='https://github.com/TencentARC/GFPGAN/releases/download/v0.2.0/GFPGANCleanv1-NoCE-C2.pth', @@ -78,6 +80,7 @@ def main(): else: img_mode = None + # give warnings for too large/small images h, w = img.shape[0:2] if max(h, w) > 1000 and args.netscale == 4: import warnings diff --git a/options/finetune_realesrgan_x4plus.yml b/options/finetune_realesrgan_x4plus.yml index c4ff3fc..aa98065 100644 --- a/options/finetune_realesrgan_x4plus.yml +++ b/options/finetune_realesrgan_x4plus.yml @@ -90,7 +90,6 @@ network_g: num_block: 23 num_grow_ch: 32 - network_d: type: UNetDiscriminatorSN num_in_ch: 3 @@ -169,7 +168,7 @@ train: # save_img: True # metrics: -# psnr: # metric name, can be arbitrary +# psnr: # metric name # type: calculate_psnr # crop_border: 4 # test_y_channel: false diff --git a/options/finetune_realesrgan_x4plus_pairdata.yml b/options/finetune_realesrgan_x4plus_pairdata.yml index b10ca31..db45d4d 100644 --- a/options/finetune_realesrgan_x4plus_pairdata.yml +++ b/options/finetune_realesrgan_x4plus_pairdata.yml @@ -52,7 +52,6 @@ network_g: num_block: 23 num_grow_ch: 32 - network_d: type: UNetDiscriminatorSN num_in_ch: 3 @@ -131,7 +130,7 @@ train: # save_img: True # metrics: -# psnr: # metric name, can be arbitrary +# psnr: # metric name # type: calculate_psnr # crop_border: 4 # test_y_channel: false diff --git a/options/train_realesrgan_x2plus.yml b/options/train_realesrgan_x2plus.yml index 7573d7f..3c98a0f 100644 --- a/options/train_realesrgan_x2plus.yml +++ b/options/train_realesrgan_x2plus.yml @@ -91,7 +91,6 @@ network_g: num_grow_ch: 32 scale: 2 - network_d: type: UNetDiscriminatorSN num_in_ch: 3 @@ -167,7 +166,7 @@ train: # save_img: True # metrics: -# psnr: # metric name, can be arbitrary +# psnr: # metric name # type: calculate_psnr # crop_border: 4 # test_y_channel: false diff --git a/options/train_realesrgan_x4plus.yml b/options/train_realesrgan_x4plus.yml index f4b768a..763199a 100644 --- a/options/train_realesrgan_x4plus.yml +++ b/options/train_realesrgan_x4plus.yml @@ -90,7 +90,6 @@ network_g: num_block: 23 num_grow_ch: 32 - network_d: type: UNetDiscriminatorSN num_in_ch: 3 @@ -166,7 +165,7 @@ train: # save_img: True # metrics: -# psnr: # metric name, can be arbitrary +# psnr: # metric name # type: calculate_psnr # crop_border: 4 # test_y_channel: false diff --git a/options/train_realesrnet_x2plus.yml b/options/train_realesrnet_x2plus.yml index a3838a9..81ee9ef 100644 --- a/options/train_realesrnet_x2plus.yml +++ b/options/train_realesrnet_x2plus.yml @@ -125,7 +125,7 @@ train: # save_img: True # metrics: -# psnr: # metric name, can be arbitrary +# psnr: # metric name # type: calculate_psnr # crop_border: 4 # test_y_channel: false diff --git a/options/train_realesrnet_x4plus.yml b/options/train_realesrnet_x4plus.yml index 81bf156..45670ed 100644 --- a/options/train_realesrnet_x4plus.yml +++ b/options/train_realesrnet_x4plus.yml @@ -124,7 +124,7 @@ train: # save_img: True # metrics: -# psnr: # metric name, can be arbitrary +# psnr: # metric name # type: calculate_psnr # crop_border: 4 # test_y_channel: false diff --git a/realesrgan/__init__.py b/realesrgan/__init__.py index 4ccac57..f3c0535 100644 --- a/realesrgan/__init__.py +++ b/realesrgan/__init__.py @@ -3,4 +3,4 @@ from .archs import * from .data import * from .models import * from .utils import * -from .version import __gitsha__, __version__ +from .version import __version__ diff --git a/realesrgan/archs/discriminator_arch.py b/realesrgan/archs/discriminator_arch.py index 397b9af..4b66ab1 100644 --- a/realesrgan/archs/discriminator_arch.py +++ b/realesrgan/archs/discriminator_arch.py @@ -6,15 +6,23 @@ from torch.nn.utils import spectral_norm @ARCH_REGISTRY.register() class UNetDiscriminatorSN(nn.Module): - """Defines a U-Net discriminator with spectral normalization (SN)""" + """Defines a U-Net discriminator with spectral normalization (SN) + + It is used in Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data. + + Arg: + num_in_ch (int): Channel number of inputs. Default: 3. + num_feat (int): Channel number of base intermediate features. Default: 64. + skip_connection (bool): Whether to use skip connections between U-Net. Default: True. + """ def __init__(self, num_in_ch, num_feat=64, skip_connection=True): super(UNetDiscriminatorSN, self).__init__() self.skip_connection = skip_connection norm = spectral_norm - + # the first convolution self.conv0 = nn.Conv2d(num_in_ch, num_feat, kernel_size=3, stride=1, padding=1) - + # downsample self.conv1 = norm(nn.Conv2d(num_feat, num_feat * 2, 4, 2, 1, bias=False)) self.conv2 = norm(nn.Conv2d(num_feat * 2, num_feat * 4, 4, 2, 1, bias=False)) self.conv3 = norm(nn.Conv2d(num_feat * 4, num_feat * 8, 4, 2, 1, bias=False)) @@ -22,14 +30,13 @@ class UNetDiscriminatorSN(nn.Module): self.conv4 = norm(nn.Conv2d(num_feat * 8, num_feat * 4, 3, 1, 1, bias=False)) self.conv5 = norm(nn.Conv2d(num_feat * 4, num_feat * 2, 3, 1, 1, bias=False)) self.conv6 = norm(nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1, bias=False)) - - # extra + # extra convolutions self.conv7 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False)) self.conv8 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False)) - self.conv9 = nn.Conv2d(num_feat, 1, 3, 1, 1) def forward(self, x): + # downsample x0 = F.leaky_relu(self.conv0(x), negative_slope=0.2, inplace=True) x1 = F.leaky_relu(self.conv1(x0), negative_slope=0.2, inplace=True) x2 = F.leaky_relu(self.conv2(x1), negative_slope=0.2, inplace=True) @@ -52,7 +59,7 @@ class UNetDiscriminatorSN(nn.Module): if self.skip_connection: x6 = x6 + x0 - # extra + # extra convolutions out = F.leaky_relu(self.conv7(x6), negative_slope=0.2, inplace=True) out = F.leaky_relu(self.conv8(out), negative_slope=0.2, inplace=True) out = self.conv9(out) diff --git a/realesrgan/data/realesrgan_dataset.py b/realesrgan/data/realesrgan_dataset.py index 570fc46..e4c1109 100644 --- a/realesrgan/data/realesrgan_dataset.py +++ b/realesrgan/data/realesrgan_dataset.py @@ -15,18 +15,31 @@ from torch.utils import data as data @DATASET_REGISTRY.register() class RealESRGANDataset(data.Dataset): - """ - Dataset used for Real-ESRGAN model. + """Dataset used for Real-ESRGAN model: + Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data. + + It loads gt (Ground-Truth) images, and augments them. + It also generates blur kernels and sinc kernels for generating low-quality images. + Note that the low-quality images are processed in tensors on GPUS for faster processing. + + Args: + opt (dict): Config for train datasets. It contains the following keys: + dataroot_gt (str): Data root path for gt. + meta_info (str): Path for meta information file. + io_backend (dict): IO backend type and other kwarg. + use_hflip (bool): Use horizontal flips. + use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation). + Please see more options in the codes. """ def __init__(self, opt): super(RealESRGANDataset, self).__init__() self.opt = opt - # file client (io backend) self.file_client = None self.io_backend_opt = opt['io_backend'] self.gt_folder = opt['dataroot_gt'] + # file client (lmdb io backend) if self.io_backend_opt['type'] == 'lmdb': self.io_backend_opt['db_paths'] = [self.gt_folder] self.io_backend_opt['client_keys'] = ['gt'] @@ -35,18 +48,20 @@ class RealESRGANDataset(data.Dataset): with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin: self.paths = [line.split('.')[0] for line in fin] else: + # disk backend with meta_info + # Each line in the meta_info describes the relative path to an image with open(self.opt['meta_info']) as fin: - paths = [line.strip() for line in fin] + paths = [line.strip().split(' ')[0] for line in fin] self.paths = [os.path.join(self.gt_folder, v) for v in paths] # blur settings for the first degradation self.blur_kernel_size = opt['blur_kernel_size'] self.kernel_list = opt['kernel_list'] - self.kernel_prob = opt['kernel_prob'] + self.kernel_prob = opt['kernel_prob'] # a list for each kernel probability self.blur_sigma = opt['blur_sigma'] - self.betag_range = opt['betag_range'] - self.betap_range = opt['betap_range'] - self.sinc_prob = opt['sinc_prob'] + self.betag_range = opt['betag_range'] # betag used in generalized Gaussian blur kernels + self.betap_range = opt['betap_range'] # betap used in plateau blur kernels + self.sinc_prob = opt['sinc_prob'] # the probability for sinc filters # blur settings for the second degradation self.blur_kernel_size2 = opt['blur_kernel_size2'] @@ -61,6 +76,7 @@ class RealESRGANDataset(data.Dataset): self.final_sinc_prob = opt['final_sinc_prob'] self.kernel_range = [2 * v + 1 for v in range(3, 11)] # kernel size ranges from 7 to 21 + # TODO: kernel range is now hard-coded, should be in the configure file self.pulse_tensor = torch.zeros(21, 21).float() # convolving with pulse tensor brings no blurry effect self.pulse_tensor[10, 10] = 1 @@ -89,10 +105,11 @@ class RealESRGANDataset(data.Dataset): retry -= 1 img_gt = imfrombytes(img_bytes, float32=True) - # -------------------- augmentation for training: flip, rotation -------------------- # + # -------------------- Do augmentation for training: flip, rotation -------------------- # img_gt = augment(img_gt, self.opt['use_hflip'], self.opt['use_rot']) - # crop or pad to 400: 400 is hard-coded. You may change it accordingly + # crop or pad to 400 + # TODO: 400 is hard-coded. You may change it accordingly h, w = img_gt.shape[0:2] crop_pad_size = 400 # pad @@ -154,7 +171,7 @@ class RealESRGANDataset(data.Dataset): pad_size = (21 - kernel_size) // 2 kernel2 = np.pad(kernel2, ((pad_size, pad_size), (pad_size, pad_size))) - # ------------------------------------- sinc kernel ------------------------------------- # + # ------------------------------------- the final sinc kernel ------------------------------------- # if np.random.uniform() < self.opt['final_sinc_prob']: kernel_size = random.choice(self.kernel_range) omega_c = np.random.uniform(np.pi / 3, np.pi) diff --git a/realesrgan/data/realesrgan_paired_dataset.py b/realesrgan/data/realesrgan_paired_dataset.py index b450c43..c8deb33 100644 --- a/realesrgan/data/realesrgan_paired_dataset.py +++ b/realesrgan/data/realesrgan_paired_dataset.py @@ -11,8 +11,7 @@ from torchvision.transforms.functional import normalize class RealESRGANPairedDataset(data.Dataset): """Paired image dataset for image restoration. - Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc) and - GT image pairs. + Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc) and GT image pairs. There are three modes: 1. 'lmdb': Use lmdb files. @@ -28,8 +27,8 @@ class RealESRGANPairedDataset(data.Dataset): dataroot_lq (str): Data root path for lq. meta_info (str): Path for meta information file. io_backend (dict): IO backend type and other kwarg. - filename_tmpl (str): Template for each filename. Note that the - template excludes the file extension. Default: '{}'. + filename_tmpl (str): Template for each filename. Note that the template excludes the file extension. + Default: '{}'. gt_size (int): Cropped patched size for gt patches. use_hflip (bool): Use horizontal flips. use_rot (bool): Use rotation (use vertical flip and transposing h @@ -42,25 +41,25 @@ class RealESRGANPairedDataset(data.Dataset): def __init__(self, opt): super(RealESRGANPairedDataset, self).__init__() self.opt = opt - # file client (io backend) self.file_client = None self.io_backend_opt = opt['io_backend'] + # mean and std for normalizing the input images self.mean = opt['mean'] if 'mean' in opt else None self.std = opt['std'] if 'std' in opt else None self.gt_folder, self.lq_folder = opt['dataroot_gt'], opt['dataroot_lq'] - if 'filename_tmpl' in opt: - self.filename_tmpl = opt['filename_tmpl'] - else: - self.filename_tmpl = '{}' + self.filename_tmpl = opt['filename_tmpl'] if 'filename_tmpl' in opt else '{}' + # file client (lmdb io backend) if self.io_backend_opt['type'] == 'lmdb': self.io_backend_opt['db_paths'] = [self.lq_folder, self.gt_folder] self.io_backend_opt['client_keys'] = ['lq', 'gt'] self.paths = paired_paths_from_lmdb([self.lq_folder, self.gt_folder], ['lq', 'gt']) elif 'meta_info' in self.opt and self.opt['meta_info'] is not None: + # disk backend with meta_info + # Each line in the meta_info describes the relative path to an image with open(self.opt['meta_info']) as fin: - paths = [line.strip() for line in fin] + paths = [line.strip().split(' ')[0] for line in fin] self.paths = [] for path in paths: gt_path, lq_path = path.split(', ') @@ -68,6 +67,9 @@ class RealESRGANPairedDataset(data.Dataset): lq_path = os.path.join(self.lq_folder, lq_path) self.paths.append(dict([('gt_path', gt_path), ('lq_path', lq_path)])) else: + # disk backend + # it will scan the whole folder to get meta info + # it will be time-consuming for folders with too many files. It is recommended using an extra meta txt file self.paths = paired_paths_from_folder([self.lq_folder, self.gt_folder], ['lq', 'gt'], self.filename_tmpl) def __getitem__(self, index): diff --git a/realesrgan/models/realesrgan_model.py b/realesrgan/models/realesrgan_model.py index c1813cf..c298a09 100644 --- a/realesrgan/models/realesrgan_model.py +++ b/realesrgan/models/realesrgan_model.py @@ -13,35 +13,45 @@ from torch.nn import functional as F @MODEL_REGISTRY.register() class RealESRGANModel(SRGANModel): - """RealESRGAN Model""" + """RealESRGAN Model for Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data. + + It mainly performs: + 1. randomly synthesize LQ images in GPU tensors + 2. optimize the networks with GAN training. + """ def __init__(self, opt): super(RealESRGANModel, self).__init__(opt) - self.jpeger = DiffJPEG(differentiable=False).cuda() - self.usm_sharpener = USMSharp().cuda() + self.jpeger = DiffJPEG(differentiable=False).cuda() # simulate JPEG compression artifacts + self.usm_sharpener = USMSharp().cuda() # do usm sharpening self.queue_size = opt.get('queue_size', 180) @torch.no_grad() def _dequeue_and_enqueue(self): - # training pair pool + """It is the training pair pool for increasing the diversity in a batch. + + Batch processing limits the diversity of synthetic degradations in a batch. For example, samples in a + batch could not have different resize scaling factors. Therefore, we employ this training pair pool + to increase the degradation diversity in a batch. + """ # initialize b, c, h, w = self.lq.size() if not hasattr(self, 'queue_lr'): - assert self.queue_size % b == 0, 'queue size should be divisible by batch size' + assert self.queue_size % b == 0, f'queue size {self.queue_size} should be divisible by batch size {b}' self.queue_lr = torch.zeros(self.queue_size, c, h, w).cuda() _, c, h, w = self.gt.size() self.queue_gt = torch.zeros(self.queue_size, c, h, w).cuda() self.queue_ptr = 0 - if self.queue_ptr == self.queue_size: # full + if self.queue_ptr == self.queue_size: # the pool is full # do dequeue and enqueue # shuffle idx = torch.randperm(self.queue_size) self.queue_lr = self.queue_lr[idx] self.queue_gt = self.queue_gt[idx] - # get + # get first b samples lq_dequeue = self.queue_lr[0:b, :, :, :].clone() gt_dequeue = self.queue_gt[0:b, :, :, :].clone() - # update + # update the queue self.queue_lr[0:b, :, :, :] = self.lq.clone() self.queue_gt[0:b, :, :, :] = self.gt.clone() @@ -55,6 +65,8 @@ class RealESRGANModel(SRGANModel): @torch.no_grad() def feed_data(self, data): + """Accept data from dataloader, and then add two-order degradations to obtain LQ images. + """ if self.is_train and self.opt.get('high_order_degradation', True): # training data synthesis self.gt = data['gt'].to(self.device) @@ -79,7 +91,7 @@ class RealESRGANModel(SRGANModel): scale = 1 mode = random.choice(['area', 'bilinear', 'bicubic']) out = F.interpolate(out, scale_factor=scale, mode=mode) - # noise + # add noise gray_noise_prob = self.opt['gray_noise_prob'] if np.random.uniform() < self.opt['gaussian_noise_prob']: out = random_add_gaussian_noise_pt( @@ -93,7 +105,7 @@ class RealESRGANModel(SRGANModel): rounds=False) # JPEG compression jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range']) - out = torch.clamp(out, 0, 1) + out = torch.clamp(out, 0, 1) # clamp to [0, 1], otherwise JPEGer will result in unpleasant artifacts out = self.jpeger(out, quality=jpeg_p) # ----------------------- The second degradation process ----------------------- # @@ -111,7 +123,7 @@ class RealESRGANModel(SRGANModel): mode = random.choice(['area', 'bilinear', 'bicubic']) out = F.interpolate( out, size=(int(ori_h / self.opt['scale'] * scale), int(ori_w / self.opt['scale'] * scale)), mode=mode) - # noise + # add noise gray_noise_prob = self.opt['gray_noise_prob2'] if np.random.uniform() < self.opt['gaussian_noise_prob2']: out = random_add_gaussian_noise_pt( @@ -162,7 +174,9 @@ class RealESRGANModel(SRGANModel): self._dequeue_and_enqueue() # sharpen self.gt again, as we have changed the self.gt with self._dequeue_and_enqueue self.gt_usm = self.usm_sharpener(self.gt) + self.lq = self.lq.contiguous() # for the warning: grad and param do not obey the gradient layout contract else: + # for paired training or validation self.lq = data['lq'].to(self.device) if 'gt' in data: self.gt = data['gt'].to(self.device) @@ -175,6 +189,7 @@ class RealESRGANModel(SRGANModel): self.is_train = True def optimize_parameters(self, current_iter): + # usm sharpening l1_gt = self.gt_usm percep_gt = self.gt_usm gan_gt = self.gt_usm diff --git a/realesrgan/models/realesrnet_model.py b/realesrgan/models/realesrnet_model.py index 2129dd4..d11668f 100644 --- a/realesrgan/models/realesrnet_model.py +++ b/realesrgan/models/realesrnet_model.py @@ -12,35 +12,46 @@ from torch.nn import functional as F @MODEL_REGISTRY.register() class RealESRNetModel(SRModel): - """RealESRNet Model""" + """RealESRNet Model for Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data. + + It is trained without GAN losses. + It mainly performs: + 1. randomly synthesize LQ images in GPU tensors + 2. optimize the networks with GAN training. + """ def __init__(self, opt): super(RealESRNetModel, self).__init__(opt) - self.jpeger = DiffJPEG(differentiable=False).cuda() - self.usm_sharpener = USMSharp().cuda() + self.jpeger = DiffJPEG(differentiable=False).cuda() # simulate JPEG compression artifacts + self.usm_sharpener = USMSharp().cuda() # do usm sharpening self.queue_size = opt.get('queue_size', 180) @torch.no_grad() def _dequeue_and_enqueue(self): - # training pair pool + """It is the training pair pool for increasing the diversity in a batch. + + Batch processing limits the diversity of synthetic degradations in a batch. For example, samples in a + batch could not have different resize scaling factors. Therefore, we employ this training pair pool + to increase the degradation diversity in a batch. + """ # initialize b, c, h, w = self.lq.size() if not hasattr(self, 'queue_lr'): - assert self.queue_size % b == 0, 'queue size should be divisible by batch size' + assert self.queue_size % b == 0, f'queue size {self.queue_size} should be divisible by batch size {b}' self.queue_lr = torch.zeros(self.queue_size, c, h, w).cuda() _, c, h, w = self.gt.size() self.queue_gt = torch.zeros(self.queue_size, c, h, w).cuda() self.queue_ptr = 0 - if self.queue_ptr == self.queue_size: # full + if self.queue_ptr == self.queue_size: # the pool is full # do dequeue and enqueue # shuffle idx = torch.randperm(self.queue_size) self.queue_lr = self.queue_lr[idx] self.queue_gt = self.queue_gt[idx] - # get + # get first b samples lq_dequeue = self.queue_lr[0:b, :, :, :].clone() gt_dequeue = self.queue_gt[0:b, :, :, :].clone() - # update + # update the queue self.queue_lr[0:b, :, :, :] = self.lq.clone() self.queue_gt[0:b, :, :, :] = self.gt.clone() @@ -54,10 +65,12 @@ class RealESRNetModel(SRModel): @torch.no_grad() def feed_data(self, data): + """Accept data from dataloader, and then add two-order degradations to obtain LQ images. + """ if self.is_train and self.opt.get('high_order_degradation', True): # training data synthesis self.gt = data['gt'].to(self.device) - # USM the GT images + # USM sharpen the GT images if self.opt['gt_usm'] is True: self.gt = self.usm_sharpener(self.gt) @@ -80,7 +93,7 @@ class RealESRNetModel(SRModel): scale = 1 mode = random.choice(['area', 'bilinear', 'bicubic']) out = F.interpolate(out, scale_factor=scale, mode=mode) - # noise + # add noise gray_noise_prob = self.opt['gray_noise_prob'] if np.random.uniform() < self.opt['gaussian_noise_prob']: out = random_add_gaussian_noise_pt( @@ -94,7 +107,7 @@ class RealESRNetModel(SRModel): rounds=False) # JPEG compression jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range']) - out = torch.clamp(out, 0, 1) + out = torch.clamp(out, 0, 1) # clamp to [0, 1], otherwise JPEGer will result in unpleasant artifacts out = self.jpeger(out, quality=jpeg_p) # ----------------------- The second degradation process ----------------------- # @@ -112,7 +125,7 @@ class RealESRNetModel(SRModel): mode = random.choice(['area', 'bilinear', 'bicubic']) out = F.interpolate( out, size=(int(ori_h / self.opt['scale'] * scale), int(ori_w / self.opt['scale'] * scale)), mode=mode) - # noise + # add noise gray_noise_prob = self.opt['gray_noise_prob2'] if np.random.uniform() < self.opt['gaussian_noise_prob2']: out = random_add_gaussian_noise_pt( @@ -160,7 +173,9 @@ class RealESRNetModel(SRModel): # training pair pool self._dequeue_and_enqueue() + self.lq = self.lq.contiguous() # for the warning: grad and param do not obey the gradient layout contract else: + # for paired training or validation self.lq = data['lq'].to(self.device) if 'gt' in data: self.gt = data['gt'].to(self.device) diff --git a/realesrgan/utils.py b/realesrgan/utils.py index a815cb3..802d391 100644 --- a/realesrgan/utils.py +++ b/realesrgan/utils.py @@ -12,6 +12,19 @@ ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) class RealESRGANer(): + """A helper class for upsampling images with RealESRGAN. + + Args: + scale (int): Upsampling scale factor used in the networks. It is usually 2 or 4. + model_path (str): The path to the pretrained model. It can be urls (will first download it automatically). + model (nn.Module): The defined network. If None, the model will be constructed here. Default: None. + tile (int): As too large images result in the out of GPU memory issue, so this tile option will first crop + input images into tiles, and then process each of them. Finally, they will be merged into one image. + 0 denotes for do not use tile. Default: 0. + tile_pad (int): The pad size for each tile, to remove border artifacts. Default: 10. + pre_pad (int): Pad the input images to avoid border artifacts. Default: 10. + half (float): Whether to use half precision during inference. Default: False. + """ def __init__(self, scale, model_path, model=None, tile=0, tile_pad=10, pre_pad=10, half=False): self.scale = scale @@ -26,10 +39,12 @@ class RealESRGANer(): if model is None: model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=scale) + # if the model_path starts with https, it will first download models to the folder: realesrgan/weights if model_path.startswith('https://'): model_path = load_file_from_url( url=model_path, model_dir='realesrgan/weights', progress=True, file_name=None) loadnet = torch.load(model_path) + # prefer to use params_ema if 'params_ema' in loadnet: keyname = 'params_ema' else: @@ -41,6 +56,8 @@ class RealESRGANer(): self.model = self.model.half() def pre_process(self, img): + """Pre-process, such as pre-pad and mod pad, so that the images can be divisible + """ img = torch.from_numpy(np.transpose(img, (2, 0, 1))).float() self.img = img.unsqueeze(0).to(self.device) if self.half: @@ -49,7 +66,7 @@ class RealESRGANer(): # pre_pad if self.pre_pad != 0: self.img = F.pad(self.img, (0, self.pre_pad, 0, self.pre_pad), 'reflect') - # mod pad + # mod pad for divisible borders if self.scale == 2: self.mod_scale = 2 elif self.scale == 1: @@ -64,10 +81,14 @@ class RealESRGANer(): self.img = F.pad(self.img, (0, self.mod_pad_w, 0, self.mod_pad_h), 'reflect') def process(self): + # model inference self.output = self.model(self.img) def tile_process(self): - """Modified from: https://github.com/ata4/esrgan-launcher + """It will first crop input images to tiles, and then process each tile. + Finally, all the processed tiles are merged into one images. + + Modified from: https://github.com/ata4/esrgan-launcher """ batch, channel, height, width = self.img.shape output_height = height * self.scale @@ -188,7 +209,7 @@ class RealESRGANer(): output_alpha = output_alpha.data.squeeze().float().cpu().clamp_(0, 1).numpy() output_alpha = np.transpose(output_alpha[[2, 1, 0], :, :], (1, 2, 0)) output_alpha = cv2.cvtColor(output_alpha, cv2.COLOR_BGR2GRAY) - else: + else: # use the cv2 resize for alpha channel h, w = alpha.shape[0:2] output_alpha = cv2.resize(alpha, (w * self.scale, h * self.scale), interpolation=cv2.INTER_LINEAR) @@ -213,7 +234,9 @@ class RealESRGANer(): def load_file_from_url(url, model_dir=None, progress=True, file_name=None): - """Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py + """Load file form http url, will download models if necessary. + + Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py """ if model_dir is None: hub_dir = get_dir() diff --git a/scripts/extract_subimages.py b/scripts/extract_subimages.py index 0b4d264..9b969ae 100644 --- a/scripts/extract_subimages.py +++ b/scripts/extract_subimages.py @@ -14,34 +14,24 @@ def main(args): opt (dict): Configuration dict. It contains: n_thread (int): Thread number. - compression_level (int): CV_IMWRITE_PNG_COMPRESSION from 0 to 9. - A higher value means a smaller size and longer compression time. - Use 0 for faster CPU decompression. Default: 3, same in cv2. - + compression_level (int): CV_IMWRITE_PNG_COMPRESSION from 0 to 9. A higher value means a smaller size + and longer compression time. Use 0 for faster CPU decompression. Default: 3, same in cv2. input_folder (str): Path to the input folder. save_folder (str): Path to save folder. crop_size (int): Crop size. step (int): Step for overlapped sliding window. - thresh_size (int): Threshold size. Patches whose size is lower - than thresh_size will be dropped. + thresh_size (int): Threshold size. Patches whose size is lower than thresh_size will be dropped. Usage: For each folder, run this script. - Typically, there are four folders to be processed for DIV2K dataset. - DIV2K_train_HR - DIV2K_train_LR_bicubic/X2 - DIV2K_train_LR_bicubic/X3 - DIV2K_train_LR_bicubic/X4 - After process, each sub_folder should have the same number of - subimages. + Typically, there are GT folder and LQ folder to be processed for DIV2K dataset. + After process, each sub_folder should have the same number of subimages. Remember to modify opt configurations according to your settings. """ opt = {} opt['n_thread'] = args.n_thread opt['compression_level'] = args.compression_level - - # HR images opt['input_folder'] = args.input opt['save_folder'] = args.output opt['crop_size'] = args.crop_size @@ -68,6 +58,7 @@ def extract_subimages(opt): print(f'Folder {save_folder} already exists. Exit.') sys.exit(1) + # scan all images img_list = list(scandir(input_folder, full_path=True)) pbar = tqdm(total=len(img_list), unit='image', desc='Extract') @@ -88,8 +79,7 @@ def worker(path, opt): opt (dict): Configuration dict. It contains: crop_size (int): Crop size. step (int): Step for overlapped sliding window. - thresh_size (int): Threshold size. Patches whose size is lower - than thresh_size will be dropped. + thresh_size (int): Threshold size. Patches whose size is lower than thresh_size will be dropped. save_folder (str): Path to save folder. compression_level (int): for cv2.IMWRITE_PNG_COMPRESSION. diff --git a/scripts/generate_meta_info.py b/scripts/generate_meta_info.py index b294f7b..51d028a 100644 --- a/scripts/generate_meta_info.py +++ b/scripts/generate_meta_info.py @@ -11,6 +11,7 @@ def main(args): for img_path in img_paths: status = True if args.check: + # read the image once for check, as some images may have errors try: img = cv2.imread(img_path) except Exception as error: @@ -20,6 +21,7 @@ def main(args): status = False print(f'Img is None: {img_path}') if status: + # get the relative path img_name = os.path.relpath(img_path, root) print(img_name) txt_file.write(f'{img_name}\n') diff --git a/scripts/generate_meta_info_pairdata.py b/scripts/generate_meta_info_pairdata.py index 4d4bf1a..76dce7e 100644 --- a/scripts/generate_meta_info_pairdata.py +++ b/scripts/generate_meta_info_pairdata.py @@ -5,6 +5,7 @@ import os def main(args): txt_file = open(args.meta_info, 'w') + # sca images img_paths_gt = sorted(glob.glob(os.path.join(args.input[0], '*'))) img_paths_lq = sorted(glob.glob(os.path.join(args.input[1], '*'))) @@ -12,6 +13,7 @@ def main(args): f'{len(img_paths_gt)} and {len(img_paths_lq)}.') for img_path_gt, img_path_lq in zip(img_paths_gt, img_paths_lq): + # get the relative paths img_name_gt = os.path.relpath(img_path_gt, args.root[0]) img_name_lq = os.path.relpath(img_path_lq, args.root[1]) print(f'{img_name_gt}, {img_name_lq}') @@ -19,7 +21,7 @@ def main(args): if __name__ == '__main__': - """Generate meta info (txt file) for paired images. + """This script is used to generate meta info (txt file) for paired images. """ parser = argparse.ArgumentParser() parser.add_argument( diff --git a/scripts/generate_multiscale_DF2K.py b/scripts/generate_multiscale_DF2K.py index 919c61f..d4f5d83 100644 --- a/scripts/generate_multiscale_DF2K.py +++ b/scripts/generate_multiscale_DF2K.py @@ -5,7 +5,6 @@ from PIL import Image def main(args): - # For DF2K, we consider the following three scales, # and the smallest image whose shortest edge is 400 scale_list = [0.75, 0.5, 1 / 3] @@ -37,6 +36,9 @@ def main(args): if __name__ == '__main__': + """Generate multi-scale versions for GT images with LANCZOS resampling. + It is now used for DF2K dataset (DIV2K + Flickr 2K) + """ parser = argparse.ArgumentParser() parser.add_argument('--input', type=str, default='datasets/DF2K/DF2K_HR', help='Input folder') parser.add_argument('--output', type=str, default='datasets/DF2K/DF2K_multiscale', help='Output folder') diff --git a/scripts/pytorch2onnx.py b/scripts/pytorch2onnx.py index dc2ec0a..09d99b2 100644 --- a/scripts/pytorch2onnx.py +++ b/scripts/pytorch2onnx.py @@ -1,17 +1,36 @@ +import argparse import torch import torch.onnx from basicsr.archs.rrdbnet_arch import RRDBNet -# An instance of your model -model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4) -model.load_state_dict(torch.load('experiments/pretrained_models/RealESRGAN_x4plus.pth')['params_ema']) -# set the train mode to false since we will only run the forward pass. -model.train(False) -model.cpu().eval() -# An example input you would normally provide to your model's forward() method -x = torch.rand(1, 3, 64, 64) +def main(args): + # An instance of the model + model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4) + if args.params: + keyname = 'params' + else: + keyname = 'params_ema' + model.load_state_dict(torch.load(args.input)[keyname]) + # set the train mode to false since we will only run the forward pass. + model.train(False) + model.cpu().eval() -# Export the model -with torch.no_grad(): - torch_out = torch.onnx._export(model, x, 'realesrgan-x4.onnx', opset_version=11, export_params=True) + # An example input + x = torch.rand(1, 3, 64, 64) + # Export the model + with torch.no_grad(): + torch_out = torch.onnx._export(model, x, args.output, opset_version=11, export_params=True) + print(torch_out.shape) + + +if __name__ == '__main__': + """Convert pytorch model to onnx models""" + parser = argparse.ArgumentParser() + parser.add_argument( + '--input', type=str, default='experiments/pretrained_models/RealESRGAN_x4plus.pth', help='Input model path') + parser.add_argument('--output', type=str, default='realesrgan-x4.onnx', help='Output onnx path') + parser.add_argument('--params', action='store_false', help='Use params instead of params_ema') + args = parser.parse_args() + + main(args)