From f5ccd64ce51105ddbc89df91a778499630d68ee3 Mon Sep 17 00:00:00 2001 From: Xintao Date: Fri, 27 Aug 2021 16:14:48 +0800 Subject: [PATCH] support finetune with paired data --- README.md | 3 + Training.md | 106 ++++++++++++ options/finetune_realesrgan_x4plus.yml | 4 +- .../finetune_realesrgan_x4plus_pairdata.yml | 151 ++++++++++++++++++ realesrgan/data/realesrgan_paired_dataset.py | 106 ++++++++++++ realesrgan/models/realesrgan_model.py | 5 +- realesrgan/models/realesrnet_model.py | 5 +- requirements.txt | 1 + scripts/generate_meta_info.py | 3 + scripts/generate_meta_info_pairdata.py | 47 ++++++ setup.cfg | 2 +- 11 files changed, 426 insertions(+), 7 deletions(-) create mode 100644 options/finetune_realesrgan_x4plus_pairdata.yml create mode 100644 realesrgan/data/realesrgan_paired_dataset.py create mode 100644 scripts/generate_meta_info_pairdata.py diff --git a/README.md b/README.md index 4370be8..458382b 100644 --- a/README.md +++ b/README.md @@ -16,6 +16,7 @@ We extend the powerful ESRGAN to a practical restoration application (namely, Re :question: Frequently Asked Questions can be found in [FAQ.md](FAQ.md). :triangular_flag_on_post: **Updates** +- :white_check_mark: Support finetuning on your own data or paired data (*i.e.*, finetuning ESRGAN). See [here](Training.md#Finetune-Real-ESRGAN-on-your-own-dataset) - :white_check_mark: Integrate [GFPGAN](https://github.com/TencentARC/GFPGAN) to support **face enhancement**. - :white_check_mark: Integrated to [Huggingface Spaces](https://huggingface.co/spaces) with [Gradio](https://github.com/gradio-app/gradio). See [Gradio Web Demo](https://huggingface.co/spaces/akhaliq/Real-ESRGAN). Thanks [@AK391](https://github.com/AK391) - :white_check_mark: Support arbitrary scale with `--outscale` (It actually further resizes outputs with `LANCZOS4`). Add *RealESRGAN_x2plus.pth* model. @@ -135,8 +136,10 @@ Results are in the `results` folder ## :european_castle: Model Zoo - [RealESRGAN_x4plus](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth) +- [RealESRGAN_x4plus_netD](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.3/RealESRGAN_x4plus_netD.pth) - [RealESRNet_x4plus](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth) - [RealESRGAN_x2plus](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth) +- [RealESRGAN_x2plus_netD](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.3/RealESRGAN_x2plus_netD.pth) - [official ESRGAN_x4](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth) ## :computer: Training and Finetuning on your own dataset diff --git a/Training.md b/Training.md index 93e65ba..0c8f595 100644 --- a/Training.md +++ b/Training.md @@ -5,6 +5,9 @@ - [Dataset Preparation](#dataset-preparation) - [Train Real-ESRNet](#Train-Real-ESRNet) - [Train Real-ESRGAN](#Train-Real-ESRGAN) +- [Finetune Real-ESRGAN on your own dataset](#Finetune-Real-ESRGAN-on-your-own-dataset) + - [Generate degraded images on the fly](#Generate-degraded-images-on-the-fly) + - [Use paired training data](#Use-paired-training-data) ## Train Real-ESRGAN @@ -131,3 +134,106 @@ You can merge several folders into one meta_info txt. Here is the example: CUDA_VISIBLE_DEVICES=0,1,2,3 \ python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 realesrgan/train.py -opt options/train_realesrgan_x4plus.yml --launcher pytorch --auto_resume ``` + +## Finetune Real-ESRGAN on your own dataset + +You can finetune Real-ESRGAN on your own dataset. Typically, the fine-tuning process can be divided into two cases: + +1. [generate degraded images on the fly](#Generate-degraded-images-on-the-fly) +1. [use your own **paired** data(#Use-paired-training-data) + +### Generate degraded images on the fly + +Only high-resolution images are required. The low-quality images are generated with the degradation process in Real-ESRGAN during trainig. + +**Prepare dataset** + +See [this section](#dataset-preparation) for more details. + +**Download pre-trained models** + +Download pre-trained models into `experiments/pretrained_models`. + +*RealESRGAN_x4plus.pth* + + ```bash + wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth -P experiments/pretrained_models + ``` + +*RealESRGAN_x4plus_netD.pth* + + ```bash + wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.3/RealESRGAN_x4plus_netD.pth -P experiments/pretrained_models + ``` + +**Finetune** + +Modify [options/finetune_realesrgan_x4plus.yml](options/finetune_realesrgan_x4plus.yml) accordingly, especially the `datasets` part: + ```yml + train: + name: DF2K+OST + type: RealESRGANDataset + dataroot_gt: datasets/DF2K # modify to the root path of your folder + meta_info: realesrgan/meta_info/meta_info_DF2Kmultiscale+OST_sub.txt # modify to your own generate meta info txt + io_backend: + type: disk + ``` + +We use four GPUs for training. We use the `--auto_resume` argument to automatically resume the training if necessary. + ```bash + CUDA_VISIBLE_DEVICES=0,1,2,3 \ + python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 realesrgan/train.py -opt options/finetune_realesrgan_x4plus.yml --launcher pytorch --auto_resume + ``` + +### Use paired training data + +You can also finetune RealESRGAN with your own paired data. It is more similar to fine-tuning ESRGAN. + +**Prepare dataset** + +Assume that you already have two folders: + +- gt folder (Ground-truth, high-resolution images): datasets/DF2K/DIV2K_train_HR_sub +- lq folder (Low quality, low-resolution images): datasets/DF2K/DIV2K_train_LR_bicubic_X4_sub + +Then, you can prepare the meta_info txt file using the script [scripts/generate_meta_info_pairdata.py](scripts/generate_meta_info_pairdata.py): + +```bash +python scripts/generate_meta_info_pairdata.py --input datasets/DF2K/DIV2K_train_HR_sub datasets/DF2K/DIV2K_train_LR_bicubic_X4_sub --meta_info datasets/DF2K/meta_info/meta_info_DIV2K_sub_pair.txt +``` + +**Download pre-trained models** + +Download pre-trained models into `experiments/pretrained_models`. + +*RealESRGAN_x4plus.pth* + + ```bash + wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth -P experiments/pretrained_models + ``` + +*RealESRGAN_x4plus_netD.pth* + + ```bash + wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.3/RealESRGAN_x4plus_netD.pth -P experiments/pretrained_models + ``` + +**Finetune** + +Modify [options/finetune_realesrgan_x4plus_pairdata.yml](options/finetune_realesrgan_x4plus_pairdata.yml) accordingly, especially the `datasets` part: + ```yml + train: + name: DIV2K + type: RealESRGANPairedDataset + dataroot_gt: datasets/DF2K # modify to the root path of your folder + dataroot_lq: datasets/DF2K # modify to the root path of your folder + meta_info: datasets/DF2K/meta_info/meta_info_DIV2K_sub_pair.txt # modify to the root path of your folder + io_backend: + type: disk + ``` + +We use four GPUs for training. We use the `--auto_resume` argument to automatically resume the training if necessary. + ```bash + CUDA_VISIBLE_DEVICES=0,1,2,3 \ + python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 realesrgan/train.py -opt options/finetune_realesrgan_x4plus_pairdata.yml --launcher pytorch --auto_resume + ``` diff --git a/options/finetune_realesrgan_x4plus.yml b/options/finetune_realesrgan_x4plus.yml index a827e38..c4ff3fc 100644 --- a/options/finetune_realesrgan_x4plus.yml +++ b/options/finetune_realesrgan_x4plus.yml @@ -1,8 +1,8 @@ # general settings -name: finetune_RealESRGANx4plus_400k_B12G4 +name: finetune_RealESRGANx4plus_400k model_type: RealESRGANModel scale: 4 -num_gpu: 4 +num_gpu: auto manual_seed: 0 # ----------------- options for synthesizing training data in RealESRGANModel ----------------- # diff --git a/options/finetune_realesrgan_x4plus_pairdata.yml b/options/finetune_realesrgan_x4plus_pairdata.yml new file mode 100644 index 0000000..b10ca31 --- /dev/null +++ b/options/finetune_realesrgan_x4plus_pairdata.yml @@ -0,0 +1,151 @@ +# general settings +name: finetune_RealESRGANx4plus_400k_pairdata +model_type: RealESRGANModel +scale: 4 +num_gpu: auto +manual_seed: 0 + +# USM the ground-truth +l1_gt_usm: True +percep_gt_usm: True +gan_gt_usm: False + +high_order_degradation: False # do not use the high-order degradation generation process + +# dataset and data loader settings +datasets: + train: + name: DIV2K + type: RealESRGANPairedDataset + dataroot_gt: datasets/DF2K + dataroot_lq: datasets/DF2K + meta_info: datasets/DF2K/meta_info/meta_info_DIV2K_sub_pair.txt + io_backend: + type: disk + + gt_size: 256 + use_hflip: True + use_rot: False + + # data loader + use_shuffle: true + num_worker_per_gpu: 5 + batch_size_per_gpu: 12 + dataset_enlarge_ratio: 1 + prefetch_mode: ~ + + # Uncomment these for validation + # val: + # name: validation + # type: PairedImageDataset + # dataroot_gt: path_to_gt + # dataroot_lq: path_to_lq + # io_backend: + # type: disk + +# network structures +network_g: + type: RRDBNet + num_in_ch: 3 + num_out_ch: 3 + num_feat: 64 + num_block: 23 + num_grow_ch: 32 + + +network_d: + type: UNetDiscriminatorSN + num_in_ch: 3 + num_feat: 64 + skip_connection: True + +# path +path: + # use the pre-trained Real-ESRNet model + pretrain_network_g: experiments/pretrained_models/RealESRNet_x4plus.pth + param_key_g: params_ema + strict_load_g: true + pretrain_network_d: experiments/pretrained_models/RealESRGAN_x4plus_netD.pth + param_key_d: params + strict_load_d: true + resume_state: ~ + +# training settings +train: + ema_decay: 0.999 + optim_g: + type: Adam + lr: !!float 1e-4 + weight_decay: 0 + betas: [0.9, 0.99] + optim_d: + type: Adam + lr: !!float 1e-4 + weight_decay: 0 + betas: [0.9, 0.99] + + scheduler: + type: MultiStepLR + milestones: [400000] + gamma: 0.5 + + total_iter: 400000 + warmup_iter: -1 # no warm up + + # losses + pixel_opt: + type: L1Loss + loss_weight: 1.0 + reduction: mean + # perceptual loss (content and style losses) + perceptual_opt: + type: PerceptualLoss + layer_weights: + # before relu + 'conv1_2': 0.1 + 'conv2_2': 0.1 + 'conv3_4': 1 + 'conv4_4': 1 + 'conv5_4': 1 + vgg_type: vgg19 + use_input_norm: true + perceptual_weight: !!float 1.0 + style_weight: 0 + range_norm: false + criterion: l1 + # gan loss + gan_opt: + type: GANLoss + gan_type: vanilla + real_label_val: 1.0 + fake_label_val: 0.0 + loss_weight: !!float 1e-1 + + net_d_iters: 1 + net_d_init_iters: 0 + +# Uncomment these for validation +# validation settings +# val: +# val_freq: !!float 5e3 +# save_img: True + +# metrics: +# psnr: # metric name, can be arbitrary +# type: calculate_psnr +# crop_border: 4 +# test_y_channel: false + +# logging settings +logger: + print_freq: 100 + save_checkpoint_freq: !!float 5e3 + use_tb_logger: true + wandb: + project: ~ + resume_id: ~ + +# dist training settings +dist_params: + backend: nccl + port: 29500 diff --git a/realesrgan/data/realesrgan_paired_dataset.py b/realesrgan/data/realesrgan_paired_dataset.py new file mode 100644 index 0000000..b450c43 --- /dev/null +++ b/realesrgan/data/realesrgan_paired_dataset.py @@ -0,0 +1,106 @@ +import os +from basicsr.data.data_util import paired_paths_from_folder, paired_paths_from_lmdb +from basicsr.data.transforms import augment, paired_random_crop +from basicsr.utils import FileClient, imfrombytes, img2tensor +from basicsr.utils.registry import DATASET_REGISTRY +from torch.utils import data as data +from torchvision.transforms.functional import normalize + + +@DATASET_REGISTRY.register() +class RealESRGANPairedDataset(data.Dataset): + """Paired image dataset for image restoration. + + Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc) and + GT image pairs. + + There are three modes: + 1. 'lmdb': Use lmdb files. + If opt['io_backend'] == lmdb. + 2. 'meta_info': Use meta information file to generate paths. + If opt['io_backend'] != lmdb and opt['meta_info'] is not None. + 3. 'folder': Scan folders to generate paths. + The rest. + + Args: + opt (dict): Config for train datasets. It contains the following keys: + dataroot_gt (str): Data root path for gt. + dataroot_lq (str): Data root path for lq. + meta_info (str): Path for meta information file. + io_backend (dict): IO backend type and other kwarg. + filename_tmpl (str): Template for each filename. Note that the + template excludes the file extension. Default: '{}'. + gt_size (int): Cropped patched size for gt patches. + use_hflip (bool): Use horizontal flips. + use_rot (bool): Use rotation (use vertical flip and transposing h + and w for implementation). + + scale (bool): Scale, which will be added automatically. + phase (str): 'train' or 'val'. + """ + + def __init__(self, opt): + super(RealESRGANPairedDataset, self).__init__() + self.opt = opt + # file client (io backend) + self.file_client = None + self.io_backend_opt = opt['io_backend'] + self.mean = opt['mean'] if 'mean' in opt else None + self.std = opt['std'] if 'std' in opt else None + + self.gt_folder, self.lq_folder = opt['dataroot_gt'], opt['dataroot_lq'] + if 'filename_tmpl' in opt: + self.filename_tmpl = opt['filename_tmpl'] + else: + self.filename_tmpl = '{}' + + if self.io_backend_opt['type'] == 'lmdb': + self.io_backend_opt['db_paths'] = [self.lq_folder, self.gt_folder] + self.io_backend_opt['client_keys'] = ['lq', 'gt'] + self.paths = paired_paths_from_lmdb([self.lq_folder, self.gt_folder], ['lq', 'gt']) + elif 'meta_info' in self.opt and self.opt['meta_info'] is not None: + with open(self.opt['meta_info']) as fin: + paths = [line.strip() for line in fin] + self.paths = [] + for path in paths: + gt_path, lq_path = path.split(', ') + gt_path = os.path.join(self.gt_folder, gt_path) + lq_path = os.path.join(self.lq_folder, lq_path) + self.paths.append(dict([('gt_path', gt_path), ('lq_path', lq_path)])) + else: + self.paths = paired_paths_from_folder([self.lq_folder, self.gt_folder], ['lq', 'gt'], self.filename_tmpl) + + def __getitem__(self, index): + if self.file_client is None: + self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt) + + scale = self.opt['scale'] + + # Load gt and lq images. Dimension order: HWC; channel order: BGR; + # image range: [0, 1], float32. + gt_path = self.paths[index]['gt_path'] + img_bytes = self.file_client.get(gt_path, 'gt') + img_gt = imfrombytes(img_bytes, float32=True) + lq_path = self.paths[index]['lq_path'] + img_bytes = self.file_client.get(lq_path, 'lq') + img_lq = imfrombytes(img_bytes, float32=True) + + # augmentation for training + if self.opt['phase'] == 'train': + gt_size = self.opt['gt_size'] + # random crop + img_gt, img_lq = paired_random_crop(img_gt, img_lq, gt_size, scale, gt_path) + # flip, rotation + img_gt, img_lq = augment([img_gt, img_lq], self.opt['use_hflip'], self.opt['use_rot']) + + # BGR to RGB, HWC to CHW, numpy to tensor + img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True) + # normalize + if self.mean is not None or self.std is not None: + normalize(img_lq, self.mean, self.std, inplace=True) + normalize(img_gt, self.mean, self.std, inplace=True) + + return {'lq': img_lq, 'gt': img_gt, 'lq_path': lq_path, 'gt_path': gt_path} + + def __len__(self): + return len(self.paths) diff --git a/realesrgan/models/realesrgan_model.py b/realesrgan/models/realesrgan_model.py index 5b1268e..c1813cf 100644 --- a/realesrgan/models/realesrgan_model.py +++ b/realesrgan/models/realesrgan_model.py @@ -19,7 +19,7 @@ class RealESRGANModel(SRGANModel): super(RealESRGANModel, self).__init__(opt) self.jpeger = DiffJPEG(differentiable=False).cuda() self.usm_sharpener = USMSharp().cuda() - self.queue_size = opt['queue_size'] + self.queue_size = opt.get('queue_size', 180) @torch.no_grad() def _dequeue_and_enqueue(self): @@ -55,7 +55,7 @@ class RealESRGANModel(SRGANModel): @torch.no_grad() def feed_data(self, data): - if self.is_train: + if self.is_train and self.opt.get('high_order_degradation', True): # training data synthesis self.gt = data['gt'].to(self.device) self.gt_usm = self.usm_sharpener(self.gt) @@ -166,6 +166,7 @@ class RealESRGANModel(SRGANModel): self.lq = data['lq'].to(self.device) if 'gt' in data: self.gt = data['gt'].to(self.device) + self.gt_usm = self.usm_sharpener(self.gt) def nondist_validation(self, dataloader, current_iter, tb_logger, save_img): # do not use the synthetic process during validation diff --git a/realesrgan/models/realesrnet_model.py b/realesrgan/models/realesrnet_model.py index 1b5651d..2129dd4 100644 --- a/realesrgan/models/realesrnet_model.py +++ b/realesrgan/models/realesrnet_model.py @@ -18,7 +18,7 @@ class RealESRNetModel(SRModel): super(RealESRNetModel, self).__init__(opt) self.jpeger = DiffJPEG(differentiable=False).cuda() self.usm_sharpener = USMSharp().cuda() - self.queue_size = opt['queue_size'] + self.queue_size = opt.get('queue_size', 180) @torch.no_grad() def _dequeue_and_enqueue(self): @@ -54,7 +54,7 @@ class RealESRNetModel(SRModel): @torch.no_grad() def feed_data(self, data): - if self.is_train: + if self.is_train and self.opt.get('high_order_degradation', True): # training data synthesis self.gt = data['gt'].to(self.device) # USM the GT images @@ -164,6 +164,7 @@ class RealESRNetModel(SRModel): self.lq = data['lq'].to(self.device) if 'gt' in data: self.gt = data['gt'].to(self.device) + self.gt_usm = self.usm_sharpener(self.gt) def nondist_validation(self, dataloader, current_iter, tb_logger, save_img): # do not use the synthetic process during validation diff --git a/requirements.txt b/requirements.txt index f4ed4c7..8352614 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,4 +5,5 @@ numpy opencv-python Pillow torch>=1.7 +torchvision tqdm diff --git a/scripts/generate_meta_info.py b/scripts/generate_meta_info.py index 7a6448e..b5aeabd 100644 --- a/scripts/generate_meta_info.py +++ b/scripts/generate_meta_info.py @@ -35,6 +35,9 @@ if __name__ == '__main__': default='datasets/DF2K/meta_info/meta_info_DF2Kmultiscale.txt', help='txt path for meta info') args = parser.parse_args() + assert len(args.input) == len(args.root), ('Input folder and folder root should have the same length, but got ' f'{len(args.input)} and {len(args.root)}.') + os.makedirs(os.path.dirname(args.meta_info), exist_ok=True) + main(args) diff --git a/scripts/generate_meta_info_pairdata.py b/scripts/generate_meta_info_pairdata.py new file mode 100644 index 0000000..4d4bf1a --- /dev/null +++ b/scripts/generate_meta_info_pairdata.py @@ -0,0 +1,47 @@ +import argparse +import glob +import os + + +def main(args): + txt_file = open(args.meta_info, 'w') + img_paths_gt = sorted(glob.glob(os.path.join(args.input[0], '*'))) + img_paths_lq = sorted(glob.glob(os.path.join(args.input[1], '*'))) + + assert len(img_paths_gt) == len(img_paths_lq), ('GT folder and LQ folder should have the same length, but got ' + f'{len(img_paths_gt)} and {len(img_paths_lq)}.') + + for img_path_gt, img_path_lq in zip(img_paths_gt, img_paths_lq): + img_name_gt = os.path.relpath(img_path_gt, args.root[0]) + img_name_lq = os.path.relpath(img_path_lq, args.root[1]) + print(f'{img_name_gt}, {img_name_lq}') + txt_file.write(f'{img_name_gt}, {img_name_lq}\n') + + +if __name__ == '__main__': + """Generate meta info (txt file) for paired images. + """ + parser = argparse.ArgumentParser() + parser.add_argument( + '--input', + nargs='+', + default=['datasets/DF2K/DIV2K_train_HR_sub', 'datasets/DF2K/DIV2K_train_LR_bicubic_X4_sub'], + help='Input folder, should be [gt_folder, lq_folder]') + parser.add_argument('--root', nargs='+', default=[None, None], help='Folder root, will use the ') + parser.add_argument( + '--meta_info', + type=str, + default='datasets/DF2K/meta_info/meta_info_DIV2K_sub_pair.txt', + help='txt path for meta info') + args = parser.parse_args() + + assert len(args.input) == 2, 'Input folder should have two elements: gt folder and lq folder' + assert len(args.root) == 2, 'Root path should have two elements: root for gt folder and lq folder' + os.makedirs(os.path.dirname(args.meta_info), exist_ok=True) + for i in range(2): + if args.input[i].endswith('/'): + args.input[i] = args.input[i][:-1] + if args.root[i] is None: + args.root[i] = os.path.dirname(args.input[i]) + + main(args) diff --git a/setup.cfg b/setup.cfg index 5dcf1ab..50ceaf0 100644 --- a/setup.cfg +++ b/setup.cfg @@ -17,6 +17,6 @@ line_length = 120 multi_line_output = 0 known_standard_library = pkg_resources,setuptools known_first_party = realesrgan -known_third_party = PIL,basicsr,cv2,numpy,torch,tqdm +known_third_party = PIL,basicsr,cv2,numpy,torch,torchvision,tqdm no_lines_before = STDLIB,LOCALFOLDER default_section = THIRDPARTY