support finetune with paired data
This commit is contained in:
@@ -16,6 +16,7 @@ We extend the powerful ESRGAN to a practical restoration application (namely, Re
|
||||
:question: Frequently Asked Questions can be found in [FAQ.md](FAQ.md).
|
||||
|
||||
:triangular_flag_on_post: **Updates**
|
||||
- :white_check_mark: Support finetuning on your own data or paired data (*i.e.*, finetuning ESRGAN). See [here](Training.md#Finetune-Real-ESRGAN-on-your-own-dataset)
|
||||
- :white_check_mark: Integrate [GFPGAN](https://github.com/TencentARC/GFPGAN) to support **face enhancement**.
|
||||
- :white_check_mark: Integrated to [Huggingface Spaces](https://huggingface.co/spaces) with [Gradio](https://github.com/gradio-app/gradio). See [Gradio Web Demo](https://huggingface.co/spaces/akhaliq/Real-ESRGAN). Thanks [@AK391](https://github.com/AK391)
|
||||
- :white_check_mark: Support arbitrary scale with `--outscale` (It actually further resizes outputs with `LANCZOS4`). Add *RealESRGAN_x2plus.pth* model.
|
||||
@@ -135,8 +136,10 @@ Results are in the `results` folder
|
||||
## :european_castle: Model Zoo
|
||||
|
||||
- [RealESRGAN_x4plus](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth)
|
||||
- [RealESRGAN_x4plus_netD](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.3/RealESRGAN_x4plus_netD.pth)
|
||||
- [RealESRNet_x4plus](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth)
|
||||
- [RealESRGAN_x2plus](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth)
|
||||
- [RealESRGAN_x2plus_netD](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.3/RealESRGAN_x2plus_netD.pth)
|
||||
- [official ESRGAN_x4](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth)
|
||||
|
||||
## :computer: Training and Finetuning on your own dataset
|
||||
|
||||
106
Training.md
106
Training.md
@@ -5,6 +5,9 @@
|
||||
- [Dataset Preparation](#dataset-preparation)
|
||||
- [Train Real-ESRNet](#Train-Real-ESRNet)
|
||||
- [Train Real-ESRGAN](#Train-Real-ESRGAN)
|
||||
- [Finetune Real-ESRGAN on your own dataset](#Finetune-Real-ESRGAN-on-your-own-dataset)
|
||||
- [Generate degraded images on the fly](#Generate-degraded-images-on-the-fly)
|
||||
- [Use paired training data](#Use-paired-training-data)
|
||||
|
||||
## Train Real-ESRGAN
|
||||
|
||||
@@ -131,3 +134,106 @@ You can merge several folders into one meta_info txt. Here is the example:
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3 \
|
||||
python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 realesrgan/train.py -opt options/train_realesrgan_x4plus.yml --launcher pytorch --auto_resume
|
||||
```
|
||||
|
||||
## Finetune Real-ESRGAN on your own dataset
|
||||
|
||||
You can finetune Real-ESRGAN on your own dataset. Typically, the fine-tuning process can be divided into two cases:
|
||||
|
||||
1. [generate degraded images on the fly](#Generate-degraded-images-on-the-fly)
|
||||
1. [use your own **paired** data(#Use-paired-training-data)
|
||||
|
||||
### Generate degraded images on the fly
|
||||
|
||||
Only high-resolution images are required. The low-quality images are generated with the degradation process in Real-ESRGAN during trainig.
|
||||
|
||||
**Prepare dataset**
|
||||
|
||||
See [this section](#dataset-preparation) for more details.
|
||||
|
||||
**Download pre-trained models**
|
||||
|
||||
Download pre-trained models into `experiments/pretrained_models`.
|
||||
|
||||
*RealESRGAN_x4plus.pth*
|
||||
|
||||
```bash
|
||||
wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth -P experiments/pretrained_models
|
||||
```
|
||||
|
||||
*RealESRGAN_x4plus_netD.pth*
|
||||
|
||||
```bash
|
||||
wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.3/RealESRGAN_x4plus_netD.pth -P experiments/pretrained_models
|
||||
```
|
||||
|
||||
**Finetune**
|
||||
|
||||
Modify [options/finetune_realesrgan_x4plus.yml](options/finetune_realesrgan_x4plus.yml) accordingly, especially the `datasets` part:
|
||||
```yml
|
||||
train:
|
||||
name: DF2K+OST
|
||||
type: RealESRGANDataset
|
||||
dataroot_gt: datasets/DF2K # modify to the root path of your folder
|
||||
meta_info: realesrgan/meta_info/meta_info_DF2Kmultiscale+OST_sub.txt # modify to your own generate meta info txt
|
||||
io_backend:
|
||||
type: disk
|
||||
```
|
||||
|
||||
We use four GPUs for training. We use the `--auto_resume` argument to automatically resume the training if necessary.
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3 \
|
||||
python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 realesrgan/train.py -opt options/finetune_realesrgan_x4plus.yml --launcher pytorch --auto_resume
|
||||
```
|
||||
|
||||
### Use paired training data
|
||||
|
||||
You can also finetune RealESRGAN with your own paired data. It is more similar to fine-tuning ESRGAN.
|
||||
|
||||
**Prepare dataset**
|
||||
|
||||
Assume that you already have two folders:
|
||||
|
||||
- gt folder (Ground-truth, high-resolution images): datasets/DF2K/DIV2K_train_HR_sub
|
||||
- lq folder (Low quality, low-resolution images): datasets/DF2K/DIV2K_train_LR_bicubic_X4_sub
|
||||
|
||||
Then, you can prepare the meta_info txt file using the script [scripts/generate_meta_info_pairdata.py](scripts/generate_meta_info_pairdata.py):
|
||||
|
||||
```bash
|
||||
python scripts/generate_meta_info_pairdata.py --input datasets/DF2K/DIV2K_train_HR_sub datasets/DF2K/DIV2K_train_LR_bicubic_X4_sub --meta_info datasets/DF2K/meta_info/meta_info_DIV2K_sub_pair.txt
|
||||
```
|
||||
|
||||
**Download pre-trained models**
|
||||
|
||||
Download pre-trained models into `experiments/pretrained_models`.
|
||||
|
||||
*RealESRGAN_x4plus.pth*
|
||||
|
||||
```bash
|
||||
wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth -P experiments/pretrained_models
|
||||
```
|
||||
|
||||
*RealESRGAN_x4plus_netD.pth*
|
||||
|
||||
```bash
|
||||
wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.3/RealESRGAN_x4plus_netD.pth -P experiments/pretrained_models
|
||||
```
|
||||
|
||||
**Finetune**
|
||||
|
||||
Modify [options/finetune_realesrgan_x4plus_pairdata.yml](options/finetune_realesrgan_x4plus_pairdata.yml) accordingly, especially the `datasets` part:
|
||||
```yml
|
||||
train:
|
||||
name: DIV2K
|
||||
type: RealESRGANPairedDataset
|
||||
dataroot_gt: datasets/DF2K # modify to the root path of your folder
|
||||
dataroot_lq: datasets/DF2K # modify to the root path of your folder
|
||||
meta_info: datasets/DF2K/meta_info/meta_info_DIV2K_sub_pair.txt # modify to the root path of your folder
|
||||
io_backend:
|
||||
type: disk
|
||||
```
|
||||
|
||||
We use four GPUs for training. We use the `--auto_resume` argument to automatically resume the training if necessary.
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3 \
|
||||
python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 realesrgan/train.py -opt options/finetune_realesrgan_x4plus_pairdata.yml --launcher pytorch --auto_resume
|
||||
```
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
# general settings
|
||||
name: finetune_RealESRGANx4plus_400k_B12G4
|
||||
name: finetune_RealESRGANx4plus_400k
|
||||
model_type: RealESRGANModel
|
||||
scale: 4
|
||||
num_gpu: 4
|
||||
num_gpu: auto
|
||||
manual_seed: 0
|
||||
|
||||
# ----------------- options for synthesizing training data in RealESRGANModel ----------------- #
|
||||
|
||||
151
options/finetune_realesrgan_x4plus_pairdata.yml
Normal file
151
options/finetune_realesrgan_x4plus_pairdata.yml
Normal file
@@ -0,0 +1,151 @@
|
||||
# general settings
|
||||
name: finetune_RealESRGANx4plus_400k_pairdata
|
||||
model_type: RealESRGANModel
|
||||
scale: 4
|
||||
num_gpu: auto
|
||||
manual_seed: 0
|
||||
|
||||
# USM the ground-truth
|
||||
l1_gt_usm: True
|
||||
percep_gt_usm: True
|
||||
gan_gt_usm: False
|
||||
|
||||
high_order_degradation: False # do not use the high-order degradation generation process
|
||||
|
||||
# dataset and data loader settings
|
||||
datasets:
|
||||
train:
|
||||
name: DIV2K
|
||||
type: RealESRGANPairedDataset
|
||||
dataroot_gt: datasets/DF2K
|
||||
dataroot_lq: datasets/DF2K
|
||||
meta_info: datasets/DF2K/meta_info/meta_info_DIV2K_sub_pair.txt
|
||||
io_backend:
|
||||
type: disk
|
||||
|
||||
gt_size: 256
|
||||
use_hflip: True
|
||||
use_rot: False
|
||||
|
||||
# data loader
|
||||
use_shuffle: true
|
||||
num_worker_per_gpu: 5
|
||||
batch_size_per_gpu: 12
|
||||
dataset_enlarge_ratio: 1
|
||||
prefetch_mode: ~
|
||||
|
||||
# Uncomment these for validation
|
||||
# val:
|
||||
# name: validation
|
||||
# type: PairedImageDataset
|
||||
# dataroot_gt: path_to_gt
|
||||
# dataroot_lq: path_to_lq
|
||||
# io_backend:
|
||||
# type: disk
|
||||
|
||||
# network structures
|
||||
network_g:
|
||||
type: RRDBNet
|
||||
num_in_ch: 3
|
||||
num_out_ch: 3
|
||||
num_feat: 64
|
||||
num_block: 23
|
||||
num_grow_ch: 32
|
||||
|
||||
|
||||
network_d:
|
||||
type: UNetDiscriminatorSN
|
||||
num_in_ch: 3
|
||||
num_feat: 64
|
||||
skip_connection: True
|
||||
|
||||
# path
|
||||
path:
|
||||
# use the pre-trained Real-ESRNet model
|
||||
pretrain_network_g: experiments/pretrained_models/RealESRNet_x4plus.pth
|
||||
param_key_g: params_ema
|
||||
strict_load_g: true
|
||||
pretrain_network_d: experiments/pretrained_models/RealESRGAN_x4plus_netD.pth
|
||||
param_key_d: params
|
||||
strict_load_d: true
|
||||
resume_state: ~
|
||||
|
||||
# training settings
|
||||
train:
|
||||
ema_decay: 0.999
|
||||
optim_g:
|
||||
type: Adam
|
||||
lr: !!float 1e-4
|
||||
weight_decay: 0
|
||||
betas: [0.9, 0.99]
|
||||
optim_d:
|
||||
type: Adam
|
||||
lr: !!float 1e-4
|
||||
weight_decay: 0
|
||||
betas: [0.9, 0.99]
|
||||
|
||||
scheduler:
|
||||
type: MultiStepLR
|
||||
milestones: [400000]
|
||||
gamma: 0.5
|
||||
|
||||
total_iter: 400000
|
||||
warmup_iter: -1 # no warm up
|
||||
|
||||
# losses
|
||||
pixel_opt:
|
||||
type: L1Loss
|
||||
loss_weight: 1.0
|
||||
reduction: mean
|
||||
# perceptual loss (content and style losses)
|
||||
perceptual_opt:
|
||||
type: PerceptualLoss
|
||||
layer_weights:
|
||||
# before relu
|
||||
'conv1_2': 0.1
|
||||
'conv2_2': 0.1
|
||||
'conv3_4': 1
|
||||
'conv4_4': 1
|
||||
'conv5_4': 1
|
||||
vgg_type: vgg19
|
||||
use_input_norm: true
|
||||
perceptual_weight: !!float 1.0
|
||||
style_weight: 0
|
||||
range_norm: false
|
||||
criterion: l1
|
||||
# gan loss
|
||||
gan_opt:
|
||||
type: GANLoss
|
||||
gan_type: vanilla
|
||||
real_label_val: 1.0
|
||||
fake_label_val: 0.0
|
||||
loss_weight: !!float 1e-1
|
||||
|
||||
net_d_iters: 1
|
||||
net_d_init_iters: 0
|
||||
|
||||
# Uncomment these for validation
|
||||
# validation settings
|
||||
# val:
|
||||
# val_freq: !!float 5e3
|
||||
# save_img: True
|
||||
|
||||
# metrics:
|
||||
# psnr: # metric name, can be arbitrary
|
||||
# type: calculate_psnr
|
||||
# crop_border: 4
|
||||
# test_y_channel: false
|
||||
|
||||
# logging settings
|
||||
logger:
|
||||
print_freq: 100
|
||||
save_checkpoint_freq: !!float 5e3
|
||||
use_tb_logger: true
|
||||
wandb:
|
||||
project: ~
|
||||
resume_id: ~
|
||||
|
||||
# dist training settings
|
||||
dist_params:
|
||||
backend: nccl
|
||||
port: 29500
|
||||
106
realesrgan/data/realesrgan_paired_dataset.py
Normal file
106
realesrgan/data/realesrgan_paired_dataset.py
Normal file
@@ -0,0 +1,106 @@
|
||||
import os
|
||||
from basicsr.data.data_util import paired_paths_from_folder, paired_paths_from_lmdb
|
||||
from basicsr.data.transforms import augment, paired_random_crop
|
||||
from basicsr.utils import FileClient, imfrombytes, img2tensor
|
||||
from basicsr.utils.registry import DATASET_REGISTRY
|
||||
from torch.utils import data as data
|
||||
from torchvision.transforms.functional import normalize
|
||||
|
||||
|
||||
@DATASET_REGISTRY.register()
|
||||
class RealESRGANPairedDataset(data.Dataset):
|
||||
"""Paired image dataset for image restoration.
|
||||
|
||||
Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc) and
|
||||
GT image pairs.
|
||||
|
||||
There are three modes:
|
||||
1. 'lmdb': Use lmdb files.
|
||||
If opt['io_backend'] == lmdb.
|
||||
2. 'meta_info': Use meta information file to generate paths.
|
||||
If opt['io_backend'] != lmdb and opt['meta_info'] is not None.
|
||||
3. 'folder': Scan folders to generate paths.
|
||||
The rest.
|
||||
|
||||
Args:
|
||||
opt (dict): Config for train datasets. It contains the following keys:
|
||||
dataroot_gt (str): Data root path for gt.
|
||||
dataroot_lq (str): Data root path for lq.
|
||||
meta_info (str): Path for meta information file.
|
||||
io_backend (dict): IO backend type and other kwarg.
|
||||
filename_tmpl (str): Template for each filename. Note that the
|
||||
template excludes the file extension. Default: '{}'.
|
||||
gt_size (int): Cropped patched size for gt patches.
|
||||
use_hflip (bool): Use horizontal flips.
|
||||
use_rot (bool): Use rotation (use vertical flip and transposing h
|
||||
and w for implementation).
|
||||
|
||||
scale (bool): Scale, which will be added automatically.
|
||||
phase (str): 'train' or 'val'.
|
||||
"""
|
||||
|
||||
def __init__(self, opt):
|
||||
super(RealESRGANPairedDataset, self).__init__()
|
||||
self.opt = opt
|
||||
# file client (io backend)
|
||||
self.file_client = None
|
||||
self.io_backend_opt = opt['io_backend']
|
||||
self.mean = opt['mean'] if 'mean' in opt else None
|
||||
self.std = opt['std'] if 'std' in opt else None
|
||||
|
||||
self.gt_folder, self.lq_folder = opt['dataroot_gt'], opt['dataroot_lq']
|
||||
if 'filename_tmpl' in opt:
|
||||
self.filename_tmpl = opt['filename_tmpl']
|
||||
else:
|
||||
self.filename_tmpl = '{}'
|
||||
|
||||
if self.io_backend_opt['type'] == 'lmdb':
|
||||
self.io_backend_opt['db_paths'] = [self.lq_folder, self.gt_folder]
|
||||
self.io_backend_opt['client_keys'] = ['lq', 'gt']
|
||||
self.paths = paired_paths_from_lmdb([self.lq_folder, self.gt_folder], ['lq', 'gt'])
|
||||
elif 'meta_info' in self.opt and self.opt['meta_info'] is not None:
|
||||
with open(self.opt['meta_info']) as fin:
|
||||
paths = [line.strip() for line in fin]
|
||||
self.paths = []
|
||||
for path in paths:
|
||||
gt_path, lq_path = path.split(', ')
|
||||
gt_path = os.path.join(self.gt_folder, gt_path)
|
||||
lq_path = os.path.join(self.lq_folder, lq_path)
|
||||
self.paths.append(dict([('gt_path', gt_path), ('lq_path', lq_path)]))
|
||||
else:
|
||||
self.paths = paired_paths_from_folder([self.lq_folder, self.gt_folder], ['lq', 'gt'], self.filename_tmpl)
|
||||
|
||||
def __getitem__(self, index):
|
||||
if self.file_client is None:
|
||||
self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
|
||||
|
||||
scale = self.opt['scale']
|
||||
|
||||
# Load gt and lq images. Dimension order: HWC; channel order: BGR;
|
||||
# image range: [0, 1], float32.
|
||||
gt_path = self.paths[index]['gt_path']
|
||||
img_bytes = self.file_client.get(gt_path, 'gt')
|
||||
img_gt = imfrombytes(img_bytes, float32=True)
|
||||
lq_path = self.paths[index]['lq_path']
|
||||
img_bytes = self.file_client.get(lq_path, 'lq')
|
||||
img_lq = imfrombytes(img_bytes, float32=True)
|
||||
|
||||
# augmentation for training
|
||||
if self.opt['phase'] == 'train':
|
||||
gt_size = self.opt['gt_size']
|
||||
# random crop
|
||||
img_gt, img_lq = paired_random_crop(img_gt, img_lq, gt_size, scale, gt_path)
|
||||
# flip, rotation
|
||||
img_gt, img_lq = augment([img_gt, img_lq], self.opt['use_hflip'], self.opt['use_rot'])
|
||||
|
||||
# BGR to RGB, HWC to CHW, numpy to tensor
|
||||
img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True)
|
||||
# normalize
|
||||
if self.mean is not None or self.std is not None:
|
||||
normalize(img_lq, self.mean, self.std, inplace=True)
|
||||
normalize(img_gt, self.mean, self.std, inplace=True)
|
||||
|
||||
return {'lq': img_lq, 'gt': img_gt, 'lq_path': lq_path, 'gt_path': gt_path}
|
||||
|
||||
def __len__(self):
|
||||
return len(self.paths)
|
||||
@@ -19,7 +19,7 @@ class RealESRGANModel(SRGANModel):
|
||||
super(RealESRGANModel, self).__init__(opt)
|
||||
self.jpeger = DiffJPEG(differentiable=False).cuda()
|
||||
self.usm_sharpener = USMSharp().cuda()
|
||||
self.queue_size = opt['queue_size']
|
||||
self.queue_size = opt.get('queue_size', 180)
|
||||
|
||||
@torch.no_grad()
|
||||
def _dequeue_and_enqueue(self):
|
||||
@@ -55,7 +55,7 @@ class RealESRGANModel(SRGANModel):
|
||||
|
||||
@torch.no_grad()
|
||||
def feed_data(self, data):
|
||||
if self.is_train:
|
||||
if self.is_train and self.opt.get('high_order_degradation', True):
|
||||
# training data synthesis
|
||||
self.gt = data['gt'].to(self.device)
|
||||
self.gt_usm = self.usm_sharpener(self.gt)
|
||||
@@ -166,6 +166,7 @@ class RealESRGANModel(SRGANModel):
|
||||
self.lq = data['lq'].to(self.device)
|
||||
if 'gt' in data:
|
||||
self.gt = data['gt'].to(self.device)
|
||||
self.gt_usm = self.usm_sharpener(self.gt)
|
||||
|
||||
def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
|
||||
# do not use the synthetic process during validation
|
||||
|
||||
@@ -18,7 +18,7 @@ class RealESRNetModel(SRModel):
|
||||
super(RealESRNetModel, self).__init__(opt)
|
||||
self.jpeger = DiffJPEG(differentiable=False).cuda()
|
||||
self.usm_sharpener = USMSharp().cuda()
|
||||
self.queue_size = opt['queue_size']
|
||||
self.queue_size = opt.get('queue_size', 180)
|
||||
|
||||
@torch.no_grad()
|
||||
def _dequeue_and_enqueue(self):
|
||||
@@ -54,7 +54,7 @@ class RealESRNetModel(SRModel):
|
||||
|
||||
@torch.no_grad()
|
||||
def feed_data(self, data):
|
||||
if self.is_train:
|
||||
if self.is_train and self.opt.get('high_order_degradation', True):
|
||||
# training data synthesis
|
||||
self.gt = data['gt'].to(self.device)
|
||||
# USM the GT images
|
||||
@@ -164,6 +164,7 @@ class RealESRNetModel(SRModel):
|
||||
self.lq = data['lq'].to(self.device)
|
||||
if 'gt' in data:
|
||||
self.gt = data['gt'].to(self.device)
|
||||
self.gt_usm = self.usm_sharpener(self.gt)
|
||||
|
||||
def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
|
||||
# do not use the synthetic process during validation
|
||||
|
||||
@@ -5,4 +5,5 @@ numpy
|
||||
opencv-python
|
||||
Pillow
|
||||
torch>=1.7
|
||||
torchvision
|
||||
tqdm
|
||||
|
||||
@@ -35,6 +35,9 @@ if __name__ == '__main__':
|
||||
default='datasets/DF2K/meta_info/meta_info_DF2Kmultiscale.txt',
|
||||
help='txt path for meta info')
|
||||
args = parser.parse_args()
|
||||
|
||||
assert len(args.input) == len(args.root), ('Input folder and folder root should have the same length, but got '
|
||||
f'{len(args.input)} and {len(args.root)}.')
|
||||
os.makedirs(os.path.dirname(args.meta_info), exist_ok=True)
|
||||
|
||||
main(args)
|
||||
|
||||
47
scripts/generate_meta_info_pairdata.py
Normal file
47
scripts/generate_meta_info_pairdata.py
Normal file
@@ -0,0 +1,47 @@
|
||||
import argparse
|
||||
import glob
|
||||
import os
|
||||
|
||||
|
||||
def main(args):
|
||||
txt_file = open(args.meta_info, 'w')
|
||||
img_paths_gt = sorted(glob.glob(os.path.join(args.input[0], '*')))
|
||||
img_paths_lq = sorted(glob.glob(os.path.join(args.input[1], '*')))
|
||||
|
||||
assert len(img_paths_gt) == len(img_paths_lq), ('GT folder and LQ folder should have the same length, but got '
|
||||
f'{len(img_paths_gt)} and {len(img_paths_lq)}.')
|
||||
|
||||
for img_path_gt, img_path_lq in zip(img_paths_gt, img_paths_lq):
|
||||
img_name_gt = os.path.relpath(img_path_gt, args.root[0])
|
||||
img_name_lq = os.path.relpath(img_path_lq, args.root[1])
|
||||
print(f'{img_name_gt}, {img_name_lq}')
|
||||
txt_file.write(f'{img_name_gt}, {img_name_lq}\n')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
"""Generate meta info (txt file) for paired images.
|
||||
"""
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
'--input',
|
||||
nargs='+',
|
||||
default=['datasets/DF2K/DIV2K_train_HR_sub', 'datasets/DF2K/DIV2K_train_LR_bicubic_X4_sub'],
|
||||
help='Input folder, should be [gt_folder, lq_folder]')
|
||||
parser.add_argument('--root', nargs='+', default=[None, None], help='Folder root, will use the ')
|
||||
parser.add_argument(
|
||||
'--meta_info',
|
||||
type=str,
|
||||
default='datasets/DF2K/meta_info/meta_info_DIV2K_sub_pair.txt',
|
||||
help='txt path for meta info')
|
||||
args = parser.parse_args()
|
||||
|
||||
assert len(args.input) == 2, 'Input folder should have two elements: gt folder and lq folder'
|
||||
assert len(args.root) == 2, 'Root path should have two elements: root for gt folder and lq folder'
|
||||
os.makedirs(os.path.dirname(args.meta_info), exist_ok=True)
|
||||
for i in range(2):
|
||||
if args.input[i].endswith('/'):
|
||||
args.input[i] = args.input[i][:-1]
|
||||
if args.root[i] is None:
|
||||
args.root[i] = os.path.dirname(args.input[i])
|
||||
|
||||
main(args)
|
||||
@@ -17,6 +17,6 @@ line_length = 120
|
||||
multi_line_output = 0
|
||||
known_standard_library = pkg_resources,setuptools
|
||||
known_first_party = realesrgan
|
||||
known_third_party = PIL,basicsr,cv2,numpy,torch,tqdm
|
||||
known_third_party = PIL,basicsr,cv2,numpy,torch,torchvision,tqdm
|
||||
no_lines_before = STDLIB,LOCALFOLDER
|
||||
default_section = THIRDPARTY
|
||||
|
||||
Reference in New Issue
Block a user