improve codes comments
This commit is contained in:
9
.github/workflows/no-response.yml
vendored
9
.github/workflows/no-response.yml
vendored
@@ -1,12 +1,11 @@
|
||||
name: No Response
|
||||
|
||||
# TODO: it seems not to work
|
||||
# Modified from: https://raw.githubusercontent.com/github/docs/main/.github/workflows/no-response.yaml
|
||||
|
||||
# **What it does**: Closes issues that don't have enough information to be
|
||||
# actionable.
|
||||
# **Why we have it**: To remove the need for maintainers to remember to check
|
||||
# back on issues periodically to see if contributors have
|
||||
# responded.
|
||||
# **What it does**: Closes issues that don't have enough information to be actionable.
|
||||
# **Why we have it**: To remove the need for maintainers to remember to check back on issues periodically
|
||||
# to see if contributors have responded.
|
||||
# **Who does it impact**: Everyone that works on docs or docs-internal.
|
||||
|
||||
on:
|
||||
|
||||
@@ -8,6 +8,8 @@ from realesrgan import RealESRGANer
|
||||
|
||||
|
||||
def main():
|
||||
"""Inference demo for Real-ESRGAN.
|
||||
"""
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--input', type=str, default='inputs', help='Input image or folder')
|
||||
parser.add_argument(
|
||||
@@ -53,7 +55,7 @@ def main():
|
||||
pre_pad=args.pre_pad,
|
||||
half=args.half)
|
||||
|
||||
if args.face_enhance:
|
||||
if args.face_enhance: # Use GFPGAN for face enhancement
|
||||
from gfpgan import GFPGANer
|
||||
face_enhancer = GFPGANer(
|
||||
model_path='https://github.com/TencentARC/GFPGAN/releases/download/v0.2.0/GFPGANCleanv1-NoCE-C2.pth',
|
||||
@@ -78,6 +80,7 @@ def main():
|
||||
else:
|
||||
img_mode = None
|
||||
|
||||
# give warnings for too large/small images
|
||||
h, w = img.shape[0:2]
|
||||
if max(h, w) > 1000 and args.netscale == 4:
|
||||
import warnings
|
||||
|
||||
@@ -90,7 +90,6 @@ network_g:
|
||||
num_block: 23
|
||||
num_grow_ch: 32
|
||||
|
||||
|
||||
network_d:
|
||||
type: UNetDiscriminatorSN
|
||||
num_in_ch: 3
|
||||
@@ -169,7 +168,7 @@ train:
|
||||
# save_img: True
|
||||
|
||||
# metrics:
|
||||
# psnr: # metric name, can be arbitrary
|
||||
# psnr: # metric name
|
||||
# type: calculate_psnr
|
||||
# crop_border: 4
|
||||
# test_y_channel: false
|
||||
|
||||
@@ -52,7 +52,6 @@ network_g:
|
||||
num_block: 23
|
||||
num_grow_ch: 32
|
||||
|
||||
|
||||
network_d:
|
||||
type: UNetDiscriminatorSN
|
||||
num_in_ch: 3
|
||||
@@ -131,7 +130,7 @@ train:
|
||||
# save_img: True
|
||||
|
||||
# metrics:
|
||||
# psnr: # metric name, can be arbitrary
|
||||
# psnr: # metric name
|
||||
# type: calculate_psnr
|
||||
# crop_border: 4
|
||||
# test_y_channel: false
|
||||
|
||||
@@ -91,7 +91,6 @@ network_g:
|
||||
num_grow_ch: 32
|
||||
scale: 2
|
||||
|
||||
|
||||
network_d:
|
||||
type: UNetDiscriminatorSN
|
||||
num_in_ch: 3
|
||||
@@ -167,7 +166,7 @@ train:
|
||||
# save_img: True
|
||||
|
||||
# metrics:
|
||||
# psnr: # metric name, can be arbitrary
|
||||
# psnr: # metric name
|
||||
# type: calculate_psnr
|
||||
# crop_border: 4
|
||||
# test_y_channel: false
|
||||
|
||||
@@ -90,7 +90,6 @@ network_g:
|
||||
num_block: 23
|
||||
num_grow_ch: 32
|
||||
|
||||
|
||||
network_d:
|
||||
type: UNetDiscriminatorSN
|
||||
num_in_ch: 3
|
||||
@@ -166,7 +165,7 @@ train:
|
||||
# save_img: True
|
||||
|
||||
# metrics:
|
||||
# psnr: # metric name, can be arbitrary
|
||||
# psnr: # metric name
|
||||
# type: calculate_psnr
|
||||
# crop_border: 4
|
||||
# test_y_channel: false
|
||||
|
||||
@@ -125,7 +125,7 @@ train:
|
||||
# save_img: True
|
||||
|
||||
# metrics:
|
||||
# psnr: # metric name, can be arbitrary
|
||||
# psnr: # metric name
|
||||
# type: calculate_psnr
|
||||
# crop_border: 4
|
||||
# test_y_channel: false
|
||||
|
||||
@@ -124,7 +124,7 @@ train:
|
||||
# save_img: True
|
||||
|
||||
# metrics:
|
||||
# psnr: # metric name, can be arbitrary
|
||||
# psnr: # metric name
|
||||
# type: calculate_psnr
|
||||
# crop_border: 4
|
||||
# test_y_channel: false
|
||||
|
||||
@@ -3,4 +3,4 @@ from .archs import *
|
||||
from .data import *
|
||||
from .models import *
|
||||
from .utils import *
|
||||
from .version import __gitsha__, __version__
|
||||
from .version import __version__
|
||||
|
||||
@@ -6,15 +6,23 @@ from torch.nn.utils import spectral_norm
|
||||
|
||||
@ARCH_REGISTRY.register()
|
||||
class UNetDiscriminatorSN(nn.Module):
|
||||
"""Defines a U-Net discriminator with spectral normalization (SN)"""
|
||||
"""Defines a U-Net discriminator with spectral normalization (SN)
|
||||
|
||||
It is used in Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data.
|
||||
|
||||
Arg:
|
||||
num_in_ch (int): Channel number of inputs. Default: 3.
|
||||
num_feat (int): Channel number of base intermediate features. Default: 64.
|
||||
skip_connection (bool): Whether to use skip connections between U-Net. Default: True.
|
||||
"""
|
||||
|
||||
def __init__(self, num_in_ch, num_feat=64, skip_connection=True):
|
||||
super(UNetDiscriminatorSN, self).__init__()
|
||||
self.skip_connection = skip_connection
|
||||
norm = spectral_norm
|
||||
|
||||
# the first convolution
|
||||
self.conv0 = nn.Conv2d(num_in_ch, num_feat, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
# downsample
|
||||
self.conv1 = norm(nn.Conv2d(num_feat, num_feat * 2, 4, 2, 1, bias=False))
|
||||
self.conv2 = norm(nn.Conv2d(num_feat * 2, num_feat * 4, 4, 2, 1, bias=False))
|
||||
self.conv3 = norm(nn.Conv2d(num_feat * 4, num_feat * 8, 4, 2, 1, bias=False))
|
||||
@@ -22,14 +30,13 @@ class UNetDiscriminatorSN(nn.Module):
|
||||
self.conv4 = norm(nn.Conv2d(num_feat * 8, num_feat * 4, 3, 1, 1, bias=False))
|
||||
self.conv5 = norm(nn.Conv2d(num_feat * 4, num_feat * 2, 3, 1, 1, bias=False))
|
||||
self.conv6 = norm(nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1, bias=False))
|
||||
|
||||
# extra
|
||||
# extra convolutions
|
||||
self.conv7 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False))
|
||||
self.conv8 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False))
|
||||
|
||||
self.conv9 = nn.Conv2d(num_feat, 1, 3, 1, 1)
|
||||
|
||||
def forward(self, x):
|
||||
# downsample
|
||||
x0 = F.leaky_relu(self.conv0(x), negative_slope=0.2, inplace=True)
|
||||
x1 = F.leaky_relu(self.conv1(x0), negative_slope=0.2, inplace=True)
|
||||
x2 = F.leaky_relu(self.conv2(x1), negative_slope=0.2, inplace=True)
|
||||
@@ -52,7 +59,7 @@ class UNetDiscriminatorSN(nn.Module):
|
||||
if self.skip_connection:
|
||||
x6 = x6 + x0
|
||||
|
||||
# extra
|
||||
# extra convolutions
|
||||
out = F.leaky_relu(self.conv7(x6), negative_slope=0.2, inplace=True)
|
||||
out = F.leaky_relu(self.conv8(out), negative_slope=0.2, inplace=True)
|
||||
out = self.conv9(out)
|
||||
|
||||
@@ -15,18 +15,31 @@ from torch.utils import data as data
|
||||
|
||||
@DATASET_REGISTRY.register()
|
||||
class RealESRGANDataset(data.Dataset):
|
||||
"""
|
||||
Dataset used for Real-ESRGAN model.
|
||||
"""Dataset used for Real-ESRGAN model:
|
||||
Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data.
|
||||
|
||||
It loads gt (Ground-Truth) images, and augments them.
|
||||
It also generates blur kernels and sinc kernels for generating low-quality images.
|
||||
Note that the low-quality images are processed in tensors on GPUS for faster processing.
|
||||
|
||||
Args:
|
||||
opt (dict): Config for train datasets. It contains the following keys:
|
||||
dataroot_gt (str): Data root path for gt.
|
||||
meta_info (str): Path for meta information file.
|
||||
io_backend (dict): IO backend type and other kwarg.
|
||||
use_hflip (bool): Use horizontal flips.
|
||||
use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation).
|
||||
Please see more options in the codes.
|
||||
"""
|
||||
|
||||
def __init__(self, opt):
|
||||
super(RealESRGANDataset, self).__init__()
|
||||
self.opt = opt
|
||||
# file client (io backend)
|
||||
self.file_client = None
|
||||
self.io_backend_opt = opt['io_backend']
|
||||
self.gt_folder = opt['dataroot_gt']
|
||||
|
||||
# file client (lmdb io backend)
|
||||
if self.io_backend_opt['type'] == 'lmdb':
|
||||
self.io_backend_opt['db_paths'] = [self.gt_folder]
|
||||
self.io_backend_opt['client_keys'] = ['gt']
|
||||
@@ -35,18 +48,20 @@ class RealESRGANDataset(data.Dataset):
|
||||
with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin:
|
||||
self.paths = [line.split('.')[0] for line in fin]
|
||||
else:
|
||||
# disk backend with meta_info
|
||||
# Each line in the meta_info describes the relative path to an image
|
||||
with open(self.opt['meta_info']) as fin:
|
||||
paths = [line.strip() for line in fin]
|
||||
paths = [line.strip().split(' ')[0] for line in fin]
|
||||
self.paths = [os.path.join(self.gt_folder, v) for v in paths]
|
||||
|
||||
# blur settings for the first degradation
|
||||
self.blur_kernel_size = opt['blur_kernel_size']
|
||||
self.kernel_list = opt['kernel_list']
|
||||
self.kernel_prob = opt['kernel_prob']
|
||||
self.kernel_prob = opt['kernel_prob'] # a list for each kernel probability
|
||||
self.blur_sigma = opt['blur_sigma']
|
||||
self.betag_range = opt['betag_range']
|
||||
self.betap_range = opt['betap_range']
|
||||
self.sinc_prob = opt['sinc_prob']
|
||||
self.betag_range = opt['betag_range'] # betag used in generalized Gaussian blur kernels
|
||||
self.betap_range = opt['betap_range'] # betap used in plateau blur kernels
|
||||
self.sinc_prob = opt['sinc_prob'] # the probability for sinc filters
|
||||
|
||||
# blur settings for the second degradation
|
||||
self.blur_kernel_size2 = opt['blur_kernel_size2']
|
||||
@@ -61,6 +76,7 @@ class RealESRGANDataset(data.Dataset):
|
||||
self.final_sinc_prob = opt['final_sinc_prob']
|
||||
|
||||
self.kernel_range = [2 * v + 1 for v in range(3, 11)] # kernel size ranges from 7 to 21
|
||||
# TODO: kernel range is now hard-coded, should be in the configure file
|
||||
self.pulse_tensor = torch.zeros(21, 21).float() # convolving with pulse tensor brings no blurry effect
|
||||
self.pulse_tensor[10, 10] = 1
|
||||
|
||||
@@ -89,10 +105,11 @@ class RealESRGANDataset(data.Dataset):
|
||||
retry -= 1
|
||||
img_gt = imfrombytes(img_bytes, float32=True)
|
||||
|
||||
# -------------------- augmentation for training: flip, rotation -------------------- #
|
||||
# -------------------- Do augmentation for training: flip, rotation -------------------- #
|
||||
img_gt = augment(img_gt, self.opt['use_hflip'], self.opt['use_rot'])
|
||||
|
||||
# crop or pad to 400: 400 is hard-coded. You may change it accordingly
|
||||
# crop or pad to 400
|
||||
# TODO: 400 is hard-coded. You may change it accordingly
|
||||
h, w = img_gt.shape[0:2]
|
||||
crop_pad_size = 400
|
||||
# pad
|
||||
@@ -154,7 +171,7 @@ class RealESRGANDataset(data.Dataset):
|
||||
pad_size = (21 - kernel_size) // 2
|
||||
kernel2 = np.pad(kernel2, ((pad_size, pad_size), (pad_size, pad_size)))
|
||||
|
||||
# ------------------------------------- sinc kernel ------------------------------------- #
|
||||
# ------------------------------------- the final sinc kernel ------------------------------------- #
|
||||
if np.random.uniform() < self.opt['final_sinc_prob']:
|
||||
kernel_size = random.choice(self.kernel_range)
|
||||
omega_c = np.random.uniform(np.pi / 3, np.pi)
|
||||
|
||||
@@ -11,8 +11,7 @@ from torchvision.transforms.functional import normalize
|
||||
class RealESRGANPairedDataset(data.Dataset):
|
||||
"""Paired image dataset for image restoration.
|
||||
|
||||
Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc) and
|
||||
GT image pairs.
|
||||
Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc) and GT image pairs.
|
||||
|
||||
There are three modes:
|
||||
1. 'lmdb': Use lmdb files.
|
||||
@@ -28,8 +27,8 @@ class RealESRGANPairedDataset(data.Dataset):
|
||||
dataroot_lq (str): Data root path for lq.
|
||||
meta_info (str): Path for meta information file.
|
||||
io_backend (dict): IO backend type and other kwarg.
|
||||
filename_tmpl (str): Template for each filename. Note that the
|
||||
template excludes the file extension. Default: '{}'.
|
||||
filename_tmpl (str): Template for each filename. Note that the template excludes the file extension.
|
||||
Default: '{}'.
|
||||
gt_size (int): Cropped patched size for gt patches.
|
||||
use_hflip (bool): Use horizontal flips.
|
||||
use_rot (bool): Use rotation (use vertical flip and transposing h
|
||||
@@ -42,25 +41,25 @@ class RealESRGANPairedDataset(data.Dataset):
|
||||
def __init__(self, opt):
|
||||
super(RealESRGANPairedDataset, self).__init__()
|
||||
self.opt = opt
|
||||
# file client (io backend)
|
||||
self.file_client = None
|
||||
self.io_backend_opt = opt['io_backend']
|
||||
# mean and std for normalizing the input images
|
||||
self.mean = opt['mean'] if 'mean' in opt else None
|
||||
self.std = opt['std'] if 'std' in opt else None
|
||||
|
||||
self.gt_folder, self.lq_folder = opt['dataroot_gt'], opt['dataroot_lq']
|
||||
if 'filename_tmpl' in opt:
|
||||
self.filename_tmpl = opt['filename_tmpl']
|
||||
else:
|
||||
self.filename_tmpl = '{}'
|
||||
self.filename_tmpl = opt['filename_tmpl'] if 'filename_tmpl' in opt else '{}'
|
||||
|
||||
# file client (lmdb io backend)
|
||||
if self.io_backend_opt['type'] == 'lmdb':
|
||||
self.io_backend_opt['db_paths'] = [self.lq_folder, self.gt_folder]
|
||||
self.io_backend_opt['client_keys'] = ['lq', 'gt']
|
||||
self.paths = paired_paths_from_lmdb([self.lq_folder, self.gt_folder], ['lq', 'gt'])
|
||||
elif 'meta_info' in self.opt and self.opt['meta_info'] is not None:
|
||||
# disk backend with meta_info
|
||||
# Each line in the meta_info describes the relative path to an image
|
||||
with open(self.opt['meta_info']) as fin:
|
||||
paths = [line.strip() for line in fin]
|
||||
paths = [line.strip().split(' ')[0] for line in fin]
|
||||
self.paths = []
|
||||
for path in paths:
|
||||
gt_path, lq_path = path.split(', ')
|
||||
@@ -68,6 +67,9 @@ class RealESRGANPairedDataset(data.Dataset):
|
||||
lq_path = os.path.join(self.lq_folder, lq_path)
|
||||
self.paths.append(dict([('gt_path', gt_path), ('lq_path', lq_path)]))
|
||||
else:
|
||||
# disk backend
|
||||
# it will scan the whole folder to get meta info
|
||||
# it will be time-consuming for folders with too many files. It is recommended using an extra meta txt file
|
||||
self.paths = paired_paths_from_folder([self.lq_folder, self.gt_folder], ['lq', 'gt'], self.filename_tmpl)
|
||||
|
||||
def __getitem__(self, index):
|
||||
|
||||
@@ -13,35 +13,45 @@ from torch.nn import functional as F
|
||||
|
||||
@MODEL_REGISTRY.register()
|
||||
class RealESRGANModel(SRGANModel):
|
||||
"""RealESRGAN Model"""
|
||||
"""RealESRGAN Model for Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data.
|
||||
|
||||
It mainly performs:
|
||||
1. randomly synthesize LQ images in GPU tensors
|
||||
2. optimize the networks with GAN training.
|
||||
"""
|
||||
|
||||
def __init__(self, opt):
|
||||
super(RealESRGANModel, self).__init__(opt)
|
||||
self.jpeger = DiffJPEG(differentiable=False).cuda()
|
||||
self.usm_sharpener = USMSharp().cuda()
|
||||
self.jpeger = DiffJPEG(differentiable=False).cuda() # simulate JPEG compression artifacts
|
||||
self.usm_sharpener = USMSharp().cuda() # do usm sharpening
|
||||
self.queue_size = opt.get('queue_size', 180)
|
||||
|
||||
@torch.no_grad()
|
||||
def _dequeue_and_enqueue(self):
|
||||
# training pair pool
|
||||
"""It is the training pair pool for increasing the diversity in a batch.
|
||||
|
||||
Batch processing limits the diversity of synthetic degradations in a batch. For example, samples in a
|
||||
batch could not have different resize scaling factors. Therefore, we employ this training pair pool
|
||||
to increase the degradation diversity in a batch.
|
||||
"""
|
||||
# initialize
|
||||
b, c, h, w = self.lq.size()
|
||||
if not hasattr(self, 'queue_lr'):
|
||||
assert self.queue_size % b == 0, 'queue size should be divisible by batch size'
|
||||
assert self.queue_size % b == 0, f'queue size {self.queue_size} should be divisible by batch size {b}'
|
||||
self.queue_lr = torch.zeros(self.queue_size, c, h, w).cuda()
|
||||
_, c, h, w = self.gt.size()
|
||||
self.queue_gt = torch.zeros(self.queue_size, c, h, w).cuda()
|
||||
self.queue_ptr = 0
|
||||
if self.queue_ptr == self.queue_size: # full
|
||||
if self.queue_ptr == self.queue_size: # the pool is full
|
||||
# do dequeue and enqueue
|
||||
# shuffle
|
||||
idx = torch.randperm(self.queue_size)
|
||||
self.queue_lr = self.queue_lr[idx]
|
||||
self.queue_gt = self.queue_gt[idx]
|
||||
# get
|
||||
# get first b samples
|
||||
lq_dequeue = self.queue_lr[0:b, :, :, :].clone()
|
||||
gt_dequeue = self.queue_gt[0:b, :, :, :].clone()
|
||||
# update
|
||||
# update the queue
|
||||
self.queue_lr[0:b, :, :, :] = self.lq.clone()
|
||||
self.queue_gt[0:b, :, :, :] = self.gt.clone()
|
||||
|
||||
@@ -55,6 +65,8 @@ class RealESRGANModel(SRGANModel):
|
||||
|
||||
@torch.no_grad()
|
||||
def feed_data(self, data):
|
||||
"""Accept data from dataloader, and then add two-order degradations to obtain LQ images.
|
||||
"""
|
||||
if self.is_train and self.opt.get('high_order_degradation', True):
|
||||
# training data synthesis
|
||||
self.gt = data['gt'].to(self.device)
|
||||
@@ -79,7 +91,7 @@ class RealESRGANModel(SRGANModel):
|
||||
scale = 1
|
||||
mode = random.choice(['area', 'bilinear', 'bicubic'])
|
||||
out = F.interpolate(out, scale_factor=scale, mode=mode)
|
||||
# noise
|
||||
# add noise
|
||||
gray_noise_prob = self.opt['gray_noise_prob']
|
||||
if np.random.uniform() < self.opt['gaussian_noise_prob']:
|
||||
out = random_add_gaussian_noise_pt(
|
||||
@@ -93,7 +105,7 @@ class RealESRGANModel(SRGANModel):
|
||||
rounds=False)
|
||||
# JPEG compression
|
||||
jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range'])
|
||||
out = torch.clamp(out, 0, 1)
|
||||
out = torch.clamp(out, 0, 1) # clamp to [0, 1], otherwise JPEGer will result in unpleasant artifacts
|
||||
out = self.jpeger(out, quality=jpeg_p)
|
||||
|
||||
# ----------------------- The second degradation process ----------------------- #
|
||||
@@ -111,7 +123,7 @@ class RealESRGANModel(SRGANModel):
|
||||
mode = random.choice(['area', 'bilinear', 'bicubic'])
|
||||
out = F.interpolate(
|
||||
out, size=(int(ori_h / self.opt['scale'] * scale), int(ori_w / self.opt['scale'] * scale)), mode=mode)
|
||||
# noise
|
||||
# add noise
|
||||
gray_noise_prob = self.opt['gray_noise_prob2']
|
||||
if np.random.uniform() < self.opt['gaussian_noise_prob2']:
|
||||
out = random_add_gaussian_noise_pt(
|
||||
@@ -162,7 +174,9 @@ class RealESRGANModel(SRGANModel):
|
||||
self._dequeue_and_enqueue()
|
||||
# sharpen self.gt again, as we have changed the self.gt with self._dequeue_and_enqueue
|
||||
self.gt_usm = self.usm_sharpener(self.gt)
|
||||
self.lq = self.lq.contiguous() # for the warning: grad and param do not obey the gradient layout contract
|
||||
else:
|
||||
# for paired training or validation
|
||||
self.lq = data['lq'].to(self.device)
|
||||
if 'gt' in data:
|
||||
self.gt = data['gt'].to(self.device)
|
||||
@@ -175,6 +189,7 @@ class RealESRGANModel(SRGANModel):
|
||||
self.is_train = True
|
||||
|
||||
def optimize_parameters(self, current_iter):
|
||||
# usm sharpening
|
||||
l1_gt = self.gt_usm
|
||||
percep_gt = self.gt_usm
|
||||
gan_gt = self.gt_usm
|
||||
|
||||
@@ -12,35 +12,46 @@ from torch.nn import functional as F
|
||||
|
||||
@MODEL_REGISTRY.register()
|
||||
class RealESRNetModel(SRModel):
|
||||
"""RealESRNet Model"""
|
||||
"""RealESRNet Model for Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data.
|
||||
|
||||
It is trained without GAN losses.
|
||||
It mainly performs:
|
||||
1. randomly synthesize LQ images in GPU tensors
|
||||
2. optimize the networks with GAN training.
|
||||
"""
|
||||
|
||||
def __init__(self, opt):
|
||||
super(RealESRNetModel, self).__init__(opt)
|
||||
self.jpeger = DiffJPEG(differentiable=False).cuda()
|
||||
self.usm_sharpener = USMSharp().cuda()
|
||||
self.jpeger = DiffJPEG(differentiable=False).cuda() # simulate JPEG compression artifacts
|
||||
self.usm_sharpener = USMSharp().cuda() # do usm sharpening
|
||||
self.queue_size = opt.get('queue_size', 180)
|
||||
|
||||
@torch.no_grad()
|
||||
def _dequeue_and_enqueue(self):
|
||||
# training pair pool
|
||||
"""It is the training pair pool for increasing the diversity in a batch.
|
||||
|
||||
Batch processing limits the diversity of synthetic degradations in a batch. For example, samples in a
|
||||
batch could not have different resize scaling factors. Therefore, we employ this training pair pool
|
||||
to increase the degradation diversity in a batch.
|
||||
"""
|
||||
# initialize
|
||||
b, c, h, w = self.lq.size()
|
||||
if not hasattr(self, 'queue_lr'):
|
||||
assert self.queue_size % b == 0, 'queue size should be divisible by batch size'
|
||||
assert self.queue_size % b == 0, f'queue size {self.queue_size} should be divisible by batch size {b}'
|
||||
self.queue_lr = torch.zeros(self.queue_size, c, h, w).cuda()
|
||||
_, c, h, w = self.gt.size()
|
||||
self.queue_gt = torch.zeros(self.queue_size, c, h, w).cuda()
|
||||
self.queue_ptr = 0
|
||||
if self.queue_ptr == self.queue_size: # full
|
||||
if self.queue_ptr == self.queue_size: # the pool is full
|
||||
# do dequeue and enqueue
|
||||
# shuffle
|
||||
idx = torch.randperm(self.queue_size)
|
||||
self.queue_lr = self.queue_lr[idx]
|
||||
self.queue_gt = self.queue_gt[idx]
|
||||
# get
|
||||
# get first b samples
|
||||
lq_dequeue = self.queue_lr[0:b, :, :, :].clone()
|
||||
gt_dequeue = self.queue_gt[0:b, :, :, :].clone()
|
||||
# update
|
||||
# update the queue
|
||||
self.queue_lr[0:b, :, :, :] = self.lq.clone()
|
||||
self.queue_gt[0:b, :, :, :] = self.gt.clone()
|
||||
|
||||
@@ -54,10 +65,12 @@ class RealESRNetModel(SRModel):
|
||||
|
||||
@torch.no_grad()
|
||||
def feed_data(self, data):
|
||||
"""Accept data from dataloader, and then add two-order degradations to obtain LQ images.
|
||||
"""
|
||||
if self.is_train and self.opt.get('high_order_degradation', True):
|
||||
# training data synthesis
|
||||
self.gt = data['gt'].to(self.device)
|
||||
# USM the GT images
|
||||
# USM sharpen the GT images
|
||||
if self.opt['gt_usm'] is True:
|
||||
self.gt = self.usm_sharpener(self.gt)
|
||||
|
||||
@@ -80,7 +93,7 @@ class RealESRNetModel(SRModel):
|
||||
scale = 1
|
||||
mode = random.choice(['area', 'bilinear', 'bicubic'])
|
||||
out = F.interpolate(out, scale_factor=scale, mode=mode)
|
||||
# noise
|
||||
# add noise
|
||||
gray_noise_prob = self.opt['gray_noise_prob']
|
||||
if np.random.uniform() < self.opt['gaussian_noise_prob']:
|
||||
out = random_add_gaussian_noise_pt(
|
||||
@@ -94,7 +107,7 @@ class RealESRNetModel(SRModel):
|
||||
rounds=False)
|
||||
# JPEG compression
|
||||
jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range'])
|
||||
out = torch.clamp(out, 0, 1)
|
||||
out = torch.clamp(out, 0, 1) # clamp to [0, 1], otherwise JPEGer will result in unpleasant artifacts
|
||||
out = self.jpeger(out, quality=jpeg_p)
|
||||
|
||||
# ----------------------- The second degradation process ----------------------- #
|
||||
@@ -112,7 +125,7 @@ class RealESRNetModel(SRModel):
|
||||
mode = random.choice(['area', 'bilinear', 'bicubic'])
|
||||
out = F.interpolate(
|
||||
out, size=(int(ori_h / self.opt['scale'] * scale), int(ori_w / self.opt['scale'] * scale)), mode=mode)
|
||||
# noise
|
||||
# add noise
|
||||
gray_noise_prob = self.opt['gray_noise_prob2']
|
||||
if np.random.uniform() < self.opt['gaussian_noise_prob2']:
|
||||
out = random_add_gaussian_noise_pt(
|
||||
@@ -160,7 +173,9 @@ class RealESRNetModel(SRModel):
|
||||
|
||||
# training pair pool
|
||||
self._dequeue_and_enqueue()
|
||||
self.lq = self.lq.contiguous() # for the warning: grad and param do not obey the gradient layout contract
|
||||
else:
|
||||
# for paired training or validation
|
||||
self.lq = data['lq'].to(self.device)
|
||||
if 'gt' in data:
|
||||
self.gt = data['gt'].to(self.device)
|
||||
|
||||
@@ -12,6 +12,19 @@ ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
|
||||
class RealESRGANer():
|
||||
"""A helper class for upsampling images with RealESRGAN.
|
||||
|
||||
Args:
|
||||
scale (int): Upsampling scale factor used in the networks. It is usually 2 or 4.
|
||||
model_path (str): The path to the pretrained model. It can be urls (will first download it automatically).
|
||||
model (nn.Module): The defined network. If None, the model will be constructed here. Default: None.
|
||||
tile (int): As too large images result in the out of GPU memory issue, so this tile option will first crop
|
||||
input images into tiles, and then process each of them. Finally, they will be merged into one image.
|
||||
0 denotes for do not use tile. Default: 0.
|
||||
tile_pad (int): The pad size for each tile, to remove border artifacts. Default: 10.
|
||||
pre_pad (int): Pad the input images to avoid border artifacts. Default: 10.
|
||||
half (float): Whether to use half precision during inference. Default: False.
|
||||
"""
|
||||
|
||||
def __init__(self, scale, model_path, model=None, tile=0, tile_pad=10, pre_pad=10, half=False):
|
||||
self.scale = scale
|
||||
@@ -26,10 +39,12 @@ class RealESRGANer():
|
||||
if model is None:
|
||||
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=scale)
|
||||
|
||||
# if the model_path starts with https, it will first download models to the folder: realesrgan/weights
|
||||
if model_path.startswith('https://'):
|
||||
model_path = load_file_from_url(
|
||||
url=model_path, model_dir='realesrgan/weights', progress=True, file_name=None)
|
||||
loadnet = torch.load(model_path)
|
||||
# prefer to use params_ema
|
||||
if 'params_ema' in loadnet:
|
||||
keyname = 'params_ema'
|
||||
else:
|
||||
@@ -41,6 +56,8 @@ class RealESRGANer():
|
||||
self.model = self.model.half()
|
||||
|
||||
def pre_process(self, img):
|
||||
"""Pre-process, such as pre-pad and mod pad, so that the images can be divisible
|
||||
"""
|
||||
img = torch.from_numpy(np.transpose(img, (2, 0, 1))).float()
|
||||
self.img = img.unsqueeze(0).to(self.device)
|
||||
if self.half:
|
||||
@@ -49,7 +66,7 @@ class RealESRGANer():
|
||||
# pre_pad
|
||||
if self.pre_pad != 0:
|
||||
self.img = F.pad(self.img, (0, self.pre_pad, 0, self.pre_pad), 'reflect')
|
||||
# mod pad
|
||||
# mod pad for divisible borders
|
||||
if self.scale == 2:
|
||||
self.mod_scale = 2
|
||||
elif self.scale == 1:
|
||||
@@ -64,10 +81,14 @@ class RealESRGANer():
|
||||
self.img = F.pad(self.img, (0, self.mod_pad_w, 0, self.mod_pad_h), 'reflect')
|
||||
|
||||
def process(self):
|
||||
# model inference
|
||||
self.output = self.model(self.img)
|
||||
|
||||
def tile_process(self):
|
||||
"""Modified from: https://github.com/ata4/esrgan-launcher
|
||||
"""It will first crop input images to tiles, and then process each tile.
|
||||
Finally, all the processed tiles are merged into one images.
|
||||
|
||||
Modified from: https://github.com/ata4/esrgan-launcher
|
||||
"""
|
||||
batch, channel, height, width = self.img.shape
|
||||
output_height = height * self.scale
|
||||
@@ -188,7 +209,7 @@ class RealESRGANer():
|
||||
output_alpha = output_alpha.data.squeeze().float().cpu().clamp_(0, 1).numpy()
|
||||
output_alpha = np.transpose(output_alpha[[2, 1, 0], :, :], (1, 2, 0))
|
||||
output_alpha = cv2.cvtColor(output_alpha, cv2.COLOR_BGR2GRAY)
|
||||
else:
|
||||
else: # use the cv2 resize for alpha channel
|
||||
h, w = alpha.shape[0:2]
|
||||
output_alpha = cv2.resize(alpha, (w * self.scale, h * self.scale), interpolation=cv2.INTER_LINEAR)
|
||||
|
||||
@@ -213,7 +234,9 @@ class RealESRGANer():
|
||||
|
||||
|
||||
def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
|
||||
"""Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py
|
||||
"""Load file form http url, will download models if necessary.
|
||||
|
||||
Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py
|
||||
"""
|
||||
if model_dir is None:
|
||||
hub_dir = get_dir()
|
||||
|
||||
@@ -14,34 +14,24 @@ def main(args):
|
||||
|
||||
opt (dict): Configuration dict. It contains:
|
||||
n_thread (int): Thread number.
|
||||
compression_level (int): CV_IMWRITE_PNG_COMPRESSION from 0 to 9.
|
||||
A higher value means a smaller size and longer compression time.
|
||||
Use 0 for faster CPU decompression. Default: 3, same in cv2.
|
||||
|
||||
compression_level (int): CV_IMWRITE_PNG_COMPRESSION from 0 to 9. A higher value means a smaller size
|
||||
and longer compression time. Use 0 for faster CPU decompression. Default: 3, same in cv2.
|
||||
input_folder (str): Path to the input folder.
|
||||
save_folder (str): Path to save folder.
|
||||
crop_size (int): Crop size.
|
||||
step (int): Step for overlapped sliding window.
|
||||
thresh_size (int): Threshold size. Patches whose size is lower
|
||||
than thresh_size will be dropped.
|
||||
thresh_size (int): Threshold size. Patches whose size is lower than thresh_size will be dropped.
|
||||
|
||||
Usage:
|
||||
For each folder, run this script.
|
||||
Typically, there are four folders to be processed for DIV2K dataset.
|
||||
DIV2K_train_HR
|
||||
DIV2K_train_LR_bicubic/X2
|
||||
DIV2K_train_LR_bicubic/X3
|
||||
DIV2K_train_LR_bicubic/X4
|
||||
After process, each sub_folder should have the same number of
|
||||
subimages.
|
||||
Typically, there are GT folder and LQ folder to be processed for DIV2K dataset.
|
||||
After process, each sub_folder should have the same number of subimages.
|
||||
Remember to modify opt configurations according to your settings.
|
||||
"""
|
||||
|
||||
opt = {}
|
||||
opt['n_thread'] = args.n_thread
|
||||
opt['compression_level'] = args.compression_level
|
||||
|
||||
# HR images
|
||||
opt['input_folder'] = args.input
|
||||
opt['save_folder'] = args.output
|
||||
opt['crop_size'] = args.crop_size
|
||||
@@ -68,6 +58,7 @@ def extract_subimages(opt):
|
||||
print(f'Folder {save_folder} already exists. Exit.')
|
||||
sys.exit(1)
|
||||
|
||||
# scan all images
|
||||
img_list = list(scandir(input_folder, full_path=True))
|
||||
|
||||
pbar = tqdm(total=len(img_list), unit='image', desc='Extract')
|
||||
@@ -88,8 +79,7 @@ def worker(path, opt):
|
||||
opt (dict): Configuration dict. It contains:
|
||||
crop_size (int): Crop size.
|
||||
step (int): Step for overlapped sliding window.
|
||||
thresh_size (int): Threshold size. Patches whose size is lower
|
||||
than thresh_size will be dropped.
|
||||
thresh_size (int): Threshold size. Patches whose size is lower than thresh_size will be dropped.
|
||||
save_folder (str): Path to save folder.
|
||||
compression_level (int): for cv2.IMWRITE_PNG_COMPRESSION.
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@ def main(args):
|
||||
for img_path in img_paths:
|
||||
status = True
|
||||
if args.check:
|
||||
# read the image once for check, as some images may have errors
|
||||
try:
|
||||
img = cv2.imread(img_path)
|
||||
except Exception as error:
|
||||
@@ -20,6 +21,7 @@ def main(args):
|
||||
status = False
|
||||
print(f'Img is None: {img_path}')
|
||||
if status:
|
||||
# get the relative path
|
||||
img_name = os.path.relpath(img_path, root)
|
||||
print(img_name)
|
||||
txt_file.write(f'{img_name}\n')
|
||||
|
||||
@@ -5,6 +5,7 @@ import os
|
||||
|
||||
def main(args):
|
||||
txt_file = open(args.meta_info, 'w')
|
||||
# sca images
|
||||
img_paths_gt = sorted(glob.glob(os.path.join(args.input[0], '*')))
|
||||
img_paths_lq = sorted(glob.glob(os.path.join(args.input[1], '*')))
|
||||
|
||||
@@ -12,6 +13,7 @@ def main(args):
|
||||
f'{len(img_paths_gt)} and {len(img_paths_lq)}.')
|
||||
|
||||
for img_path_gt, img_path_lq in zip(img_paths_gt, img_paths_lq):
|
||||
# get the relative paths
|
||||
img_name_gt = os.path.relpath(img_path_gt, args.root[0])
|
||||
img_name_lq = os.path.relpath(img_path_lq, args.root[1])
|
||||
print(f'{img_name_gt}, {img_name_lq}')
|
||||
@@ -19,7 +21,7 @@ def main(args):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
"""Generate meta info (txt file) for paired images.
|
||||
"""This script is used to generate meta info (txt file) for paired images.
|
||||
"""
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
|
||||
@@ -5,7 +5,6 @@ from PIL import Image
|
||||
|
||||
|
||||
def main(args):
|
||||
|
||||
# For DF2K, we consider the following three scales,
|
||||
# and the smallest image whose shortest edge is 400
|
||||
scale_list = [0.75, 0.5, 1 / 3]
|
||||
@@ -37,6 +36,9 @@ def main(args):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
"""Generate multi-scale versions for GT images with LANCZOS resampling.
|
||||
It is now used for DF2K dataset (DIV2K + Flickr 2K)
|
||||
"""
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--input', type=str, default='datasets/DF2K/DF2K_HR', help='Input folder')
|
||||
parser.add_argument('--output', type=str, default='datasets/DF2K/DF2K_multiscale', help='Output folder')
|
||||
|
||||
@@ -1,17 +1,36 @@
|
||||
import argparse
|
||||
import torch
|
||||
import torch.onnx
|
||||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||
|
||||
# An instance of your model
|
||||
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
|
||||
model.load_state_dict(torch.load('experiments/pretrained_models/RealESRGAN_x4plus.pth')['params_ema'])
|
||||
# set the train mode to false since we will only run the forward pass.
|
||||
model.train(False)
|
||||
model.cpu().eval()
|
||||
|
||||
# An example input you would normally provide to your model's forward() method
|
||||
x = torch.rand(1, 3, 64, 64)
|
||||
def main(args):
|
||||
# An instance of the model
|
||||
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
|
||||
if args.params:
|
||||
keyname = 'params'
|
||||
else:
|
||||
keyname = 'params_ema'
|
||||
model.load_state_dict(torch.load(args.input)[keyname])
|
||||
# set the train mode to false since we will only run the forward pass.
|
||||
model.train(False)
|
||||
model.cpu().eval()
|
||||
|
||||
# Export the model
|
||||
with torch.no_grad():
|
||||
torch_out = torch.onnx._export(model, x, 'realesrgan-x4.onnx', opset_version=11, export_params=True)
|
||||
# An example input
|
||||
x = torch.rand(1, 3, 64, 64)
|
||||
# Export the model
|
||||
with torch.no_grad():
|
||||
torch_out = torch.onnx._export(model, x, args.output, opset_version=11, export_params=True)
|
||||
print(torch_out.shape)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
"""Convert pytorch model to onnx models"""
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
'--input', type=str, default='experiments/pretrained_models/RealESRGAN_x4plus.pth', help='Input model path')
|
||||
parser.add_argument('--output', type=str, default='realesrgan-x4.onnx', help='Output onnx path')
|
||||
parser.add_argument('--params', action='store_false', help='Use params instead of params_ema')
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
||||
|
||||
Reference in New Issue
Block a user