From 9baeba566b390cc803ed2e2d8c12e00022a46584 Mon Sep 17 00:00:00 2001 From: Xintao Date: Sat, 24 Jul 2021 18:56:20 +0800 Subject: [PATCH] add Real-ESRGAN training codes --- README.md | 2 +- archs/__init__.py | 10 ++ archs/discriminator_arch.py | 60 +++++++ data/realesrgan_dataset.py | 175 ++++++++++++++++++++ inference_realesrgan.py | 4 +- models/realesrgan_model.py | 240 ++++++++++++++++++++++++++++ models/realesrnet_model.py | 172 ++++++++++++++++++++ options/train_realesrgan_x4plus.yml | 186 +++++++++++++++++++++ options/train_realesrnet_x4plus.yml | 144 +++++++++++++++++ train.py | 10 ++ 10 files changed, 999 insertions(+), 4 deletions(-) create mode 100644 archs/__init__.py create mode 100644 archs/discriminator_arch.py create mode 100644 data/realesrgan_dataset.py create mode 100644 models/realesrgan_model.py create mode 100644 models/realesrnet_model.py create mode 100644 options/train_realesrgan_x4plus.yml create mode 100644 options/train_realesrnet_x4plus.yml create mode 100644 train.py diff --git a/README.md b/README.md index c9fc818..b9ff864 100644 --- a/README.md +++ b/README.md @@ -46,7 +46,7 @@ You can simply run the following command: Note that it may introduce block inconsistency (and also generate slightly different results from the PyTorch implementation), because this executable file first crops the input image into several tiles, and then processes them separately, finally stitches together. -This executable file is based on the wonderful [Tecent/ncnn](https://github.com/Tencent/ncnn) and [realsr-ncnn-vulkan](https://github.com/nihui/realsr-ncnn-vulkan) by [nihui](https://github.com/nihui). +This executable file is based on the wonderful [Tencent/ncnn](https://github.com/Tencent/ncnn) and [realsr-ncnn-vulkan](https://github.com/nihui/realsr-ncnn-vulkan) by [nihui](https://github.com/nihui). --- diff --git a/archs/__init__.py b/archs/__init__.py new file mode 100644 index 0000000..4ec725e --- /dev/null +++ b/archs/__init__.py @@ -0,0 +1,10 @@ +import importlib +from basicsr.utils import scandir +from os import path as osp + +# automatically scan and import arch modules for registry +# scan all the files that end with '_arch.py' under the archs folder +arch_folder = osp.dirname(osp.abspath(__file__)) +arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')] +# import all the arch modules +_arch_modules = [importlib.import_module(f'archs.{file_name}') for file_name in arch_filenames] diff --git a/archs/discriminator_arch.py b/archs/discriminator_arch.py new file mode 100644 index 0000000..397b9af --- /dev/null +++ b/archs/discriminator_arch.py @@ -0,0 +1,60 @@ +from basicsr.utils.registry import ARCH_REGISTRY +from torch import nn as nn +from torch.nn import functional as F +from torch.nn.utils import spectral_norm + + +@ARCH_REGISTRY.register() +class UNetDiscriminatorSN(nn.Module): + """Defines a U-Net discriminator with spectral normalization (SN)""" + + def __init__(self, num_in_ch, num_feat=64, skip_connection=True): + super(UNetDiscriminatorSN, self).__init__() + self.skip_connection = skip_connection + norm = spectral_norm + + self.conv0 = nn.Conv2d(num_in_ch, num_feat, kernel_size=3, stride=1, padding=1) + + self.conv1 = norm(nn.Conv2d(num_feat, num_feat * 2, 4, 2, 1, bias=False)) + self.conv2 = norm(nn.Conv2d(num_feat * 2, num_feat * 4, 4, 2, 1, bias=False)) + self.conv3 = norm(nn.Conv2d(num_feat * 4, num_feat * 8, 4, 2, 1, bias=False)) + # upsample + self.conv4 = norm(nn.Conv2d(num_feat * 8, num_feat * 4, 3, 1, 1, bias=False)) + self.conv5 = norm(nn.Conv2d(num_feat * 4, num_feat * 2, 3, 1, 1, bias=False)) + self.conv6 = norm(nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1, bias=False)) + + # extra + self.conv7 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False)) + self.conv8 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False)) + + self.conv9 = nn.Conv2d(num_feat, 1, 3, 1, 1) + + def forward(self, x): + x0 = F.leaky_relu(self.conv0(x), negative_slope=0.2, inplace=True) + x1 = F.leaky_relu(self.conv1(x0), negative_slope=0.2, inplace=True) + x2 = F.leaky_relu(self.conv2(x1), negative_slope=0.2, inplace=True) + x3 = F.leaky_relu(self.conv3(x2), negative_slope=0.2, inplace=True) + + # upsample + x3 = F.interpolate(x3, scale_factor=2, mode='bilinear', align_corners=False) + x4 = F.leaky_relu(self.conv4(x3), negative_slope=0.2, inplace=True) + + if self.skip_connection: + x4 = x4 + x2 + x4 = F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False) + x5 = F.leaky_relu(self.conv5(x4), negative_slope=0.2, inplace=True) + + if self.skip_connection: + x5 = x5 + x1 + x5 = F.interpolate(x5, scale_factor=2, mode='bilinear', align_corners=False) + x6 = F.leaky_relu(self.conv6(x5), negative_slope=0.2, inplace=True) + + if self.skip_connection: + x6 = x6 + x0 + + # extra + out = F.leaky_relu(self.conv7(x6), negative_slope=0.2, inplace=True) + out = F.leaky_relu(self.conv8(out), negative_slope=0.2, inplace=True) + out = self.conv9(out) + + return out diff --git a/data/realesrgan_dataset.py b/data/realesrgan_dataset.py new file mode 100644 index 0000000..570fc46 --- /dev/null +++ b/data/realesrgan_dataset.py @@ -0,0 +1,175 @@ +import cv2 +import math +import numpy as np +import os +import os.path as osp +import random +import time +import torch +from basicsr.data.degradations import circular_lowpass_kernel, random_mixed_kernels +from basicsr.data.transforms import augment +from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor +from basicsr.utils.registry import DATASET_REGISTRY +from torch.utils import data as data + + +@DATASET_REGISTRY.register() +class RealESRGANDataset(data.Dataset): + """ + Dataset used for Real-ESRGAN model. + """ + + def __init__(self, opt): + super(RealESRGANDataset, self).__init__() + self.opt = opt + # file client (io backend) + self.file_client = None + self.io_backend_opt = opt['io_backend'] + self.gt_folder = opt['dataroot_gt'] + + if self.io_backend_opt['type'] == 'lmdb': + self.io_backend_opt['db_paths'] = [self.gt_folder] + self.io_backend_opt['client_keys'] = ['gt'] + if not self.gt_folder.endswith('.lmdb'): + raise ValueError(f"'dataroot_gt' should end with '.lmdb', but received {self.gt_folder}") + with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin: + self.paths = [line.split('.')[0] for line in fin] + else: + with open(self.opt['meta_info']) as fin: + paths = [line.strip() for line in fin] + self.paths = [os.path.join(self.gt_folder, v) for v in paths] + + # blur settings for the first degradation + self.blur_kernel_size = opt['blur_kernel_size'] + self.kernel_list = opt['kernel_list'] + self.kernel_prob = opt['kernel_prob'] + self.blur_sigma = opt['blur_sigma'] + self.betag_range = opt['betag_range'] + self.betap_range = opt['betap_range'] + self.sinc_prob = opt['sinc_prob'] + + # blur settings for the second degradation + self.blur_kernel_size2 = opt['blur_kernel_size2'] + self.kernel_list2 = opt['kernel_list2'] + self.kernel_prob2 = opt['kernel_prob2'] + self.blur_sigma2 = opt['blur_sigma2'] + self.betag_range2 = opt['betag_range2'] + self.betap_range2 = opt['betap_range2'] + self.sinc_prob2 = opt['sinc_prob2'] + + # a final sinc filter + self.final_sinc_prob = opt['final_sinc_prob'] + + self.kernel_range = [2 * v + 1 for v in range(3, 11)] # kernel size ranges from 7 to 21 + self.pulse_tensor = torch.zeros(21, 21).float() # convolving with pulse tensor brings no blurry effect + self.pulse_tensor[10, 10] = 1 + + def __getitem__(self, index): + if self.file_client is None: + self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt) + + # -------------------------------- Load gt images -------------------------------- # + # Shape: (h, w, c); channel order: BGR; image range: [0, 1], float32. + gt_path = self.paths[index] + # avoid errors caused by high latency in reading files + retry = 3 + while retry > 0: + try: + img_bytes = self.file_client.get(gt_path, 'gt') + except Exception as e: + logger = get_root_logger() + logger.warn(f'File client error: {e}, remaining retry times: {retry - 1}') + # change another file to read + index = random.randint(0, self.__len__()) + gt_path = self.paths[index] + time.sleep(1) # sleep 1s for occasional server congestion + else: + break + finally: + retry -= 1 + img_gt = imfrombytes(img_bytes, float32=True) + + # -------------------- augmentation for training: flip, rotation -------------------- # + img_gt = augment(img_gt, self.opt['use_hflip'], self.opt['use_rot']) + + # crop or pad to 400: 400 is hard-coded. You may change it accordingly + h, w = img_gt.shape[0:2] + crop_pad_size = 400 + # pad + if h < crop_pad_size or w < crop_pad_size: + pad_h = max(0, crop_pad_size - h) + pad_w = max(0, crop_pad_size - w) + img_gt = cv2.copyMakeBorder(img_gt, 0, pad_h, 0, pad_w, cv2.BORDER_REFLECT_101) + # crop + if img_gt.shape[0] > crop_pad_size or img_gt.shape[1] > crop_pad_size: + h, w = img_gt.shape[0:2] + # randomly choose top and left coordinates + top = random.randint(0, h - crop_pad_size) + left = random.randint(0, w - crop_pad_size) + img_gt = img_gt[top:top + crop_pad_size, left:left + crop_pad_size, ...] + + # ------------------------ Generate kernels (used in the first degradation) ------------------------ # + kernel_size = random.choice(self.kernel_range) + if np.random.uniform() < self.opt['sinc_prob']: + # this sinc filter setting is for kernels ranging from [7, 21] + if kernel_size < 13: + omega_c = np.random.uniform(np.pi / 3, np.pi) + else: + omega_c = np.random.uniform(np.pi / 5, np.pi) + kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False) + else: + kernel = random_mixed_kernels( + self.kernel_list, + self.kernel_prob, + kernel_size, + self.blur_sigma, + self.blur_sigma, [-math.pi, math.pi], + self.betag_range, + self.betap_range, + noise_range=None) + # pad kernel + pad_size = (21 - kernel_size) // 2 + kernel = np.pad(kernel, ((pad_size, pad_size), (pad_size, pad_size))) + + # ------------------------ Generate kernels (used in the second degradation) ------------------------ # + kernel_size = random.choice(self.kernel_range) + if np.random.uniform() < self.opt['sinc_prob2']: + if kernel_size < 13: + omega_c = np.random.uniform(np.pi / 3, np.pi) + else: + omega_c = np.random.uniform(np.pi / 5, np.pi) + kernel2 = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False) + else: + kernel2 = random_mixed_kernels( + self.kernel_list2, + self.kernel_prob2, + kernel_size, + self.blur_sigma2, + self.blur_sigma2, [-math.pi, math.pi], + self.betag_range2, + self.betap_range2, + noise_range=None) + + # pad kernel + pad_size = (21 - kernel_size) // 2 + kernel2 = np.pad(kernel2, ((pad_size, pad_size), (pad_size, pad_size))) + + # ------------------------------------- sinc kernel ------------------------------------- # + if np.random.uniform() < self.opt['final_sinc_prob']: + kernel_size = random.choice(self.kernel_range) + omega_c = np.random.uniform(np.pi / 3, np.pi) + sinc_kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=21) + sinc_kernel = torch.FloatTensor(sinc_kernel) + else: + sinc_kernel = self.pulse_tensor + + # BGR to RGB, HWC to CHW, numpy to tensor + img_gt = img2tensor([img_gt], bgr2rgb=True, float32=True)[0] + kernel = torch.FloatTensor(kernel) + kernel2 = torch.FloatTensor(kernel2) + + return_d = {'gt': img_gt, 'kernel1': kernel, 'kernel2': kernel2, 'sinc_kernel': sinc_kernel, 'gt_path': gt_path} + return return_d + + def __len__(self): + return len(self.paths) diff --git a/inference_realesrgan.py b/inference_realesrgan.py index 45b1714..2bee50d 100644 --- a/inference_realesrgan.py +++ b/inference_realesrgan.py @@ -17,9 +17,7 @@ def main(): device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # set up model - # FIXME: currenly RRDBNet in BasicSR does not support scale argument. Will update later - # model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=args.scale) - model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32) + model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=args.scale) loadnet = torch.load(args.model_path) model.load_state_dict(loadnet['params_ema'], strict=True) model.eval() diff --git a/models/realesrgan_model.py b/models/realesrgan_model.py new file mode 100644 index 0000000..180103e --- /dev/null +++ b/models/realesrgan_model.py @@ -0,0 +1,240 @@ +import numpy as np +import random +import torch +from basicsr.data.degradations import random_add_gaussian_noise_pt, random_add_poisson_noise_pt +from basicsr.data.transforms import paired_random_crop +from basicsr.models.srgan_model import SRGANModel +from basicsr.utils import DiffJPEG, USMSharp +from basicsr.utils.img_process_util import filter2D +from basicsr.utils.registry import MODEL_REGISTRY +from collections import OrderedDict +from torch.nn import functional as F + + +@MODEL_REGISTRY.register() +class RealESRGANModel(SRGANModel): + """RealESRGAN Model""" + + def __init__(self, opt): + super(RealESRGANModel, self).__init__(opt) + self.jpeger = DiffJPEG(differentiable=False).cuda() + self.usm_shaper = USMSharp().cuda() + self.queue_size = opt['queue_size'] + + @torch.no_grad() + def _dequeue_and_enqueue(self): + # training pair pool + # initialize + b, c, h, w = self.lq.size() + if not hasattr(self, 'queue_lr'): + assert self.queue_size % b == 0, 'queue size should be divisible by batch size' + self.queue_lr = torch.zeros(self.queue_size, c, h, w).cuda() + _, c, h, w = self.gt.size() + self.queue_gt = torch.zeros(self.queue_size, c, h, w).cuda() + self.queue_ptr = 0 + if self.queue_ptr == self.queue_size: # full + # do dequeue and enqueue + # shuffle + idx = torch.randperm(self.queue_size) + self.queue_lr = self.queue_lr[idx] + self.queue_gt = self.queue_gt[idx] + # get + lq_dequeue = self.queue_lr[0:b, :, :, :].clone() + gt_dequeue = self.queue_gt[0:b, :, :, :].clone() + # update + self.queue_lr[0:b, :, :, :] = self.lq.clone() + self.queue_gt[0:b, :, :, :] = self.gt.clone() + + self.lq = lq_dequeue + self.gt = gt_dequeue + else: + # only do enqueue + self.queue_lr[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.lq.clone() + self.queue_gt[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.gt.clone() + self.queue_ptr = self.queue_ptr + b + + @torch.no_grad() + def feed_data(self, data): + if self.is_train: + # training data synthesis + self.gt = data['gt'].to(self.device) + self.gt_usm = self.usm_shaper(self.gt) + + self.kernel1 = data['kernel1'].to(self.device) + self.kernel2 = data['kernel2'].to(self.device) + self.sinc_kernel = data['sinc_kernel'].to(self.device) + + ori_h, ori_w = self.gt.size()[2:4] + + # ----------------------- The first degradation process ----------------------- # + # blur + out = filter2D(self.gt_usm, self.kernel1) + # random resize + updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob'])[0] + if updown_type == 'up': + scale = np.random.uniform(1, self.opt['resize_range'][1]) + elif updown_type == 'down': + scale = np.random.uniform(self.opt['resize_range'][0], 1) + else: + scale = 1 + mode = random.choice(['area', 'bilinear', 'bicubic']) + out = F.interpolate(out, scale_factor=scale, mode=mode) + # noise + gray_noise_prob = self.opt['gray_noise_prob'] + if np.random.uniform() < self.opt['gaussian_noise_prob']: + out = random_add_gaussian_noise_pt( + out, sigma_range=self.opt['noise_range'], clip=True, rounds=False, gray_prob=gray_noise_prob) + else: + out = random_add_poisson_noise_pt( + out, + scale_range=self.opt['poisson_scale_range'], + gray_prob=gray_noise_prob, + clip=True, + rounds=False) + # JPEG compression + jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range']) + out = torch.clamp(out, 0, 1) + out = self.jpeger(out, quality=jpeg_p) + + # ----------------------- The second degradation process ----------------------- # + # blur + if np.random.uniform() < self.opt['second_blur_prob']: + out = filter2D(out, self.kernel2) + # random resize + updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob2'])[0] + if updown_type == 'up': + scale = np.random.uniform(1, self.opt['resize_range2'][1]) + elif updown_type == 'down': + scale = np.random.uniform(self.opt['resize_range2'][0], 1) + else: + scale = 1 + mode = random.choice(['area', 'bilinear', 'bicubic']) + out = F.interpolate( + out, size=(int(ori_h / self.opt['scale'] * scale), int(ori_w / self.opt['scale'] * scale)), mode=mode) + # noise + gray_noise_prob = self.opt['gray_noise_prob2'] + if np.random.uniform() < self.opt['gaussian_noise_prob2']: + out = random_add_gaussian_noise_pt( + out, sigma_range=self.opt['noise_range2'], clip=True, rounds=False, gray_prob=gray_noise_prob) + else: + out = random_add_poisson_noise_pt( + out, + scale_range=self.opt['poisson_scale_range2'], + gray_prob=gray_noise_prob, + clip=True, + rounds=False) + + # JPEG compression + the final sinc filter + # We also need to resize images to desired sizes. We group [resize back + sinc filter] together + # as one operation. + # We consider two orders: + # 1. [resize back + sinc filter] + JPEG compression + # 2. JPEG compression + [resize back + sinc filter] + # Empirically, we find other combinations (sinc + JPEG + Resize) will introduce twisted lines. + if np.random.uniform() < 0.5: + # resize back + the final sinc filter + mode = random.choice(['area', 'bilinear', 'bicubic']) + out = F.interpolate(out, size=(ori_h // self.opt['scale'], ori_w // self.opt['scale']), mode=mode) + out = filter2D(out, self.sinc_kernel) + # JPEG compression + jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2']) + out = torch.clamp(out, 0, 1) + out = self.jpeger(out, quality=jpeg_p) + else: + # JPEG compression + jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2']) + out = torch.clamp(out, 0, 1) + out = self.jpeger(out, quality=jpeg_p) + # resize back + the final sinc filter + mode = random.choice(['area', 'bilinear', 'bicubic']) + out = F.interpolate(out, size=(ori_h // self.opt['scale'], ori_w // self.opt['scale']), mode=mode) + out = filter2D(out, self.sinc_kernel) + + # clamp and round + self.lq = torch.clamp((out * 255.0).round(), 0, 255) / 255. + + # random crop + gt_size = self.opt['gt_size'] + (self.gt, self.gt_usm), self.lq = paired_random_crop([self.gt, self.gt_usm], self.lq, gt_size, + self.opt['scale']) + + # training pair pool + self._dequeue_and_enqueue() + else: + self.lq = data['lq'].to(self.device) + if 'gt' in data: + self.gt = data['gt'].to(self.device) + + def nondist_validation(self, dataloader, current_iter, tb_logger, save_img): + # do not use the synthetic process during validation + self.is_train = False + super(RealESRGANModel, self).nondist_validation(dataloader, current_iter, tb_logger, save_img) + self.is_train = True + + def optimize_parameters(self, current_iter): + l1_gt = self.gt_usm + percep_gt = self.gt_usm + gan_gt = self.gt_usm + if self.opt['l1_gt_usm'] is False: + l1_gt = self.gt + if self.opt['percep_gt_usm'] is False: + percep_gt = self.gt + if self.opt['gan_gt_usm'] is False: + gan_gt = self.gt + + # optimize net_g + for p in self.net_d.parameters(): + p.requires_grad = False + + self.optimizer_g.zero_grad() + self.output = self.net_g(self.lq) + + l_g_total = 0 + loss_dict = OrderedDict() + if (current_iter % self.net_d_iters == 0 and current_iter > self.net_d_init_iters): + # pixel loss + if self.cri_pix: + l_g_pix = self.cri_pix(self.output, l1_gt) + l_g_total += l_g_pix + loss_dict['l_g_pix'] = l_g_pix + # perceptual loss + if self.cri_perceptual: + l_g_percep, l_g_style = self.cri_perceptual(self.output, percep_gt) + if l_g_percep is not None: + l_g_total += l_g_percep + loss_dict['l_g_percep'] = l_g_percep + if l_g_style is not None: + l_g_total += l_g_style + loss_dict['l_g_style'] = l_g_style + # gan loss + fake_g_pred = self.net_d(self.output) + l_g_gan = self.cri_gan(fake_g_pred, True, is_disc=False) + l_g_total += l_g_gan + loss_dict['l_g_gan'] = l_g_gan + + l_g_total.backward() + self.optimizer_g.step() + + # optimize net_d + for p in self.net_d.parameters(): + p.requires_grad = True + + self.optimizer_d.zero_grad() + # real + real_d_pred = self.net_d(gan_gt) + l_d_real = self.cri_gan(real_d_pred, True, is_disc=True) + loss_dict['l_d_real'] = l_d_real + loss_dict['out_d_real'] = torch.mean(real_d_pred.detach()) + l_d_real.backward() + # fake + fake_d_pred = self.net_d(self.output.detach().clone()) # clone for pt1.9 + l_d_fake = self.cri_gan(fake_d_pred, False, is_disc=True) + loss_dict['l_d_fake'] = l_d_fake + loss_dict['out_d_fake'] = torch.mean(fake_d_pred.detach()) + l_d_fake.backward() + self.optimizer_d.step() + + if self.ema_decay > 0: + self.model_ema(decay=self.ema_decay) + + self.log_dict = self.reduce_loss_dict(loss_dict) diff --git a/models/realesrnet_model.py b/models/realesrnet_model.py new file mode 100644 index 0000000..e92e833 --- /dev/null +++ b/models/realesrnet_model.py @@ -0,0 +1,172 @@ +import numpy as np +import random +import torch +from basicsr.data.degradations import random_add_gaussian_noise_pt, random_add_poisson_noise_pt +from basicsr.data.transforms import paired_random_crop +from basicsr.models.sr_model import SRModel +from basicsr.utils import DiffJPEG, USMSharp +from basicsr.utils.img_process_util import filter2D +from basicsr.utils.registry import MODEL_REGISTRY +from torch.nn import functional as F + + +@MODEL_REGISTRY.register() +class RealESRNetModel(SRModel): + """RealESRNet Model""" + + def __init__(self, opt): + super(RealESRNetModel, self).__init__(opt) + self.jpeger = DiffJPEG(differentiable=False).cuda() + self.usm_shaper = USMSharp().cuda() + self.queue_size = opt['queue_size'] + + @torch.no_grad() + def _dequeue_and_enqueue(self): + # training pair pool + # initialize + b, c, h, w = self.lq.size() + if not hasattr(self, 'queue_lr'): + assert self.queue_size % b == 0, 'queue size should be divisible by batch size' + self.queue_lr = torch.zeros(self.queue_size, c, h, w).cuda() + _, c, h, w = self.gt.size() + self.queue_gt = torch.zeros(self.queue_size, c, h, w).cuda() + self.queue_ptr = 0 + if self.queue_ptr == self.queue_size: # full + # do dequeue and enqueue + # shuffle + idx = torch.randperm(self.queue_size) + self.queue_lr = self.queue_lr[idx] + self.queue_gt = self.queue_gt[idx] + # get + lq_dequeue = self.queue_lr[0:b, :, :, :].clone() + gt_dequeue = self.queue_gt[0:b, :, :, :].clone() + # update + self.queue_lr[0:b, :, :, :] = self.lq.clone() + self.queue_gt[0:b, :, :, :] = self.gt.clone() + + self.lq = lq_dequeue + self.gt = gt_dequeue + else: + # only do enqueue + self.queue_lr[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.lq.clone() + self.queue_gt[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.gt.clone() + self.queue_ptr = self.queue_ptr + b + + @torch.no_grad() + def feed_data(self, data): + if self.is_train: + # training data synthesis + self.gt = data['gt'].to(self.device) + # USM the GT images + if self.opt['gt_usm'] is True: + self.gt = self.usm_shaper(self.gt) + + self.kernel1 = data['kernel1'].to(self.device) + self.kernel2 = data['kernel2'].to(self.device) + self.sinc_kernel = data['sinc_kernel'].to(self.device) + + ori_h, ori_w = self.gt.size()[2:4] + + # ----------------------- The first degradation process ----------------------- # + # blur + out = filter2D(self.gt, self.kernel1) + # random resize + updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob'])[0] + if updown_type == 'up': + scale = np.random.uniform(1, self.opt['resize_range'][1]) + elif updown_type == 'down': + scale = np.random.uniform(self.opt['resize_range'][0], 1) + else: + scale = 1 + mode = random.choice(['area', 'bilinear', 'bicubic']) + out = F.interpolate(out, scale_factor=scale, mode=mode) + # noise + gray_noise_prob = self.opt['gray_noise_prob'] + if np.random.uniform() < self.opt['gaussian_noise_prob']: + out = random_add_gaussian_noise_pt( + out, sigma_range=self.opt['noise_range'], clip=True, rounds=False, gray_prob=gray_noise_prob) + else: + out = random_add_poisson_noise_pt( + out, + scale_range=self.opt['poisson_scale_range'], + gray_prob=gray_noise_prob, + clip=True, + rounds=False) + # JPEG compression + jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range']) + out = torch.clamp(out, 0, 1) + out = self.jpeger(out, quality=jpeg_p) + + # ----------------------- The second degradation process ----------------------- # + # blur + if np.random.uniform() < self.opt['second_blur_prob']: + out = filter2D(out, self.kernel2) + # random resize + updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob2'])[0] + if updown_type == 'up': + scale = np.random.uniform(1, self.opt['resize_range2'][1]) + elif updown_type == 'down': + scale = np.random.uniform(self.opt['resize_range2'][0], 1) + else: + scale = 1 + mode = random.choice(['area', 'bilinear', 'bicubic']) + out = F.interpolate( + out, size=(int(ori_h / self.opt['scale'] * scale), int(ori_w / self.opt['scale'] * scale)), mode=mode) + # noise + gray_noise_prob = self.opt['gray_noise_prob2'] + if np.random.uniform() < self.opt['gaussian_noise_prob2']: + out = random_add_gaussian_noise_pt( + out, sigma_range=self.opt['noise_range2'], clip=True, rounds=False, gray_prob=gray_noise_prob) + else: + out = random_add_poisson_noise_pt( + out, + scale_range=self.opt['poisson_scale_range2'], + gray_prob=gray_noise_prob, + clip=True, + rounds=False) + + # JPEG compression + the final sinc filter + # We also need to resize images to desired sizes. We group [resize back + sinc filter] together + # as one operation. + # We consider two orders: + # 1. [resize back + sinc filter] + JPEG compression + # 2. JPEG compression + [resize back + sinc filter] + # Empirically, we find other combinations (sinc + JPEG + Resize) will introduce twisted lines. + if np.random.uniform() < 0.5: + # resize back + the final sinc filter + mode = random.choice(['area', 'bilinear', 'bicubic']) + out = F.interpolate(out, size=(ori_h // self.opt['scale'], ori_w // self.opt['scale']), mode=mode) + out = filter2D(out, self.sinc_kernel) + # JPEG compression + jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2']) + out = torch.clamp(out, 0, 1) + out = self.jpeger(out, quality=jpeg_p) + else: + # JPEG compression + jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2']) + out = torch.clamp(out, 0, 1) + out = self.jpeger(out, quality=jpeg_p) + # resize back + the final sinc filter + mode = random.choice(['area', 'bilinear', 'bicubic']) + out = F.interpolate(out, size=(ori_h // self.opt['scale'], ori_w // self.opt['scale']), mode=mode) + out = filter2D(out, self.sinc_kernel) + + # clamp and round + self.lq = torch.clamp((out * 255.0).round(), 0, 255) / 255. + + # random crop + gt_size = self.opt['gt_size'] + self.gt, self.lq = paired_random_crop(self.gt, self.lq, gt_size, self.opt['scale']) + + # training pair pool + self._dequeue_and_enqueue() + else: + self.lq = data['lq'].to(self.device) + if 'gt' in data: + self.gt = data['gt'].to(self.device) + + def nondist_validation(self, dataloader, current_iter, tb_logger, save_img): + # do not use the synthetic process during validation + self.is_train = False + super(RealESRNetModel, self).nondist_validation(dataloader, current_iter, tb_logger, save_img) + self.is_train = True diff --git a/options/train_realesrgan_x4plus.yml b/options/train_realesrgan_x4plus.yml new file mode 100644 index 0000000..940a777 --- /dev/null +++ b/options/train_realesrgan_x4plus.yml @@ -0,0 +1,186 @@ +# general settings +name: train_RealESRGANx4plus_400k_B12G4_fromRealESRNet +model_type: RealESRGANModel +scale: 4 +num_gpu: 4 +manual_seed: 0 + +# ----------------- options for synthesizing training data in RealESRGANModel ----------------- # +# USM the ground-truth +l1_gt_usm: True +percep_gt_usm: True +gan_gt_usm: False + +# the first degradation process +resize_prob: [0.2, 0.7, 0.1] # up, down, keep +resize_range: [0.15, 1.5] +gaussian_noise_prob: 0.5 +noise_range: [1, 30] +poisson_scale_range: [0.05, 3] +gray_noise_prob: 0.4 +jpeg_range: [30, 95] + +# the second degradation process +second_blur_prob: 0.8 +resize_prob2: [0.3, 0.4, 0.3] # up, down, keep +resize_range2: [0.3, 1.2] +gaussian_noise_prob2: 0.5 +noise_range2: [1, 25] +poisson_scale_range2: [0.05, 2.5] +gray_noise_prob2: 0.4 +jpeg_range2: [30, 95] + +gt_size: 256 +queue_size: 180 + +# dataset and data loader settings +datasets: + train: + name: DF2K+OST + type: RealESRGANDataset + dataroot_gt: datasets/DF2K + meta_info: data/meta_info/meta_info_DF2Kmultiscale+OST_sub.txt + io_backend: + type: disk + + blur_kernel_size: 21 + kernel_list: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso'] + kernel_prob: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03] + sinc_prob: 0.1 + blur_sigma: [0.2, 3] + betag_range: [0.5, 4] + betap_range: [1, 2] + + blur_kernel_size2: 21 + kernel_list2: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso'] + kernel_prob2: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03] + sinc_prob2: 0.1 + blur_sigma2: [0.2, 1.5] + betag_range2: [0.5, 4] + betap_range2: [1, 2] + + final_sinc_prob: 0.8 + + gt_size: 256 + use_hflip: True + use_rot: False + + # data loader + use_shuffle: true + num_worker_per_gpu: 5 + batch_size_per_gpu: 12 + dataset_enlarge_ratio: 1 + prefetch_mode: ~ + + # Uncomment these for validation + # val: + # name: validation + # type: PairedImageDataset + # dataroot_gt: path_to_gt + # dataroot_lq: path_to_lq + # io_backend: + # type: disk + +# network structures +network_g: + type: RRDBNet + num_in_ch: 3 + num_out_ch: 3 + num_feat: 64 + num_block: 23 + num_grow_ch: 32 + + +network_d: + type: UNetDiscriminatorSN + num_in_ch: 3 + num_feat: 64 + skip_connection: True + +# path +path: + # use the pre-trained Real-ESRNet model + pretrain_network_g: experiments/train_RealESRNetx4plus_1000k_B12G4_fromESRGAN/model/net_g_1000000.pth + param_key_g: params_ema + strict_load_g: true + resume_state: ~ + +# training settings +train: + ema_decay: 0.999 + optim_g: + type: Adam + lr: !!float 1e-4 + weight_decay: 0 + betas: [0.9, 0.99] + optim_d: + type: Adam + lr: !!float 1e-4 + weight_decay: 0 + betas: [0.9, 0.99] + + scheduler: + type: MultiStepLR + milestones: [400000] + gamma: 0.5 + + total_iter: 400000 + warmup_iter: -1 # no warm up + + # losses + pixel_opt: + type: L1Loss + loss_weight: 1.0 + reduction: mean + # perceptual loss (content and style losses) + perceptual_opt: + type: PerceptualLoss + layer_weights: + # before relu + 'conv1_2': 0.1 + 'conv2_2': 0.1 + 'conv3_4': 1 + 'conv4_4': 1 + 'conv5_4': 1 + vgg_type: vgg19 + use_input_norm: true + perceptual_weight: !!float 1.0 + style_weight: 0 + range_norm: false + criterion: l1 + # gan loss + gan_opt: + type: GANLoss + gan_type: vanilla + real_label_val: 1.0 + fake_label_val: 0.0 + loss_weight: !!float 1e-1 + + net_d_iters: 1 + net_d_init_iters: 0 + +# Uncomment these for validation +# validation settings +# val: +# val_freq: !!float 5e3 +# save_img: True + +# metrics: +# psnr: # metric name, can be arbitrary +# type: calculate_psnr +# crop_border: 4 +# test_y_channel: false + +# logging settings +logger: + print_freq: 100 + save_checkpoint_freq: !!float 5e3 + use_tb_logger: true + wandb: + project: ~ + resume_id: ~ + +# dist training settings +dist_params: + backend: nccl + port: 29500 diff --git a/options/train_realesrnet_x4plus.yml b/options/train_realesrnet_x4plus.yml new file mode 100644 index 0000000..400c580 --- /dev/null +++ b/options/train_realesrnet_x4plus.yml @@ -0,0 +1,144 @@ +# general settings +name: train_RealESRNetx4plus_1000k_B12G4_fromESRGAN +model_type: RealESRNetModel +scale: 4 +num_gpu: 4 +manual_seed: 0 + +# ----------------- options for synthesizing training data in RealESRNetModel ----------------- # +gt_usm: True # USM the ground-truth + +# the first degradation process +resize_prob: [0.2, 0.7, 0.1] # up, down, keep +resize_range: [0.15, 1.5] +gaussian_noise_prob: 0.5 +noise_range: [1, 30] +poisson_scale_range: [0.05, 3] +gray_noise_prob: 0.4 +jpeg_range: [30, 95] + +# the second degradation process +second_blur_prob: 0.8 +resize_prob2: [0.3, 0.4, 0.3] # up, down, keep +resize_range2: [0.3, 1.2] +gaussian_noise_prob2: 0.5 +noise_range2: [1, 25] +poisson_scale_range2: [0.05, 2.5] +gray_noise_prob2: 0.4 +jpeg_range2: [30, 95] + +gt_size: 256 +queue_size: 180 + +# dataset and data loader settings +datasets: + train: + name: DF2K+OST + type: RealESRGANDataset + dataroot_gt: datasets/DF2K + meta_info: data/meta_info/meta_info_DF2Kmultiscale+OST_sub.txt + io_backend: + type: disk + + blur_kernel_size: 21 + kernel_list: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso'] + kernel_prob: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03] + sinc_prob: 0.1 + blur_sigma: [0.2, 3] + betag_range: [0.5, 4] + betap_range: [1, 2] + + blur_kernel_size2: 21 + kernel_list2: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso'] + kernel_prob2: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03] + sinc_prob2: 0.1 + blur_sigma2: [0.2, 1.5] + betag_range2: [0.5, 4] + betap_range2: [1, 2] + + final_sinc_prob: 0.8 + + gt_size: 256 + use_hflip: True + use_rot: False + + # data loader + use_shuffle: true + num_worker_per_gpu: 5 + batch_size_per_gpu: 12 + dataset_enlarge_ratio: 1 + prefetch_mode: ~ + + # Uncomment these for validation + # val: + # name: validation + # type: PairedImageDataset + # dataroot_gt: path_to_gt + # dataroot_lq: path_to_lq + # io_backend: + # type: disk + +# network structures +network_g: + type: RRDBNet + num_in_ch: 3 + num_out_ch: 3 + num_feat: 64 + num_block: 23 + num_grow_ch: 32 + +# path +path: + pretrain_network_g: experiments/pretrained_models/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth + param_key_g: params_ema + strict_load_g: true + resume_state: ~ + +# training settings +train: + ema_decay: 0.999 + optim_g: + type: Adam + lr: !!float 2e-4 + weight_decay: 0 + betas: [0.9, 0.99] + + scheduler: + type: MultiStepLR + milestones: [1000000] + gamma: 0.5 + + total_iter: 1000000 + warmup_iter: -1 # no warm up + + # losses + pixel_opt: + type: L1Loss + loss_weight: 1.0 + reduction: mean + +# Uncomment these for validation +# validation settings +# val: +# val_freq: !!float 5e3 +# save_img: True + +# metrics: +# psnr: # metric name, can be arbitrary +# type: calculate_psnr +# crop_border: 4 +# test_y_channel: false + +# logging settings +logger: + print_freq: 100 + save_checkpoint_freq: !!float 5e3 + use_tb_logger: true + wandb: + project: ~ + resume_id: ~ + +# dist training settings +dist_params: + backend: nccl + port: 29500 diff --git a/train.py b/train.py new file mode 100644 index 0000000..bd52322 --- /dev/null +++ b/train.py @@ -0,0 +1,10 @@ +import os.path as osp +from basicsr.train import train_pipeline + +import archs # noqa: F401 +import data # noqa: F401 +import models # noqa: F401 + +if __name__ == '__main__': + root_path = osp.abspath(osp.join(__file__, osp.pardir)) + train_pipeline(root_path)