improve codes comments
This commit is contained in:
@@ -3,4 +3,4 @@ from .archs import *
|
||||
from .data import *
|
||||
from .models import *
|
||||
from .utils import *
|
||||
from .version import __gitsha__, __version__
|
||||
from .version import __version__
|
||||
|
||||
@@ -6,15 +6,23 @@ from torch.nn.utils import spectral_norm
|
||||
|
||||
@ARCH_REGISTRY.register()
|
||||
class UNetDiscriminatorSN(nn.Module):
|
||||
"""Defines a U-Net discriminator with spectral normalization (SN)"""
|
||||
"""Defines a U-Net discriminator with spectral normalization (SN)
|
||||
|
||||
It is used in Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data.
|
||||
|
||||
Arg:
|
||||
num_in_ch (int): Channel number of inputs. Default: 3.
|
||||
num_feat (int): Channel number of base intermediate features. Default: 64.
|
||||
skip_connection (bool): Whether to use skip connections between U-Net. Default: True.
|
||||
"""
|
||||
|
||||
def __init__(self, num_in_ch, num_feat=64, skip_connection=True):
|
||||
super(UNetDiscriminatorSN, self).__init__()
|
||||
self.skip_connection = skip_connection
|
||||
norm = spectral_norm
|
||||
|
||||
# the first convolution
|
||||
self.conv0 = nn.Conv2d(num_in_ch, num_feat, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
# downsample
|
||||
self.conv1 = norm(nn.Conv2d(num_feat, num_feat * 2, 4, 2, 1, bias=False))
|
||||
self.conv2 = norm(nn.Conv2d(num_feat * 2, num_feat * 4, 4, 2, 1, bias=False))
|
||||
self.conv3 = norm(nn.Conv2d(num_feat * 4, num_feat * 8, 4, 2, 1, bias=False))
|
||||
@@ -22,14 +30,13 @@ class UNetDiscriminatorSN(nn.Module):
|
||||
self.conv4 = norm(nn.Conv2d(num_feat * 8, num_feat * 4, 3, 1, 1, bias=False))
|
||||
self.conv5 = norm(nn.Conv2d(num_feat * 4, num_feat * 2, 3, 1, 1, bias=False))
|
||||
self.conv6 = norm(nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1, bias=False))
|
||||
|
||||
# extra
|
||||
# extra convolutions
|
||||
self.conv7 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False))
|
||||
self.conv8 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False))
|
||||
|
||||
self.conv9 = nn.Conv2d(num_feat, 1, 3, 1, 1)
|
||||
|
||||
def forward(self, x):
|
||||
# downsample
|
||||
x0 = F.leaky_relu(self.conv0(x), negative_slope=0.2, inplace=True)
|
||||
x1 = F.leaky_relu(self.conv1(x0), negative_slope=0.2, inplace=True)
|
||||
x2 = F.leaky_relu(self.conv2(x1), negative_slope=0.2, inplace=True)
|
||||
@@ -52,7 +59,7 @@ class UNetDiscriminatorSN(nn.Module):
|
||||
if self.skip_connection:
|
||||
x6 = x6 + x0
|
||||
|
||||
# extra
|
||||
# extra convolutions
|
||||
out = F.leaky_relu(self.conv7(x6), negative_slope=0.2, inplace=True)
|
||||
out = F.leaky_relu(self.conv8(out), negative_slope=0.2, inplace=True)
|
||||
out = self.conv9(out)
|
||||
|
||||
@@ -15,18 +15,31 @@ from torch.utils import data as data
|
||||
|
||||
@DATASET_REGISTRY.register()
|
||||
class RealESRGANDataset(data.Dataset):
|
||||
"""
|
||||
Dataset used for Real-ESRGAN model.
|
||||
"""Dataset used for Real-ESRGAN model:
|
||||
Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data.
|
||||
|
||||
It loads gt (Ground-Truth) images, and augments them.
|
||||
It also generates blur kernels and sinc kernels for generating low-quality images.
|
||||
Note that the low-quality images are processed in tensors on GPUS for faster processing.
|
||||
|
||||
Args:
|
||||
opt (dict): Config for train datasets. It contains the following keys:
|
||||
dataroot_gt (str): Data root path for gt.
|
||||
meta_info (str): Path for meta information file.
|
||||
io_backend (dict): IO backend type and other kwarg.
|
||||
use_hflip (bool): Use horizontal flips.
|
||||
use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation).
|
||||
Please see more options in the codes.
|
||||
"""
|
||||
|
||||
def __init__(self, opt):
|
||||
super(RealESRGANDataset, self).__init__()
|
||||
self.opt = opt
|
||||
# file client (io backend)
|
||||
self.file_client = None
|
||||
self.io_backend_opt = opt['io_backend']
|
||||
self.gt_folder = opt['dataroot_gt']
|
||||
|
||||
# file client (lmdb io backend)
|
||||
if self.io_backend_opt['type'] == 'lmdb':
|
||||
self.io_backend_opt['db_paths'] = [self.gt_folder]
|
||||
self.io_backend_opt['client_keys'] = ['gt']
|
||||
@@ -35,18 +48,20 @@ class RealESRGANDataset(data.Dataset):
|
||||
with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin:
|
||||
self.paths = [line.split('.')[0] for line in fin]
|
||||
else:
|
||||
# disk backend with meta_info
|
||||
# Each line in the meta_info describes the relative path to an image
|
||||
with open(self.opt['meta_info']) as fin:
|
||||
paths = [line.strip() for line in fin]
|
||||
paths = [line.strip().split(' ')[0] for line in fin]
|
||||
self.paths = [os.path.join(self.gt_folder, v) for v in paths]
|
||||
|
||||
# blur settings for the first degradation
|
||||
self.blur_kernel_size = opt['blur_kernel_size']
|
||||
self.kernel_list = opt['kernel_list']
|
||||
self.kernel_prob = opt['kernel_prob']
|
||||
self.kernel_prob = opt['kernel_prob'] # a list for each kernel probability
|
||||
self.blur_sigma = opt['blur_sigma']
|
||||
self.betag_range = opt['betag_range']
|
||||
self.betap_range = opt['betap_range']
|
||||
self.sinc_prob = opt['sinc_prob']
|
||||
self.betag_range = opt['betag_range'] # betag used in generalized Gaussian blur kernels
|
||||
self.betap_range = opt['betap_range'] # betap used in plateau blur kernels
|
||||
self.sinc_prob = opt['sinc_prob'] # the probability for sinc filters
|
||||
|
||||
# blur settings for the second degradation
|
||||
self.blur_kernel_size2 = opt['blur_kernel_size2']
|
||||
@@ -61,6 +76,7 @@ class RealESRGANDataset(data.Dataset):
|
||||
self.final_sinc_prob = opt['final_sinc_prob']
|
||||
|
||||
self.kernel_range = [2 * v + 1 for v in range(3, 11)] # kernel size ranges from 7 to 21
|
||||
# TODO: kernel range is now hard-coded, should be in the configure file
|
||||
self.pulse_tensor = torch.zeros(21, 21).float() # convolving with pulse tensor brings no blurry effect
|
||||
self.pulse_tensor[10, 10] = 1
|
||||
|
||||
@@ -89,10 +105,11 @@ class RealESRGANDataset(data.Dataset):
|
||||
retry -= 1
|
||||
img_gt = imfrombytes(img_bytes, float32=True)
|
||||
|
||||
# -------------------- augmentation for training: flip, rotation -------------------- #
|
||||
# -------------------- Do augmentation for training: flip, rotation -------------------- #
|
||||
img_gt = augment(img_gt, self.opt['use_hflip'], self.opt['use_rot'])
|
||||
|
||||
# crop or pad to 400: 400 is hard-coded. You may change it accordingly
|
||||
# crop or pad to 400
|
||||
# TODO: 400 is hard-coded. You may change it accordingly
|
||||
h, w = img_gt.shape[0:2]
|
||||
crop_pad_size = 400
|
||||
# pad
|
||||
@@ -154,7 +171,7 @@ class RealESRGANDataset(data.Dataset):
|
||||
pad_size = (21 - kernel_size) // 2
|
||||
kernel2 = np.pad(kernel2, ((pad_size, pad_size), (pad_size, pad_size)))
|
||||
|
||||
# ------------------------------------- sinc kernel ------------------------------------- #
|
||||
# ------------------------------------- the final sinc kernel ------------------------------------- #
|
||||
if np.random.uniform() < self.opt['final_sinc_prob']:
|
||||
kernel_size = random.choice(self.kernel_range)
|
||||
omega_c = np.random.uniform(np.pi / 3, np.pi)
|
||||
|
||||
@@ -11,8 +11,7 @@ from torchvision.transforms.functional import normalize
|
||||
class RealESRGANPairedDataset(data.Dataset):
|
||||
"""Paired image dataset for image restoration.
|
||||
|
||||
Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc) and
|
||||
GT image pairs.
|
||||
Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc) and GT image pairs.
|
||||
|
||||
There are three modes:
|
||||
1. 'lmdb': Use lmdb files.
|
||||
@@ -28,8 +27,8 @@ class RealESRGANPairedDataset(data.Dataset):
|
||||
dataroot_lq (str): Data root path for lq.
|
||||
meta_info (str): Path for meta information file.
|
||||
io_backend (dict): IO backend type and other kwarg.
|
||||
filename_tmpl (str): Template for each filename. Note that the
|
||||
template excludes the file extension. Default: '{}'.
|
||||
filename_tmpl (str): Template for each filename. Note that the template excludes the file extension.
|
||||
Default: '{}'.
|
||||
gt_size (int): Cropped patched size for gt patches.
|
||||
use_hflip (bool): Use horizontal flips.
|
||||
use_rot (bool): Use rotation (use vertical flip and transposing h
|
||||
@@ -42,25 +41,25 @@ class RealESRGANPairedDataset(data.Dataset):
|
||||
def __init__(self, opt):
|
||||
super(RealESRGANPairedDataset, self).__init__()
|
||||
self.opt = opt
|
||||
# file client (io backend)
|
||||
self.file_client = None
|
||||
self.io_backend_opt = opt['io_backend']
|
||||
# mean and std for normalizing the input images
|
||||
self.mean = opt['mean'] if 'mean' in opt else None
|
||||
self.std = opt['std'] if 'std' in opt else None
|
||||
|
||||
self.gt_folder, self.lq_folder = opt['dataroot_gt'], opt['dataroot_lq']
|
||||
if 'filename_tmpl' in opt:
|
||||
self.filename_tmpl = opt['filename_tmpl']
|
||||
else:
|
||||
self.filename_tmpl = '{}'
|
||||
self.filename_tmpl = opt['filename_tmpl'] if 'filename_tmpl' in opt else '{}'
|
||||
|
||||
# file client (lmdb io backend)
|
||||
if self.io_backend_opt['type'] == 'lmdb':
|
||||
self.io_backend_opt['db_paths'] = [self.lq_folder, self.gt_folder]
|
||||
self.io_backend_opt['client_keys'] = ['lq', 'gt']
|
||||
self.paths = paired_paths_from_lmdb([self.lq_folder, self.gt_folder], ['lq', 'gt'])
|
||||
elif 'meta_info' in self.opt and self.opt['meta_info'] is not None:
|
||||
# disk backend with meta_info
|
||||
# Each line in the meta_info describes the relative path to an image
|
||||
with open(self.opt['meta_info']) as fin:
|
||||
paths = [line.strip() for line in fin]
|
||||
paths = [line.strip().split(' ')[0] for line in fin]
|
||||
self.paths = []
|
||||
for path in paths:
|
||||
gt_path, lq_path = path.split(', ')
|
||||
@@ -68,6 +67,9 @@ class RealESRGANPairedDataset(data.Dataset):
|
||||
lq_path = os.path.join(self.lq_folder, lq_path)
|
||||
self.paths.append(dict([('gt_path', gt_path), ('lq_path', lq_path)]))
|
||||
else:
|
||||
# disk backend
|
||||
# it will scan the whole folder to get meta info
|
||||
# it will be time-consuming for folders with too many files. It is recommended using an extra meta txt file
|
||||
self.paths = paired_paths_from_folder([self.lq_folder, self.gt_folder], ['lq', 'gt'], self.filename_tmpl)
|
||||
|
||||
def __getitem__(self, index):
|
||||
|
||||
@@ -13,35 +13,45 @@ from torch.nn import functional as F
|
||||
|
||||
@MODEL_REGISTRY.register()
|
||||
class RealESRGANModel(SRGANModel):
|
||||
"""RealESRGAN Model"""
|
||||
"""RealESRGAN Model for Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data.
|
||||
|
||||
It mainly performs:
|
||||
1. randomly synthesize LQ images in GPU tensors
|
||||
2. optimize the networks with GAN training.
|
||||
"""
|
||||
|
||||
def __init__(self, opt):
|
||||
super(RealESRGANModel, self).__init__(opt)
|
||||
self.jpeger = DiffJPEG(differentiable=False).cuda()
|
||||
self.usm_sharpener = USMSharp().cuda()
|
||||
self.jpeger = DiffJPEG(differentiable=False).cuda() # simulate JPEG compression artifacts
|
||||
self.usm_sharpener = USMSharp().cuda() # do usm sharpening
|
||||
self.queue_size = opt.get('queue_size', 180)
|
||||
|
||||
@torch.no_grad()
|
||||
def _dequeue_and_enqueue(self):
|
||||
# training pair pool
|
||||
"""It is the training pair pool for increasing the diversity in a batch.
|
||||
|
||||
Batch processing limits the diversity of synthetic degradations in a batch. For example, samples in a
|
||||
batch could not have different resize scaling factors. Therefore, we employ this training pair pool
|
||||
to increase the degradation diversity in a batch.
|
||||
"""
|
||||
# initialize
|
||||
b, c, h, w = self.lq.size()
|
||||
if not hasattr(self, 'queue_lr'):
|
||||
assert self.queue_size % b == 0, 'queue size should be divisible by batch size'
|
||||
assert self.queue_size % b == 0, f'queue size {self.queue_size} should be divisible by batch size {b}'
|
||||
self.queue_lr = torch.zeros(self.queue_size, c, h, w).cuda()
|
||||
_, c, h, w = self.gt.size()
|
||||
self.queue_gt = torch.zeros(self.queue_size, c, h, w).cuda()
|
||||
self.queue_ptr = 0
|
||||
if self.queue_ptr == self.queue_size: # full
|
||||
if self.queue_ptr == self.queue_size: # the pool is full
|
||||
# do dequeue and enqueue
|
||||
# shuffle
|
||||
idx = torch.randperm(self.queue_size)
|
||||
self.queue_lr = self.queue_lr[idx]
|
||||
self.queue_gt = self.queue_gt[idx]
|
||||
# get
|
||||
# get first b samples
|
||||
lq_dequeue = self.queue_lr[0:b, :, :, :].clone()
|
||||
gt_dequeue = self.queue_gt[0:b, :, :, :].clone()
|
||||
# update
|
||||
# update the queue
|
||||
self.queue_lr[0:b, :, :, :] = self.lq.clone()
|
||||
self.queue_gt[0:b, :, :, :] = self.gt.clone()
|
||||
|
||||
@@ -55,6 +65,8 @@ class RealESRGANModel(SRGANModel):
|
||||
|
||||
@torch.no_grad()
|
||||
def feed_data(self, data):
|
||||
"""Accept data from dataloader, and then add two-order degradations to obtain LQ images.
|
||||
"""
|
||||
if self.is_train and self.opt.get('high_order_degradation', True):
|
||||
# training data synthesis
|
||||
self.gt = data['gt'].to(self.device)
|
||||
@@ -79,7 +91,7 @@ class RealESRGANModel(SRGANModel):
|
||||
scale = 1
|
||||
mode = random.choice(['area', 'bilinear', 'bicubic'])
|
||||
out = F.interpolate(out, scale_factor=scale, mode=mode)
|
||||
# noise
|
||||
# add noise
|
||||
gray_noise_prob = self.opt['gray_noise_prob']
|
||||
if np.random.uniform() < self.opt['gaussian_noise_prob']:
|
||||
out = random_add_gaussian_noise_pt(
|
||||
@@ -93,7 +105,7 @@ class RealESRGANModel(SRGANModel):
|
||||
rounds=False)
|
||||
# JPEG compression
|
||||
jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range'])
|
||||
out = torch.clamp(out, 0, 1)
|
||||
out = torch.clamp(out, 0, 1) # clamp to [0, 1], otherwise JPEGer will result in unpleasant artifacts
|
||||
out = self.jpeger(out, quality=jpeg_p)
|
||||
|
||||
# ----------------------- The second degradation process ----------------------- #
|
||||
@@ -111,7 +123,7 @@ class RealESRGANModel(SRGANModel):
|
||||
mode = random.choice(['area', 'bilinear', 'bicubic'])
|
||||
out = F.interpolate(
|
||||
out, size=(int(ori_h / self.opt['scale'] * scale), int(ori_w / self.opt['scale'] * scale)), mode=mode)
|
||||
# noise
|
||||
# add noise
|
||||
gray_noise_prob = self.opt['gray_noise_prob2']
|
||||
if np.random.uniform() < self.opt['gaussian_noise_prob2']:
|
||||
out = random_add_gaussian_noise_pt(
|
||||
@@ -162,7 +174,9 @@ class RealESRGANModel(SRGANModel):
|
||||
self._dequeue_and_enqueue()
|
||||
# sharpen self.gt again, as we have changed the self.gt with self._dequeue_and_enqueue
|
||||
self.gt_usm = self.usm_sharpener(self.gt)
|
||||
self.lq = self.lq.contiguous() # for the warning: grad and param do not obey the gradient layout contract
|
||||
else:
|
||||
# for paired training or validation
|
||||
self.lq = data['lq'].to(self.device)
|
||||
if 'gt' in data:
|
||||
self.gt = data['gt'].to(self.device)
|
||||
@@ -175,6 +189,7 @@ class RealESRGANModel(SRGANModel):
|
||||
self.is_train = True
|
||||
|
||||
def optimize_parameters(self, current_iter):
|
||||
# usm sharpening
|
||||
l1_gt = self.gt_usm
|
||||
percep_gt = self.gt_usm
|
||||
gan_gt = self.gt_usm
|
||||
|
||||
@@ -12,35 +12,46 @@ from torch.nn import functional as F
|
||||
|
||||
@MODEL_REGISTRY.register()
|
||||
class RealESRNetModel(SRModel):
|
||||
"""RealESRNet Model"""
|
||||
"""RealESRNet Model for Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data.
|
||||
|
||||
It is trained without GAN losses.
|
||||
It mainly performs:
|
||||
1. randomly synthesize LQ images in GPU tensors
|
||||
2. optimize the networks with GAN training.
|
||||
"""
|
||||
|
||||
def __init__(self, opt):
|
||||
super(RealESRNetModel, self).__init__(opt)
|
||||
self.jpeger = DiffJPEG(differentiable=False).cuda()
|
||||
self.usm_sharpener = USMSharp().cuda()
|
||||
self.jpeger = DiffJPEG(differentiable=False).cuda() # simulate JPEG compression artifacts
|
||||
self.usm_sharpener = USMSharp().cuda() # do usm sharpening
|
||||
self.queue_size = opt.get('queue_size', 180)
|
||||
|
||||
@torch.no_grad()
|
||||
def _dequeue_and_enqueue(self):
|
||||
# training pair pool
|
||||
"""It is the training pair pool for increasing the diversity in a batch.
|
||||
|
||||
Batch processing limits the diversity of synthetic degradations in a batch. For example, samples in a
|
||||
batch could not have different resize scaling factors. Therefore, we employ this training pair pool
|
||||
to increase the degradation diversity in a batch.
|
||||
"""
|
||||
# initialize
|
||||
b, c, h, w = self.lq.size()
|
||||
if not hasattr(self, 'queue_lr'):
|
||||
assert self.queue_size % b == 0, 'queue size should be divisible by batch size'
|
||||
assert self.queue_size % b == 0, f'queue size {self.queue_size} should be divisible by batch size {b}'
|
||||
self.queue_lr = torch.zeros(self.queue_size, c, h, w).cuda()
|
||||
_, c, h, w = self.gt.size()
|
||||
self.queue_gt = torch.zeros(self.queue_size, c, h, w).cuda()
|
||||
self.queue_ptr = 0
|
||||
if self.queue_ptr == self.queue_size: # full
|
||||
if self.queue_ptr == self.queue_size: # the pool is full
|
||||
# do dequeue and enqueue
|
||||
# shuffle
|
||||
idx = torch.randperm(self.queue_size)
|
||||
self.queue_lr = self.queue_lr[idx]
|
||||
self.queue_gt = self.queue_gt[idx]
|
||||
# get
|
||||
# get first b samples
|
||||
lq_dequeue = self.queue_lr[0:b, :, :, :].clone()
|
||||
gt_dequeue = self.queue_gt[0:b, :, :, :].clone()
|
||||
# update
|
||||
# update the queue
|
||||
self.queue_lr[0:b, :, :, :] = self.lq.clone()
|
||||
self.queue_gt[0:b, :, :, :] = self.gt.clone()
|
||||
|
||||
@@ -54,10 +65,12 @@ class RealESRNetModel(SRModel):
|
||||
|
||||
@torch.no_grad()
|
||||
def feed_data(self, data):
|
||||
"""Accept data from dataloader, and then add two-order degradations to obtain LQ images.
|
||||
"""
|
||||
if self.is_train and self.opt.get('high_order_degradation', True):
|
||||
# training data synthesis
|
||||
self.gt = data['gt'].to(self.device)
|
||||
# USM the GT images
|
||||
# USM sharpen the GT images
|
||||
if self.opt['gt_usm'] is True:
|
||||
self.gt = self.usm_sharpener(self.gt)
|
||||
|
||||
@@ -80,7 +93,7 @@ class RealESRNetModel(SRModel):
|
||||
scale = 1
|
||||
mode = random.choice(['area', 'bilinear', 'bicubic'])
|
||||
out = F.interpolate(out, scale_factor=scale, mode=mode)
|
||||
# noise
|
||||
# add noise
|
||||
gray_noise_prob = self.opt['gray_noise_prob']
|
||||
if np.random.uniform() < self.opt['gaussian_noise_prob']:
|
||||
out = random_add_gaussian_noise_pt(
|
||||
@@ -94,7 +107,7 @@ class RealESRNetModel(SRModel):
|
||||
rounds=False)
|
||||
# JPEG compression
|
||||
jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range'])
|
||||
out = torch.clamp(out, 0, 1)
|
||||
out = torch.clamp(out, 0, 1) # clamp to [0, 1], otherwise JPEGer will result in unpleasant artifacts
|
||||
out = self.jpeger(out, quality=jpeg_p)
|
||||
|
||||
# ----------------------- The second degradation process ----------------------- #
|
||||
@@ -112,7 +125,7 @@ class RealESRNetModel(SRModel):
|
||||
mode = random.choice(['area', 'bilinear', 'bicubic'])
|
||||
out = F.interpolate(
|
||||
out, size=(int(ori_h / self.opt['scale'] * scale), int(ori_w / self.opt['scale'] * scale)), mode=mode)
|
||||
# noise
|
||||
# add noise
|
||||
gray_noise_prob = self.opt['gray_noise_prob2']
|
||||
if np.random.uniform() < self.opt['gaussian_noise_prob2']:
|
||||
out = random_add_gaussian_noise_pt(
|
||||
@@ -160,7 +173,9 @@ class RealESRNetModel(SRModel):
|
||||
|
||||
# training pair pool
|
||||
self._dequeue_and_enqueue()
|
||||
self.lq = self.lq.contiguous() # for the warning: grad and param do not obey the gradient layout contract
|
||||
else:
|
||||
# for paired training or validation
|
||||
self.lq = data['lq'].to(self.device)
|
||||
if 'gt' in data:
|
||||
self.gt = data['gt'].to(self.device)
|
||||
|
||||
@@ -12,6 +12,19 @@ ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
|
||||
class RealESRGANer():
|
||||
"""A helper class for upsampling images with RealESRGAN.
|
||||
|
||||
Args:
|
||||
scale (int): Upsampling scale factor used in the networks. It is usually 2 or 4.
|
||||
model_path (str): The path to the pretrained model. It can be urls (will first download it automatically).
|
||||
model (nn.Module): The defined network. If None, the model will be constructed here. Default: None.
|
||||
tile (int): As too large images result in the out of GPU memory issue, so this tile option will first crop
|
||||
input images into tiles, and then process each of them. Finally, they will be merged into one image.
|
||||
0 denotes for do not use tile. Default: 0.
|
||||
tile_pad (int): The pad size for each tile, to remove border artifacts. Default: 10.
|
||||
pre_pad (int): Pad the input images to avoid border artifacts. Default: 10.
|
||||
half (float): Whether to use half precision during inference. Default: False.
|
||||
"""
|
||||
|
||||
def __init__(self, scale, model_path, model=None, tile=0, tile_pad=10, pre_pad=10, half=False):
|
||||
self.scale = scale
|
||||
@@ -26,10 +39,12 @@ class RealESRGANer():
|
||||
if model is None:
|
||||
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=scale)
|
||||
|
||||
# if the model_path starts with https, it will first download models to the folder: realesrgan/weights
|
||||
if model_path.startswith('https://'):
|
||||
model_path = load_file_from_url(
|
||||
url=model_path, model_dir='realesrgan/weights', progress=True, file_name=None)
|
||||
loadnet = torch.load(model_path)
|
||||
# prefer to use params_ema
|
||||
if 'params_ema' in loadnet:
|
||||
keyname = 'params_ema'
|
||||
else:
|
||||
@@ -41,6 +56,8 @@ class RealESRGANer():
|
||||
self.model = self.model.half()
|
||||
|
||||
def pre_process(self, img):
|
||||
"""Pre-process, such as pre-pad and mod pad, so that the images can be divisible
|
||||
"""
|
||||
img = torch.from_numpy(np.transpose(img, (2, 0, 1))).float()
|
||||
self.img = img.unsqueeze(0).to(self.device)
|
||||
if self.half:
|
||||
@@ -49,7 +66,7 @@ class RealESRGANer():
|
||||
# pre_pad
|
||||
if self.pre_pad != 0:
|
||||
self.img = F.pad(self.img, (0, self.pre_pad, 0, self.pre_pad), 'reflect')
|
||||
# mod pad
|
||||
# mod pad for divisible borders
|
||||
if self.scale == 2:
|
||||
self.mod_scale = 2
|
||||
elif self.scale == 1:
|
||||
@@ -64,10 +81,14 @@ class RealESRGANer():
|
||||
self.img = F.pad(self.img, (0, self.mod_pad_w, 0, self.mod_pad_h), 'reflect')
|
||||
|
||||
def process(self):
|
||||
# model inference
|
||||
self.output = self.model(self.img)
|
||||
|
||||
def tile_process(self):
|
||||
"""Modified from: https://github.com/ata4/esrgan-launcher
|
||||
"""It will first crop input images to tiles, and then process each tile.
|
||||
Finally, all the processed tiles are merged into one images.
|
||||
|
||||
Modified from: https://github.com/ata4/esrgan-launcher
|
||||
"""
|
||||
batch, channel, height, width = self.img.shape
|
||||
output_height = height * self.scale
|
||||
@@ -188,7 +209,7 @@ class RealESRGANer():
|
||||
output_alpha = output_alpha.data.squeeze().float().cpu().clamp_(0, 1).numpy()
|
||||
output_alpha = np.transpose(output_alpha[[2, 1, 0], :, :], (1, 2, 0))
|
||||
output_alpha = cv2.cvtColor(output_alpha, cv2.COLOR_BGR2GRAY)
|
||||
else:
|
||||
else: # use the cv2 resize for alpha channel
|
||||
h, w = alpha.shape[0:2]
|
||||
output_alpha = cv2.resize(alpha, (w * self.scale, h * self.scale), interpolation=cv2.INTER_LINEAR)
|
||||
|
||||
@@ -213,7 +234,9 @@ class RealESRGANer():
|
||||
|
||||
|
||||
def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
|
||||
"""Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py
|
||||
"""Load file form http url, will download models if necessary.
|
||||
|
||||
Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py
|
||||
"""
|
||||
if model_dir is None:
|
||||
hub_dir = get_dir()
|
||||
|
||||
Reference in New Issue
Block a user