support finetune with paired data

This commit is contained in:
Xintao
2021-08-27 16:14:48 +08:00
parent 194c2c14b3
commit f5ccd64ce5
11 changed files with 426 additions and 7 deletions

View File

@@ -16,6 +16,7 @@ We extend the powerful ESRGAN to a practical restoration application (namely, Re
:question: Frequently Asked Questions can be found in [FAQ.md](FAQ.md).
:triangular_flag_on_post: **Updates**
- :white_check_mark: Support finetuning on your own data or paired data (*i.e.*, finetuning ESRGAN). See [here](Training.md#Finetune-Real-ESRGAN-on-your-own-dataset)
- :white_check_mark: Integrate [GFPGAN](https://github.com/TencentARC/GFPGAN) to support **face enhancement**.
- :white_check_mark: Integrated to [Huggingface Spaces](https://huggingface.co/spaces) with [Gradio](https://github.com/gradio-app/gradio). See [Gradio Web Demo](https://huggingface.co/spaces/akhaliq/Real-ESRGAN). Thanks [@AK391](https://github.com/AK391)
- :white_check_mark: Support arbitrary scale with `--outscale` (It actually further resizes outputs with `LANCZOS4`). Add *RealESRGAN_x2plus.pth* model.
@@ -135,8 +136,10 @@ Results are in the `results` folder
## :european_castle: Model Zoo
- [RealESRGAN_x4plus](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth)
- [RealESRGAN_x4plus_netD](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.3/RealESRGAN_x4plus_netD.pth)
- [RealESRNet_x4plus](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth)
- [RealESRGAN_x2plus](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth)
- [RealESRGAN_x2plus_netD](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.3/RealESRGAN_x2plus_netD.pth)
- [official ESRGAN_x4](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth)
## :computer: Training and Finetuning on your own dataset

View File

@@ -5,6 +5,9 @@
- [Dataset Preparation](#dataset-preparation)
- [Train Real-ESRNet](#Train-Real-ESRNet)
- [Train Real-ESRGAN](#Train-Real-ESRGAN)
- [Finetune Real-ESRGAN on your own dataset](#Finetune-Real-ESRGAN-on-your-own-dataset)
- [Generate degraded images on the fly](#Generate-degraded-images-on-the-fly)
- [Use paired training data](#Use-paired-training-data)
## Train Real-ESRGAN
@@ -131,3 +134,106 @@ You can merge several folders into one meta_info txt. Here is the example:
CUDA_VISIBLE_DEVICES=0,1,2,3 \
python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 realesrgan/train.py -opt options/train_realesrgan_x4plus.yml --launcher pytorch --auto_resume
```
## Finetune Real-ESRGAN on your own dataset
You can finetune Real-ESRGAN on your own dataset. Typically, the fine-tuning process can be divided into two cases:
1. [generate degraded images on the fly](#Generate-degraded-images-on-the-fly)
1. [use your own **paired** data(#Use-paired-training-data)
### Generate degraded images on the fly
Only high-resolution images are required. The low-quality images are generated with the degradation process in Real-ESRGAN during trainig.
**Prepare dataset**
See [this section](#dataset-preparation) for more details.
**Download pre-trained models**
Download pre-trained models into `experiments/pretrained_models`.
*RealESRGAN_x4plus.pth*
```bash
wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth -P experiments/pretrained_models
```
*RealESRGAN_x4plus_netD.pth*
```bash
wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.3/RealESRGAN_x4plus_netD.pth -P experiments/pretrained_models
```
**Finetune**
Modify [options/finetune_realesrgan_x4plus.yml](options/finetune_realesrgan_x4plus.yml) accordingly, especially the `datasets` part:
```yml
train:
name: DF2K+OST
type: RealESRGANDataset
dataroot_gt: datasets/DF2K # modify to the root path of your folder
meta_info: realesrgan/meta_info/meta_info_DF2Kmultiscale+OST_sub.txt # modify to your own generate meta info txt
io_backend:
type: disk
```
We use four GPUs for training. We use the `--auto_resume` argument to automatically resume the training if necessary.
```bash
CUDA_VISIBLE_DEVICES=0,1,2,3 \
python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 realesrgan/train.py -opt options/finetune_realesrgan_x4plus.yml --launcher pytorch --auto_resume
```
### Use paired training data
You can also finetune RealESRGAN with your own paired data. It is more similar to fine-tuning ESRGAN.
**Prepare dataset**
Assume that you already have two folders:
- gt folder (Ground-truth, high-resolution images): datasets/DF2K/DIV2K_train_HR_sub
- lq folder (Low quality, low-resolution images): datasets/DF2K/DIV2K_train_LR_bicubic_X4_sub
Then, you can prepare the meta_info txt file using the script [scripts/generate_meta_info_pairdata.py](scripts/generate_meta_info_pairdata.py):
```bash
python scripts/generate_meta_info_pairdata.py --input datasets/DF2K/DIV2K_train_HR_sub datasets/DF2K/DIV2K_train_LR_bicubic_X4_sub --meta_info datasets/DF2K/meta_info/meta_info_DIV2K_sub_pair.txt
```
**Download pre-trained models**
Download pre-trained models into `experiments/pretrained_models`.
*RealESRGAN_x4plus.pth*
```bash
wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth -P experiments/pretrained_models
```
*RealESRGAN_x4plus_netD.pth*
```bash
wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.3/RealESRGAN_x4plus_netD.pth -P experiments/pretrained_models
```
**Finetune**
Modify [options/finetune_realesrgan_x4plus_pairdata.yml](options/finetune_realesrgan_x4plus_pairdata.yml) accordingly, especially the `datasets` part:
```yml
train:
name: DIV2K
type: RealESRGANPairedDataset
dataroot_gt: datasets/DF2K # modify to the root path of your folder
dataroot_lq: datasets/DF2K # modify to the root path of your folder
meta_info: datasets/DF2K/meta_info/meta_info_DIV2K_sub_pair.txt # modify to the root path of your folder
io_backend:
type: disk
```
We use four GPUs for training. We use the `--auto_resume` argument to automatically resume the training if necessary.
```bash
CUDA_VISIBLE_DEVICES=0,1,2,3 \
python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 realesrgan/train.py -opt options/finetune_realesrgan_x4plus_pairdata.yml --launcher pytorch --auto_resume
```

View File

@@ -1,8 +1,8 @@
# general settings
name: finetune_RealESRGANx4plus_400k_B12G4
name: finetune_RealESRGANx4plus_400k
model_type: RealESRGANModel
scale: 4
num_gpu: 4
num_gpu: auto
manual_seed: 0
# ----------------- options for synthesizing training data in RealESRGANModel ----------------- #

View File

@@ -0,0 +1,151 @@
# general settings
name: finetune_RealESRGANx4plus_400k_pairdata
model_type: RealESRGANModel
scale: 4
num_gpu: auto
manual_seed: 0
# USM the ground-truth
l1_gt_usm: True
percep_gt_usm: True
gan_gt_usm: False
high_order_degradation: False # do not use the high-order degradation generation process
# dataset and data loader settings
datasets:
train:
name: DIV2K
type: RealESRGANPairedDataset
dataroot_gt: datasets/DF2K
dataroot_lq: datasets/DF2K
meta_info: datasets/DF2K/meta_info/meta_info_DIV2K_sub_pair.txt
io_backend:
type: disk
gt_size: 256
use_hflip: True
use_rot: False
# data loader
use_shuffle: true
num_worker_per_gpu: 5
batch_size_per_gpu: 12
dataset_enlarge_ratio: 1
prefetch_mode: ~
# Uncomment these for validation
# val:
# name: validation
# type: PairedImageDataset
# dataroot_gt: path_to_gt
# dataroot_lq: path_to_lq
# io_backend:
# type: disk
# network structures
network_g:
type: RRDBNet
num_in_ch: 3
num_out_ch: 3
num_feat: 64
num_block: 23
num_grow_ch: 32
network_d:
type: UNetDiscriminatorSN
num_in_ch: 3
num_feat: 64
skip_connection: True
# path
path:
# use the pre-trained Real-ESRNet model
pretrain_network_g: experiments/pretrained_models/RealESRNet_x4plus.pth
param_key_g: params_ema
strict_load_g: true
pretrain_network_d: experiments/pretrained_models/RealESRGAN_x4plus_netD.pth
param_key_d: params
strict_load_d: true
resume_state: ~
# training settings
train:
ema_decay: 0.999
optim_g:
type: Adam
lr: !!float 1e-4
weight_decay: 0
betas: [0.9, 0.99]
optim_d:
type: Adam
lr: !!float 1e-4
weight_decay: 0
betas: [0.9, 0.99]
scheduler:
type: MultiStepLR
milestones: [400000]
gamma: 0.5
total_iter: 400000
warmup_iter: -1 # no warm up
# losses
pixel_opt:
type: L1Loss
loss_weight: 1.0
reduction: mean
# perceptual loss (content and style losses)
perceptual_opt:
type: PerceptualLoss
layer_weights:
# before relu
'conv1_2': 0.1
'conv2_2': 0.1
'conv3_4': 1
'conv4_4': 1
'conv5_4': 1
vgg_type: vgg19
use_input_norm: true
perceptual_weight: !!float 1.0
style_weight: 0
range_norm: false
criterion: l1
# gan loss
gan_opt:
type: GANLoss
gan_type: vanilla
real_label_val: 1.0
fake_label_val: 0.0
loss_weight: !!float 1e-1
net_d_iters: 1
net_d_init_iters: 0
# Uncomment these for validation
# validation settings
# val:
# val_freq: !!float 5e3
# save_img: True
# metrics:
# psnr: # metric name, can be arbitrary
# type: calculate_psnr
# crop_border: 4
# test_y_channel: false
# logging settings
logger:
print_freq: 100
save_checkpoint_freq: !!float 5e3
use_tb_logger: true
wandb:
project: ~
resume_id: ~
# dist training settings
dist_params:
backend: nccl
port: 29500

View File

@@ -0,0 +1,106 @@
import os
from basicsr.data.data_util import paired_paths_from_folder, paired_paths_from_lmdb
from basicsr.data.transforms import augment, paired_random_crop
from basicsr.utils import FileClient, imfrombytes, img2tensor
from basicsr.utils.registry import DATASET_REGISTRY
from torch.utils import data as data
from torchvision.transforms.functional import normalize
@DATASET_REGISTRY.register()
class RealESRGANPairedDataset(data.Dataset):
"""Paired image dataset for image restoration.
Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc) and
GT image pairs.
There are three modes:
1. 'lmdb': Use lmdb files.
If opt['io_backend'] == lmdb.
2. 'meta_info': Use meta information file to generate paths.
If opt['io_backend'] != lmdb and opt['meta_info'] is not None.
3. 'folder': Scan folders to generate paths.
The rest.
Args:
opt (dict): Config for train datasets. It contains the following keys:
dataroot_gt (str): Data root path for gt.
dataroot_lq (str): Data root path for lq.
meta_info (str): Path for meta information file.
io_backend (dict): IO backend type and other kwarg.
filename_tmpl (str): Template for each filename. Note that the
template excludes the file extension. Default: '{}'.
gt_size (int): Cropped patched size for gt patches.
use_hflip (bool): Use horizontal flips.
use_rot (bool): Use rotation (use vertical flip and transposing h
and w for implementation).
scale (bool): Scale, which will be added automatically.
phase (str): 'train' or 'val'.
"""
def __init__(self, opt):
super(RealESRGANPairedDataset, self).__init__()
self.opt = opt
# file client (io backend)
self.file_client = None
self.io_backend_opt = opt['io_backend']
self.mean = opt['mean'] if 'mean' in opt else None
self.std = opt['std'] if 'std' in opt else None
self.gt_folder, self.lq_folder = opt['dataroot_gt'], opt['dataroot_lq']
if 'filename_tmpl' in opt:
self.filename_tmpl = opt['filename_tmpl']
else:
self.filename_tmpl = '{}'
if self.io_backend_opt['type'] == 'lmdb':
self.io_backend_opt['db_paths'] = [self.lq_folder, self.gt_folder]
self.io_backend_opt['client_keys'] = ['lq', 'gt']
self.paths = paired_paths_from_lmdb([self.lq_folder, self.gt_folder], ['lq', 'gt'])
elif 'meta_info' in self.opt and self.opt['meta_info'] is not None:
with open(self.opt['meta_info']) as fin:
paths = [line.strip() for line in fin]
self.paths = []
for path in paths:
gt_path, lq_path = path.split(', ')
gt_path = os.path.join(self.gt_folder, gt_path)
lq_path = os.path.join(self.lq_folder, lq_path)
self.paths.append(dict([('gt_path', gt_path), ('lq_path', lq_path)]))
else:
self.paths = paired_paths_from_folder([self.lq_folder, self.gt_folder], ['lq', 'gt'], self.filename_tmpl)
def __getitem__(self, index):
if self.file_client is None:
self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
scale = self.opt['scale']
# Load gt and lq images. Dimension order: HWC; channel order: BGR;
# image range: [0, 1], float32.
gt_path = self.paths[index]['gt_path']
img_bytes = self.file_client.get(gt_path, 'gt')
img_gt = imfrombytes(img_bytes, float32=True)
lq_path = self.paths[index]['lq_path']
img_bytes = self.file_client.get(lq_path, 'lq')
img_lq = imfrombytes(img_bytes, float32=True)
# augmentation for training
if self.opt['phase'] == 'train':
gt_size = self.opt['gt_size']
# random crop
img_gt, img_lq = paired_random_crop(img_gt, img_lq, gt_size, scale, gt_path)
# flip, rotation
img_gt, img_lq = augment([img_gt, img_lq], self.opt['use_hflip'], self.opt['use_rot'])
# BGR to RGB, HWC to CHW, numpy to tensor
img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True)
# normalize
if self.mean is not None or self.std is not None:
normalize(img_lq, self.mean, self.std, inplace=True)
normalize(img_gt, self.mean, self.std, inplace=True)
return {'lq': img_lq, 'gt': img_gt, 'lq_path': lq_path, 'gt_path': gt_path}
def __len__(self):
return len(self.paths)

View File

@@ -19,7 +19,7 @@ class RealESRGANModel(SRGANModel):
super(RealESRGANModel, self).__init__(opt)
self.jpeger = DiffJPEG(differentiable=False).cuda()
self.usm_sharpener = USMSharp().cuda()
self.queue_size = opt['queue_size']
self.queue_size = opt.get('queue_size', 180)
@torch.no_grad()
def _dequeue_and_enqueue(self):
@@ -55,7 +55,7 @@ class RealESRGANModel(SRGANModel):
@torch.no_grad()
def feed_data(self, data):
if self.is_train:
if self.is_train and self.opt.get('high_order_degradation', True):
# training data synthesis
self.gt = data['gt'].to(self.device)
self.gt_usm = self.usm_sharpener(self.gt)
@@ -166,6 +166,7 @@ class RealESRGANModel(SRGANModel):
self.lq = data['lq'].to(self.device)
if 'gt' in data:
self.gt = data['gt'].to(self.device)
self.gt_usm = self.usm_sharpener(self.gt)
def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
# do not use the synthetic process during validation

View File

@@ -18,7 +18,7 @@ class RealESRNetModel(SRModel):
super(RealESRNetModel, self).__init__(opt)
self.jpeger = DiffJPEG(differentiable=False).cuda()
self.usm_sharpener = USMSharp().cuda()
self.queue_size = opt['queue_size']
self.queue_size = opt.get('queue_size', 180)
@torch.no_grad()
def _dequeue_and_enqueue(self):
@@ -54,7 +54,7 @@ class RealESRNetModel(SRModel):
@torch.no_grad()
def feed_data(self, data):
if self.is_train:
if self.is_train and self.opt.get('high_order_degradation', True):
# training data synthesis
self.gt = data['gt'].to(self.device)
# USM the GT images
@@ -164,6 +164,7 @@ class RealESRNetModel(SRModel):
self.lq = data['lq'].to(self.device)
if 'gt' in data:
self.gt = data['gt'].to(self.device)
self.gt_usm = self.usm_sharpener(self.gt)
def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
# do not use the synthetic process during validation

View File

@@ -5,4 +5,5 @@ numpy
opencv-python
Pillow
torch>=1.7
torchvision
tqdm

View File

@@ -35,6 +35,9 @@ if __name__ == '__main__':
default='datasets/DF2K/meta_info/meta_info_DF2Kmultiscale.txt',
help='txt path for meta info')
args = parser.parse_args()
assert len(args.input) == len(args.root), ('Input folder and folder root should have the same length, but got '
f'{len(args.input)} and {len(args.root)}.')
os.makedirs(os.path.dirname(args.meta_info), exist_ok=True)
main(args)

View File

@@ -0,0 +1,47 @@
import argparse
import glob
import os
def main(args):
txt_file = open(args.meta_info, 'w')
img_paths_gt = sorted(glob.glob(os.path.join(args.input[0], '*')))
img_paths_lq = sorted(glob.glob(os.path.join(args.input[1], '*')))
assert len(img_paths_gt) == len(img_paths_lq), ('GT folder and LQ folder should have the same length, but got '
f'{len(img_paths_gt)} and {len(img_paths_lq)}.')
for img_path_gt, img_path_lq in zip(img_paths_gt, img_paths_lq):
img_name_gt = os.path.relpath(img_path_gt, args.root[0])
img_name_lq = os.path.relpath(img_path_lq, args.root[1])
print(f'{img_name_gt}, {img_name_lq}')
txt_file.write(f'{img_name_gt}, {img_name_lq}\n')
if __name__ == '__main__':
"""Generate meta info (txt file) for paired images.
"""
parser = argparse.ArgumentParser()
parser.add_argument(
'--input',
nargs='+',
default=['datasets/DF2K/DIV2K_train_HR_sub', 'datasets/DF2K/DIV2K_train_LR_bicubic_X4_sub'],
help='Input folder, should be [gt_folder, lq_folder]')
parser.add_argument('--root', nargs='+', default=[None, None], help='Folder root, will use the ')
parser.add_argument(
'--meta_info',
type=str,
default='datasets/DF2K/meta_info/meta_info_DIV2K_sub_pair.txt',
help='txt path for meta info')
args = parser.parse_args()
assert len(args.input) == 2, 'Input folder should have two elements: gt folder and lq folder'
assert len(args.root) == 2, 'Root path should have two elements: root for gt folder and lq folder'
os.makedirs(os.path.dirname(args.meta_info), exist_ok=True)
for i in range(2):
if args.input[i].endswith('/'):
args.input[i] = args.input[i][:-1]
if args.root[i] is None:
args.root[i] = os.path.dirname(args.input[i])
main(args)

View File

@@ -17,6 +17,6 @@ line_length = 120
multi_line_output = 0
known_standard_library = pkg_resources,setuptools
known_first_party = realesrgan
known_third_party = PIL,basicsr,cv2,numpy,torch,tqdm
known_third_party = PIL,basicsr,cv2,numpy,torch,torchvision,tqdm
no_lines_before = STDLIB,LOCALFOLDER
default_section = THIRDPARTY