improve codes comments

This commit is contained in:
Xintao
2021-11-23 00:52:00 +08:00
parent c9023b3d7a
commit 35ee6f781e
20 changed files with 194 additions and 102 deletions

View File

@@ -15,18 +15,31 @@ from torch.utils import data as data
@DATASET_REGISTRY.register()
class RealESRGANDataset(data.Dataset):
"""
Dataset used for Real-ESRGAN model.
"""Dataset used for Real-ESRGAN model:
Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data.
It loads gt (Ground-Truth) images, and augments them.
It also generates blur kernels and sinc kernels for generating low-quality images.
Note that the low-quality images are processed in tensors on GPUS for faster processing.
Args:
opt (dict): Config for train datasets. It contains the following keys:
dataroot_gt (str): Data root path for gt.
meta_info (str): Path for meta information file.
io_backend (dict): IO backend type and other kwarg.
use_hflip (bool): Use horizontal flips.
use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation).
Please see more options in the codes.
"""
def __init__(self, opt):
super(RealESRGANDataset, self).__init__()
self.opt = opt
# file client (io backend)
self.file_client = None
self.io_backend_opt = opt['io_backend']
self.gt_folder = opt['dataroot_gt']
# file client (lmdb io backend)
if self.io_backend_opt['type'] == 'lmdb':
self.io_backend_opt['db_paths'] = [self.gt_folder]
self.io_backend_opt['client_keys'] = ['gt']
@@ -35,18 +48,20 @@ class RealESRGANDataset(data.Dataset):
with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin:
self.paths = [line.split('.')[0] for line in fin]
else:
# disk backend with meta_info
# Each line in the meta_info describes the relative path to an image
with open(self.opt['meta_info']) as fin:
paths = [line.strip() for line in fin]
paths = [line.strip().split(' ')[0] for line in fin]
self.paths = [os.path.join(self.gt_folder, v) for v in paths]
# blur settings for the first degradation
self.blur_kernel_size = opt['blur_kernel_size']
self.kernel_list = opt['kernel_list']
self.kernel_prob = opt['kernel_prob']
self.kernel_prob = opt['kernel_prob'] # a list for each kernel probability
self.blur_sigma = opt['blur_sigma']
self.betag_range = opt['betag_range']
self.betap_range = opt['betap_range']
self.sinc_prob = opt['sinc_prob']
self.betag_range = opt['betag_range'] # betag used in generalized Gaussian blur kernels
self.betap_range = opt['betap_range'] # betap used in plateau blur kernels
self.sinc_prob = opt['sinc_prob'] # the probability for sinc filters
# blur settings for the second degradation
self.blur_kernel_size2 = opt['blur_kernel_size2']
@@ -61,6 +76,7 @@ class RealESRGANDataset(data.Dataset):
self.final_sinc_prob = opt['final_sinc_prob']
self.kernel_range = [2 * v + 1 for v in range(3, 11)] # kernel size ranges from 7 to 21
# TODO: kernel range is now hard-coded, should be in the configure file
self.pulse_tensor = torch.zeros(21, 21).float() # convolving with pulse tensor brings no blurry effect
self.pulse_tensor[10, 10] = 1
@@ -89,10 +105,11 @@ class RealESRGANDataset(data.Dataset):
retry -= 1
img_gt = imfrombytes(img_bytes, float32=True)
# -------------------- augmentation for training: flip, rotation -------------------- #
# -------------------- Do augmentation for training: flip, rotation -------------------- #
img_gt = augment(img_gt, self.opt['use_hflip'], self.opt['use_rot'])
# crop or pad to 400: 400 is hard-coded. You may change it accordingly
# crop or pad to 400
# TODO: 400 is hard-coded. You may change it accordingly
h, w = img_gt.shape[0:2]
crop_pad_size = 400
# pad
@@ -154,7 +171,7 @@ class RealESRGANDataset(data.Dataset):
pad_size = (21 - kernel_size) // 2
kernel2 = np.pad(kernel2, ((pad_size, pad_size), (pad_size, pad_size)))
# ------------------------------------- sinc kernel ------------------------------------- #
# ------------------------------------- the final sinc kernel ------------------------------------- #
if np.random.uniform() < self.opt['final_sinc_prob']:
kernel_size = random.choice(self.kernel_range)
omega_c = np.random.uniform(np.pi / 3, np.pi)

View File

@@ -11,8 +11,7 @@ from torchvision.transforms.functional import normalize
class RealESRGANPairedDataset(data.Dataset):
"""Paired image dataset for image restoration.
Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc) and
GT image pairs.
Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc) and GT image pairs.
There are three modes:
1. 'lmdb': Use lmdb files.
@@ -28,8 +27,8 @@ class RealESRGANPairedDataset(data.Dataset):
dataroot_lq (str): Data root path for lq.
meta_info (str): Path for meta information file.
io_backend (dict): IO backend type and other kwarg.
filename_tmpl (str): Template for each filename. Note that the
template excludes the file extension. Default: '{}'.
filename_tmpl (str): Template for each filename. Note that the template excludes the file extension.
Default: '{}'.
gt_size (int): Cropped patched size for gt patches.
use_hflip (bool): Use horizontal flips.
use_rot (bool): Use rotation (use vertical flip and transposing h
@@ -42,25 +41,25 @@ class RealESRGANPairedDataset(data.Dataset):
def __init__(self, opt):
super(RealESRGANPairedDataset, self).__init__()
self.opt = opt
# file client (io backend)
self.file_client = None
self.io_backend_opt = opt['io_backend']
# mean and std for normalizing the input images
self.mean = opt['mean'] if 'mean' in opt else None
self.std = opt['std'] if 'std' in opt else None
self.gt_folder, self.lq_folder = opt['dataroot_gt'], opt['dataroot_lq']
if 'filename_tmpl' in opt:
self.filename_tmpl = opt['filename_tmpl']
else:
self.filename_tmpl = '{}'
self.filename_tmpl = opt['filename_tmpl'] if 'filename_tmpl' in opt else '{}'
# file client (lmdb io backend)
if self.io_backend_opt['type'] == 'lmdb':
self.io_backend_opt['db_paths'] = [self.lq_folder, self.gt_folder]
self.io_backend_opt['client_keys'] = ['lq', 'gt']
self.paths = paired_paths_from_lmdb([self.lq_folder, self.gt_folder], ['lq', 'gt'])
elif 'meta_info' in self.opt and self.opt['meta_info'] is not None:
# disk backend with meta_info
# Each line in the meta_info describes the relative path to an image
with open(self.opt['meta_info']) as fin:
paths = [line.strip() for line in fin]
paths = [line.strip().split(' ')[0] for line in fin]
self.paths = []
for path in paths:
gt_path, lq_path = path.split(', ')
@@ -68,6 +67,9 @@ class RealESRGANPairedDataset(data.Dataset):
lq_path = os.path.join(self.lq_folder, lq_path)
self.paths.append(dict([('gt_path', gt_path), ('lq_path', lq_path)]))
else:
# disk backend
# it will scan the whole folder to get meta info
# it will be time-consuming for folders with too many files. It is recommended using an extra meta txt file
self.paths = paired_paths_from_folder([self.lq_folder, self.gt_folder], ['lq', 'gt'], self.filename_tmpl)
def __getitem__(self, index):