Compare commits
13 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2ed98632c0 | ||
|
|
37d3c81c34 | ||
|
|
8d982a2173 | ||
|
|
ee820df2e2 | ||
|
|
248cbedbce | ||
|
|
e9f056fb63 | ||
|
|
ec767200ed | ||
|
|
1b45b66436 | ||
|
|
9baeba566b | ||
|
|
85bd346714 | ||
|
|
a1713103c3 | ||
|
|
7fcc11f255 | ||
|
|
cc153d278a |
6
.github/workflows/pylint.yml
vendored
6
.github/workflows/pylint.yml
vendored
@@ -1,4 +1,4 @@
|
||||
name: Python Lint
|
||||
name: PyLint
|
||||
|
||||
on: [push, pull_request]
|
||||
|
||||
@@ -26,5 +26,5 @@ jobs:
|
||||
- name: Lint
|
||||
run: |
|
||||
flake8 .
|
||||
isort --check-only --diff basicsr/ options/ scripts/ tests/ inference/ setup.py
|
||||
yapf -r -d basicsr/ options/ scripts/ tests/ inference/ setup.py
|
||||
isort --check-only --diff data/ models/ inference_realesrgan.py
|
||||
yapf -r -d data/ models/ inference_realesrgan.py
|
||||
|
||||
68
README.md
68
README.md
@@ -1,46 +1,67 @@
|
||||
# Real-ESRGAN
|
||||
|
||||
[**Paper**](https://arxiv.org/abs/2101.04061)
|
||||
[](https://github.com/xinntao/Real-ESRGAN/releases)
|
||||
[](https://github.com/xinntao/Real-ESRGAN/issues)
|
||||
[](https://github.com/xinntao/Real-ESRGAN/blob/master/LICENSE)
|
||||
[](https://github.com/xinntao/Real-ESRGAN/blob/master/.github/workflows/pylint.yml)
|
||||
|
||||
1. [Colab Demo](https://colab.research.google.com/drive/1sVsoBd9AjckIXThgtZhGrHRfFI6UUYOo) for Real-ESRGAN <a href="https://colab.research.google.com/drive/1k2Zod6kSHEvraybHl50Lys0LerhyTMCo?usp=sharing"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="google colab logo"></a>.
|
||||
2. [Portable Windows executable file](https://github.com/xinntao/Real-ESRGAN/releases). You can find more information [here](#Portable-executable-files).
|
||||
|
||||
Real-ESRGAN aims at developing **Practical Algorithms for General Image Restoration**.<br>
|
||||
We extend the powerful ESRGAN to a practical restoration application (namely, Real-ESRGAN), which is trained with pure synthetic data.
|
||||
|
||||
:triangular_flag_on_post: The training codes have been released. A detailed guide can be found in [Training.md](Training.md).
|
||||
|
||||
### :book: Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data
|
||||
|
||||
> [[Paper](https://arxiv.org/abs/2101.04061)]   [Project Page]   [Demo] <br>
|
||||
> [[Paper](https://arxiv.org/abs/2107.10833)]   [Project Page]   [Demo] <br>
|
||||
> [Xintao Wang](https://xinntao.github.io/), Liangbin Xie, [Chao Dong](https://scholar.google.com.hk/citations?user=OSDCB0UAAAAJ), [Ying Shan](https://scholar.google.com/citations?user=4oXBp9UAAAAJ&hl=en) <br>
|
||||
> Applied Research Center (ARC), Tencent PCG; Shenzhen Institutes of Advanced Technology, Chinese Academy of Sciences
|
||||
> Applied Research Center (ARC), Tencent PCG<br>
|
||||
> Shenzhen Institutes of Advanced Technology, Chinese Academy of Sciences
|
||||
|
||||
<p align="center">
|
||||
<img src="assets/teaser.jpg">
|
||||
</p>
|
||||
|
||||
#### Abstract
|
||||
---
|
||||
|
||||
Though many attempts have been made in blind super-resolution to restore low-resolution images with unknown and complex degradations, they are still far from addressing general real-world degraded images. In this work, we extend the powerful ESRGAN to a practical restoration application (namely, Real-ESRGAN), which is trained with pure synthetic data. Specifically, a high-order degradation modeling process is introduced to better simulate complex real-world degradations. We also consider the common ringing and overshoot artifacts in the synthesis process. In addition, we employ a U-Net discriminator with spectral normalization to increase discriminator capability and stabilize the training dynamics. Extensive comparisons have shown its superior visual performance than prior works on various real datasets. We also provide efficient implementations to synthesize training pairs on the fly.
|
||||
We have provided a pretrained model (*RealESRGAN_x4plus.pth*) with upsampling X4.<br>
|
||||
**Note that RealESRGAN may still fail in some cases as the real-world degradations are really too complex.**<br>
|
||||
Moreover, it **may not** perform well on **human faces, text**, *etc*, which will be optimized later.
|
||||
<br>
|
||||
|
||||
#### BibTeX
|
||||
Real-ESRGAN will be a long-term supported project (in my current plan :smiley:). It will be continuously updated
|
||||
in my spare time.
|
||||
|
||||
@Article{wang2021realesrgan,
|
||||
title={Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data},
|
||||
author={Xintao Wang and Liangbin Xie and Chao Dong and Ying Shan},
|
||||
journal={arXiv:2107.xxxxx},
|
||||
year={2021}
|
||||
}
|
||||
Here is a TODO list in the near future:
|
||||
|
||||
- [ ] optimize for human faces
|
||||
- [ ] optimize for texts
|
||||
- [ ] optimize for animation images
|
||||
- [ ] support more scales
|
||||
- [ ] support controllable restoration strength
|
||||
|
||||
If you have any good ideas or demands, please open an issue/discussion to let me know. <br>
|
||||
If you have some images that Real-ESRGAN could not well restored, please also open an issue/discussion. I will record it (but I cannot guarantee to resolve it:stuck_out_tongue:). If necessary, I will open a page to specially record these real-world cases that need to be solved, but the current technology is difficult to handle well.
|
||||
|
||||
---
|
||||
|
||||
We are cleaning the training codes. It will be finished on 23 or 24, July.
|
||||
|
||||
---
|
||||
### Portable executable files
|
||||
|
||||
You can download **Windows executable files** from https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN-ncnn-vulkan.zip
|
||||
|
||||
This executable file is **portable** and includes all the binaries and models required. No CUDA or PyTorch environment is needed.<br>
|
||||
|
||||
You can simply run the following command:
|
||||
|
||||
```bash
|
||||
./realesrgan-ncnn-vulkan.exe -i input.jpg -o output.png
|
||||
```
|
||||
|
||||
Note that it may introduce block artifacts (and also generate slightly different results from the PyTorch implementation), because this executable file first crops the input image into several tiles, and then processes them separately, finally stitches together.
|
||||
Note that it may introduce block inconsistency (and also generate slightly different results from the PyTorch implementation), because this executable file first crops the input image into several tiles, and then processes them separately, finally stitches together.
|
||||
|
||||
This executable file is based on the wonderful [ncnn project](https://github.com/Tencent/ncnn) and [realsr-ncnn-vulkan](https://github.com/nihui/realsr-ncnn-vulkan).
|
||||
This executable file is based on the wonderful [Tencent/ncnn](https://github.com/Tencent/ncnn) and [realsr-ncnn-vulkan](https://github.com/nihui/realsr-ncnn-vulkan) by [nihui](https://github.com/nihui).
|
||||
|
||||
---
|
||||
|
||||
@@ -85,6 +106,19 @@ python inference_realesrgan.py --model_path experiments/pretrained_models/RealES
|
||||
|
||||
Results are in the `results` folder
|
||||
|
||||
## :computer: Training
|
||||
|
||||
A detailed guide can be found in [Training.md](Training.md).
|
||||
|
||||
## BibTeX
|
||||
|
||||
@Article{wang2021realesrgan,
|
||||
title={Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data},
|
||||
author={Xintao Wang and Liangbin Xie and Chao Dong and Ying Shan},
|
||||
journal={arXiv:2107.10833},
|
||||
year={2021}
|
||||
}
|
||||
|
||||
## :e-mail: Contact
|
||||
|
||||
If you have any question, please email `xintao.wang@outlook.com` or `xintaowang@tencent.com`.
|
||||
|
||||
97
Training.md
Normal file
97
Training.md
Normal file
@@ -0,0 +1,97 @@
|
||||
# :computer: How to Train Real-ESRGAN
|
||||
|
||||
The training codes have been released. <br>
|
||||
Note that the codes have a lot of refactoring. So there may be some bugs/performance drops. Welcome to report issues and I will also retrain the models.
|
||||
|
||||
## Overview
|
||||
|
||||
The training has been divided into two stages. These two stages have the same data synthesis process and training pipeline, except for the loss functions. Specifically,
|
||||
|
||||
1. We first train Real-ESRNet with L1 loss from the pre-trained model ESRGAN.
|
||||
1. We then use the trained Real-ESRNet model as an initialization of the generator, and train the Real-ESRGAN with a combination of L1 loss, perceptual loss and GAN loss.
|
||||
|
||||
## Dataset Preparation
|
||||
|
||||
We use DF2K (DIV2K and Flickr2K) + OST datasets for our training. Only HR images are required. <br>
|
||||
You can download from :
|
||||
|
||||
1. DIV2K: http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_HR.zip
|
||||
2. Flickr2K: https://cv.snu.ac.kr/research/EDSR/Flickr2K.tar
|
||||
3. OST: https://openmmlab.oss-cn-hangzhou.aliyuncs.com/datasets/OST_dataset.zip
|
||||
|
||||
For the DF2K dataset, we use a multi-scale strategy, *i.e.*, we downsample HR images to obtain several Ground-Truth images with different scales.
|
||||
|
||||
We then crop DF2K images into sub-images for faster IO and processing.
|
||||
|
||||
You need to prepare a txt file containing the image paths. The following are some examples in `meta_info_DF2Kmultiscale+OST_sub.txt` (As different users may have different sub-images partitions, this file is not suitable for your purpose and you need to prepare your own txt file):
|
||||
|
||||
```txt
|
||||
DF2K_HR_sub/000001_s001.png
|
||||
DF2K_HR_sub/000001_s002.png
|
||||
DF2K_HR_sub/000001_s003.png
|
||||
...
|
||||
```
|
||||
|
||||
## Train Real-ESRNet
|
||||
|
||||
1. Download pre-trained model [ESRGAN](https://drive.google.com/file/d/1b3_bWZTjNO3iL2js1yWkJfjZykcQgvzT/view?usp=sharing) into `experiments/pretrained_models`.
|
||||
1. Modify the content in the option file `options/train_realesrnet_x4plus.yml` accordingly:
|
||||
```yml
|
||||
train:
|
||||
name: DF2K+OST
|
||||
type: RealESRGANDataset
|
||||
dataroot_gt: datasets/DF2K # modify to the root path of your folder
|
||||
meta_info: data/meta_info/meta_info_DF2Kmultiscale+OST_sub.txt # modify to your own generate meta info txt
|
||||
io_backend:
|
||||
type: disk
|
||||
```
|
||||
1. If you want to perform validation during training, uncomment those lines and modify accordingly:
|
||||
```yml
|
||||
# Uncomment these for validation
|
||||
# val:
|
||||
# name: validation
|
||||
# type: PairedImageDataset
|
||||
# dataroot_gt: path_to_gt
|
||||
# dataroot_lq: path_to_lq
|
||||
# io_backend:
|
||||
# type: disk
|
||||
|
||||
...
|
||||
|
||||
# Uncomment these for validation
|
||||
# validation settings
|
||||
# val:
|
||||
# val_freq: !!float 5e3
|
||||
# save_img: True
|
||||
|
||||
# metrics:
|
||||
# psnr: # metric name, can be arbitrary
|
||||
# type: calculate_psnr
|
||||
# crop_border: 4
|
||||
# test_y_channel: false
|
||||
```
|
||||
1. Before the formal training, you may run in the `--debug` mode to see whether everything is OK. We use four GPUs for training:
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3 \
|
||||
python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 train.py -opt options/train_realesrnet_x4plus.yml --launcher pytorch --debug
|
||||
```
|
||||
1. The formal training. We use four GPUs for training. We use the `--auto_resume` argument to automatically resume the training if necessary.
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3 \
|
||||
python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 train.py -opt options/train_realesrnet_x4plus.yml --launcher pytorch --auto_resume
|
||||
```
|
||||
|
||||
## Train Real-ESRGAN
|
||||
|
||||
1. After the training of Real-ESRNet, you now have the file `experiments/train_RealESRNetx4plus_1000k_B12G4_fromESRGAN/model/net_g_1000000.pth`. If you need to specify the pre-trained path to other files, modify the `pretrain_network_g` value in the option file `train_realesrgan_x4plus.yml`.
|
||||
1. Modify the option file `train_realesrgan_x4plus.yml` accordingly. Most modifications are similar to those listed above.
|
||||
1. Before the formal training, you may run in the `--debug` mode to see whether everything is OK. We use four GPUs for training:
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3 \
|
||||
python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 train.py -opt options/train_realesrgan_x4plus.yml --launcher pytorch --debug
|
||||
```
|
||||
1. The formal training. We use four GPUs for training. We use the `--auto_resume` argument to automatically resume the training if necessary.
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3 \
|
||||
python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 train.py -opt options/train_realesrgan_x4plus.yml --launcher pytorch --auto_resume
|
||||
```
|
||||
10
archs/__init__.py
Normal file
10
archs/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
||||
import importlib
|
||||
from basicsr.utils import scandir
|
||||
from os import path as osp
|
||||
|
||||
# automatically scan and import arch modules for registry
|
||||
# scan all the files that end with '_arch.py' under the archs folder
|
||||
arch_folder = osp.dirname(osp.abspath(__file__))
|
||||
arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')]
|
||||
# import all the arch modules
|
||||
_arch_modules = [importlib.import_module(f'archs.{file_name}') for file_name in arch_filenames]
|
||||
60
archs/discriminator_arch.py
Normal file
60
archs/discriminator_arch.py
Normal file
@@ -0,0 +1,60 @@
|
||||
from basicsr.utils.registry import ARCH_REGISTRY
|
||||
from torch import nn as nn
|
||||
from torch.nn import functional as F
|
||||
from torch.nn.utils import spectral_norm
|
||||
|
||||
|
||||
@ARCH_REGISTRY.register()
|
||||
class UNetDiscriminatorSN(nn.Module):
|
||||
"""Defines a U-Net discriminator with spectral normalization (SN)"""
|
||||
|
||||
def __init__(self, num_in_ch, num_feat=64, skip_connection=True):
|
||||
super(UNetDiscriminatorSN, self).__init__()
|
||||
self.skip_connection = skip_connection
|
||||
norm = spectral_norm
|
||||
|
||||
self.conv0 = nn.Conv2d(num_in_ch, num_feat, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
self.conv1 = norm(nn.Conv2d(num_feat, num_feat * 2, 4, 2, 1, bias=False))
|
||||
self.conv2 = norm(nn.Conv2d(num_feat * 2, num_feat * 4, 4, 2, 1, bias=False))
|
||||
self.conv3 = norm(nn.Conv2d(num_feat * 4, num_feat * 8, 4, 2, 1, bias=False))
|
||||
# upsample
|
||||
self.conv4 = norm(nn.Conv2d(num_feat * 8, num_feat * 4, 3, 1, 1, bias=False))
|
||||
self.conv5 = norm(nn.Conv2d(num_feat * 4, num_feat * 2, 3, 1, 1, bias=False))
|
||||
self.conv6 = norm(nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1, bias=False))
|
||||
|
||||
# extra
|
||||
self.conv7 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False))
|
||||
self.conv8 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False))
|
||||
|
||||
self.conv9 = nn.Conv2d(num_feat, 1, 3, 1, 1)
|
||||
|
||||
def forward(self, x):
|
||||
x0 = F.leaky_relu(self.conv0(x), negative_slope=0.2, inplace=True)
|
||||
x1 = F.leaky_relu(self.conv1(x0), negative_slope=0.2, inplace=True)
|
||||
x2 = F.leaky_relu(self.conv2(x1), negative_slope=0.2, inplace=True)
|
||||
x3 = F.leaky_relu(self.conv3(x2), negative_slope=0.2, inplace=True)
|
||||
|
||||
# upsample
|
||||
x3 = F.interpolate(x3, scale_factor=2, mode='bilinear', align_corners=False)
|
||||
x4 = F.leaky_relu(self.conv4(x3), negative_slope=0.2, inplace=True)
|
||||
|
||||
if self.skip_connection:
|
||||
x4 = x4 + x2
|
||||
x4 = F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False)
|
||||
x5 = F.leaky_relu(self.conv5(x4), negative_slope=0.2, inplace=True)
|
||||
|
||||
if self.skip_connection:
|
||||
x5 = x5 + x1
|
||||
x5 = F.interpolate(x5, scale_factor=2, mode='bilinear', align_corners=False)
|
||||
x6 = F.leaky_relu(self.conv6(x5), negative_slope=0.2, inplace=True)
|
||||
|
||||
if self.skip_connection:
|
||||
x6 = x6 + x0
|
||||
|
||||
# extra
|
||||
out = F.leaky_relu(self.conv7(x6), negative_slope=0.2, inplace=True)
|
||||
out = F.leaky_relu(self.conv8(out), negative_slope=0.2, inplace=True)
|
||||
out = self.conv9(out)
|
||||
|
||||
return out
|
||||
10
data/__init__.py
Normal file
10
data/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
||||
import importlib
|
||||
from basicsr.utils import scandir
|
||||
from os import path as osp
|
||||
|
||||
# automatically scan and import dataset modules for registry
|
||||
# scan all the files that end with '_dataset.py' under the data folder
|
||||
data_folder = osp.dirname(osp.abspath(__file__))
|
||||
dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')]
|
||||
# import all the dataset modules
|
||||
_dataset_modules = [importlib.import_module(f'data.{file_name}') for file_name in dataset_filenames]
|
||||
175
data/realesrgan_dataset.py
Normal file
175
data/realesrgan_dataset.py
Normal file
@@ -0,0 +1,175 @@
|
||||
import cv2
|
||||
import math
|
||||
import numpy as np
|
||||
import os
|
||||
import os.path as osp
|
||||
import random
|
||||
import time
|
||||
import torch
|
||||
from basicsr.data.degradations import circular_lowpass_kernel, random_mixed_kernels
|
||||
from basicsr.data.transforms import augment
|
||||
from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
|
||||
from basicsr.utils.registry import DATASET_REGISTRY
|
||||
from torch.utils import data as data
|
||||
|
||||
|
||||
@DATASET_REGISTRY.register()
|
||||
class RealESRGANDataset(data.Dataset):
|
||||
"""
|
||||
Dataset used for Real-ESRGAN model.
|
||||
"""
|
||||
|
||||
def __init__(self, opt):
|
||||
super(RealESRGANDataset, self).__init__()
|
||||
self.opt = opt
|
||||
# file client (io backend)
|
||||
self.file_client = None
|
||||
self.io_backend_opt = opt['io_backend']
|
||||
self.gt_folder = opt['dataroot_gt']
|
||||
|
||||
if self.io_backend_opt['type'] == 'lmdb':
|
||||
self.io_backend_opt['db_paths'] = [self.gt_folder]
|
||||
self.io_backend_opt['client_keys'] = ['gt']
|
||||
if not self.gt_folder.endswith('.lmdb'):
|
||||
raise ValueError(f"'dataroot_gt' should end with '.lmdb', but received {self.gt_folder}")
|
||||
with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin:
|
||||
self.paths = [line.split('.')[0] for line in fin]
|
||||
else:
|
||||
with open(self.opt['meta_info']) as fin:
|
||||
paths = [line.strip() for line in fin]
|
||||
self.paths = [os.path.join(self.gt_folder, v) for v in paths]
|
||||
|
||||
# blur settings for the first degradation
|
||||
self.blur_kernel_size = opt['blur_kernel_size']
|
||||
self.kernel_list = opt['kernel_list']
|
||||
self.kernel_prob = opt['kernel_prob']
|
||||
self.blur_sigma = opt['blur_sigma']
|
||||
self.betag_range = opt['betag_range']
|
||||
self.betap_range = opt['betap_range']
|
||||
self.sinc_prob = opt['sinc_prob']
|
||||
|
||||
# blur settings for the second degradation
|
||||
self.blur_kernel_size2 = opt['blur_kernel_size2']
|
||||
self.kernel_list2 = opt['kernel_list2']
|
||||
self.kernel_prob2 = opt['kernel_prob2']
|
||||
self.blur_sigma2 = opt['blur_sigma2']
|
||||
self.betag_range2 = opt['betag_range2']
|
||||
self.betap_range2 = opt['betap_range2']
|
||||
self.sinc_prob2 = opt['sinc_prob2']
|
||||
|
||||
# a final sinc filter
|
||||
self.final_sinc_prob = opt['final_sinc_prob']
|
||||
|
||||
self.kernel_range = [2 * v + 1 for v in range(3, 11)] # kernel size ranges from 7 to 21
|
||||
self.pulse_tensor = torch.zeros(21, 21).float() # convolving with pulse tensor brings no blurry effect
|
||||
self.pulse_tensor[10, 10] = 1
|
||||
|
||||
def __getitem__(self, index):
|
||||
if self.file_client is None:
|
||||
self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
|
||||
|
||||
# -------------------------------- Load gt images -------------------------------- #
|
||||
# Shape: (h, w, c); channel order: BGR; image range: [0, 1], float32.
|
||||
gt_path = self.paths[index]
|
||||
# avoid errors caused by high latency in reading files
|
||||
retry = 3
|
||||
while retry > 0:
|
||||
try:
|
||||
img_bytes = self.file_client.get(gt_path, 'gt')
|
||||
except Exception as e:
|
||||
logger = get_root_logger()
|
||||
logger.warn(f'File client error: {e}, remaining retry times: {retry - 1}')
|
||||
# change another file to read
|
||||
index = random.randint(0, self.__len__())
|
||||
gt_path = self.paths[index]
|
||||
time.sleep(1) # sleep 1s for occasional server congestion
|
||||
else:
|
||||
break
|
||||
finally:
|
||||
retry -= 1
|
||||
img_gt = imfrombytes(img_bytes, float32=True)
|
||||
|
||||
# -------------------- augmentation for training: flip, rotation -------------------- #
|
||||
img_gt = augment(img_gt, self.opt['use_hflip'], self.opt['use_rot'])
|
||||
|
||||
# crop or pad to 400: 400 is hard-coded. You may change it accordingly
|
||||
h, w = img_gt.shape[0:2]
|
||||
crop_pad_size = 400
|
||||
# pad
|
||||
if h < crop_pad_size or w < crop_pad_size:
|
||||
pad_h = max(0, crop_pad_size - h)
|
||||
pad_w = max(0, crop_pad_size - w)
|
||||
img_gt = cv2.copyMakeBorder(img_gt, 0, pad_h, 0, pad_w, cv2.BORDER_REFLECT_101)
|
||||
# crop
|
||||
if img_gt.shape[0] > crop_pad_size or img_gt.shape[1] > crop_pad_size:
|
||||
h, w = img_gt.shape[0:2]
|
||||
# randomly choose top and left coordinates
|
||||
top = random.randint(0, h - crop_pad_size)
|
||||
left = random.randint(0, w - crop_pad_size)
|
||||
img_gt = img_gt[top:top + crop_pad_size, left:left + crop_pad_size, ...]
|
||||
|
||||
# ------------------------ Generate kernels (used in the first degradation) ------------------------ #
|
||||
kernel_size = random.choice(self.kernel_range)
|
||||
if np.random.uniform() < self.opt['sinc_prob']:
|
||||
# this sinc filter setting is for kernels ranging from [7, 21]
|
||||
if kernel_size < 13:
|
||||
omega_c = np.random.uniform(np.pi / 3, np.pi)
|
||||
else:
|
||||
omega_c = np.random.uniform(np.pi / 5, np.pi)
|
||||
kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False)
|
||||
else:
|
||||
kernel = random_mixed_kernels(
|
||||
self.kernel_list,
|
||||
self.kernel_prob,
|
||||
kernel_size,
|
||||
self.blur_sigma,
|
||||
self.blur_sigma, [-math.pi, math.pi],
|
||||
self.betag_range,
|
||||
self.betap_range,
|
||||
noise_range=None)
|
||||
# pad kernel
|
||||
pad_size = (21 - kernel_size) // 2
|
||||
kernel = np.pad(kernel, ((pad_size, pad_size), (pad_size, pad_size)))
|
||||
|
||||
# ------------------------ Generate kernels (used in the second degradation) ------------------------ #
|
||||
kernel_size = random.choice(self.kernel_range)
|
||||
if np.random.uniform() < self.opt['sinc_prob2']:
|
||||
if kernel_size < 13:
|
||||
omega_c = np.random.uniform(np.pi / 3, np.pi)
|
||||
else:
|
||||
omega_c = np.random.uniform(np.pi / 5, np.pi)
|
||||
kernel2 = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False)
|
||||
else:
|
||||
kernel2 = random_mixed_kernels(
|
||||
self.kernel_list2,
|
||||
self.kernel_prob2,
|
||||
kernel_size,
|
||||
self.blur_sigma2,
|
||||
self.blur_sigma2, [-math.pi, math.pi],
|
||||
self.betag_range2,
|
||||
self.betap_range2,
|
||||
noise_range=None)
|
||||
|
||||
# pad kernel
|
||||
pad_size = (21 - kernel_size) // 2
|
||||
kernel2 = np.pad(kernel2, ((pad_size, pad_size), (pad_size, pad_size)))
|
||||
|
||||
# ------------------------------------- sinc kernel ------------------------------------- #
|
||||
if np.random.uniform() < self.opt['final_sinc_prob']:
|
||||
kernel_size = random.choice(self.kernel_range)
|
||||
omega_c = np.random.uniform(np.pi / 3, np.pi)
|
||||
sinc_kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=21)
|
||||
sinc_kernel = torch.FloatTensor(sinc_kernel)
|
||||
else:
|
||||
sinc_kernel = self.pulse_tensor
|
||||
|
||||
# BGR to RGB, HWC to CHW, numpy to tensor
|
||||
img_gt = img2tensor([img_gt], bgr2rgb=True, float32=True)[0]
|
||||
kernel = torch.FloatTensor(kernel)
|
||||
kernel2 = torch.FloatTensor(kernel2)
|
||||
|
||||
return_d = {'gt': img_gt, 'kernel1': kernel, 'kernel2': kernel2, 'sinc_kernel': sinc_kernel, 'gt_path': gt_path}
|
||||
return return_d
|
||||
|
||||
def __len__(self):
|
||||
return len(self.paths)
|
||||
1
experiments/pretrained_models/README.md
Normal file
1
experiments/pretrained_models/README.md
Normal file
@@ -0,0 +1 @@
|
||||
# Put downloaded pre-trained models here
|
||||
10
models/__init__.py
Normal file
10
models/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
||||
import importlib
|
||||
from basicsr.utils import scandir
|
||||
from os import path as osp
|
||||
|
||||
# automatically scan and import model modules for registry
|
||||
# scan all the files that end with '_model.py' under the model folder
|
||||
model_folder = osp.dirname(osp.abspath(__file__))
|
||||
model_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) if v.endswith('_model.py')]
|
||||
# import all the model modules
|
||||
_model_modules = [importlib.import_module(f'models.{file_name}') for file_name in model_filenames]
|
||||
240
models/realesrgan_model.py
Normal file
240
models/realesrgan_model.py
Normal file
@@ -0,0 +1,240 @@
|
||||
import numpy as np
|
||||
import random
|
||||
import torch
|
||||
from basicsr.data.degradations import random_add_gaussian_noise_pt, random_add_poisson_noise_pt
|
||||
from basicsr.data.transforms import paired_random_crop
|
||||
from basicsr.models.srgan_model import SRGANModel
|
||||
from basicsr.utils import DiffJPEG, USMSharp
|
||||
from basicsr.utils.img_process_util import filter2D
|
||||
from basicsr.utils.registry import MODEL_REGISTRY
|
||||
from collections import OrderedDict
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
@MODEL_REGISTRY.register()
|
||||
class RealESRGANModel(SRGANModel):
|
||||
"""RealESRGAN Model"""
|
||||
|
||||
def __init__(self, opt):
|
||||
super(RealESRGANModel, self).__init__(opt)
|
||||
self.jpeger = DiffJPEG(differentiable=False).cuda()
|
||||
self.usm_shaper = USMSharp().cuda()
|
||||
self.queue_size = opt['queue_size']
|
||||
|
||||
@torch.no_grad()
|
||||
def _dequeue_and_enqueue(self):
|
||||
# training pair pool
|
||||
# initialize
|
||||
b, c, h, w = self.lq.size()
|
||||
if not hasattr(self, 'queue_lr'):
|
||||
assert self.queue_size % b == 0, 'queue size should be divisible by batch size'
|
||||
self.queue_lr = torch.zeros(self.queue_size, c, h, w).cuda()
|
||||
_, c, h, w = self.gt.size()
|
||||
self.queue_gt = torch.zeros(self.queue_size, c, h, w).cuda()
|
||||
self.queue_ptr = 0
|
||||
if self.queue_ptr == self.queue_size: # full
|
||||
# do dequeue and enqueue
|
||||
# shuffle
|
||||
idx = torch.randperm(self.queue_size)
|
||||
self.queue_lr = self.queue_lr[idx]
|
||||
self.queue_gt = self.queue_gt[idx]
|
||||
# get
|
||||
lq_dequeue = self.queue_lr[0:b, :, :, :].clone()
|
||||
gt_dequeue = self.queue_gt[0:b, :, :, :].clone()
|
||||
# update
|
||||
self.queue_lr[0:b, :, :, :] = self.lq.clone()
|
||||
self.queue_gt[0:b, :, :, :] = self.gt.clone()
|
||||
|
||||
self.lq = lq_dequeue
|
||||
self.gt = gt_dequeue
|
||||
else:
|
||||
# only do enqueue
|
||||
self.queue_lr[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.lq.clone()
|
||||
self.queue_gt[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.gt.clone()
|
||||
self.queue_ptr = self.queue_ptr + b
|
||||
|
||||
@torch.no_grad()
|
||||
def feed_data(self, data):
|
||||
if self.is_train:
|
||||
# training data synthesis
|
||||
self.gt = data['gt'].to(self.device)
|
||||
self.gt_usm = self.usm_shaper(self.gt)
|
||||
|
||||
self.kernel1 = data['kernel1'].to(self.device)
|
||||
self.kernel2 = data['kernel2'].to(self.device)
|
||||
self.sinc_kernel = data['sinc_kernel'].to(self.device)
|
||||
|
||||
ori_h, ori_w = self.gt.size()[2:4]
|
||||
|
||||
# ----------------------- The first degradation process ----------------------- #
|
||||
# blur
|
||||
out = filter2D(self.gt_usm, self.kernel1)
|
||||
# random resize
|
||||
updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob'])[0]
|
||||
if updown_type == 'up':
|
||||
scale = np.random.uniform(1, self.opt['resize_range'][1])
|
||||
elif updown_type == 'down':
|
||||
scale = np.random.uniform(self.opt['resize_range'][0], 1)
|
||||
else:
|
||||
scale = 1
|
||||
mode = random.choice(['area', 'bilinear', 'bicubic'])
|
||||
out = F.interpolate(out, scale_factor=scale, mode=mode)
|
||||
# noise
|
||||
gray_noise_prob = self.opt['gray_noise_prob']
|
||||
if np.random.uniform() < self.opt['gaussian_noise_prob']:
|
||||
out = random_add_gaussian_noise_pt(
|
||||
out, sigma_range=self.opt['noise_range'], clip=True, rounds=False, gray_prob=gray_noise_prob)
|
||||
else:
|
||||
out = random_add_poisson_noise_pt(
|
||||
out,
|
||||
scale_range=self.opt['poisson_scale_range'],
|
||||
gray_prob=gray_noise_prob,
|
||||
clip=True,
|
||||
rounds=False)
|
||||
# JPEG compression
|
||||
jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range'])
|
||||
out = torch.clamp(out, 0, 1)
|
||||
out = self.jpeger(out, quality=jpeg_p)
|
||||
|
||||
# ----------------------- The second degradation process ----------------------- #
|
||||
# blur
|
||||
if np.random.uniform() < self.opt['second_blur_prob']:
|
||||
out = filter2D(out, self.kernel2)
|
||||
# random resize
|
||||
updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob2'])[0]
|
||||
if updown_type == 'up':
|
||||
scale = np.random.uniform(1, self.opt['resize_range2'][1])
|
||||
elif updown_type == 'down':
|
||||
scale = np.random.uniform(self.opt['resize_range2'][0], 1)
|
||||
else:
|
||||
scale = 1
|
||||
mode = random.choice(['area', 'bilinear', 'bicubic'])
|
||||
out = F.interpolate(
|
||||
out, size=(int(ori_h / self.opt['scale'] * scale), int(ori_w / self.opt['scale'] * scale)), mode=mode)
|
||||
# noise
|
||||
gray_noise_prob = self.opt['gray_noise_prob2']
|
||||
if np.random.uniform() < self.opt['gaussian_noise_prob2']:
|
||||
out = random_add_gaussian_noise_pt(
|
||||
out, sigma_range=self.opt['noise_range2'], clip=True, rounds=False, gray_prob=gray_noise_prob)
|
||||
else:
|
||||
out = random_add_poisson_noise_pt(
|
||||
out,
|
||||
scale_range=self.opt['poisson_scale_range2'],
|
||||
gray_prob=gray_noise_prob,
|
||||
clip=True,
|
||||
rounds=False)
|
||||
|
||||
# JPEG compression + the final sinc filter
|
||||
# We also need to resize images to desired sizes. We group [resize back + sinc filter] together
|
||||
# as one operation.
|
||||
# We consider two orders:
|
||||
# 1. [resize back + sinc filter] + JPEG compression
|
||||
# 2. JPEG compression + [resize back + sinc filter]
|
||||
# Empirically, we find other combinations (sinc + JPEG + Resize) will introduce twisted lines.
|
||||
if np.random.uniform() < 0.5:
|
||||
# resize back + the final sinc filter
|
||||
mode = random.choice(['area', 'bilinear', 'bicubic'])
|
||||
out = F.interpolate(out, size=(ori_h // self.opt['scale'], ori_w // self.opt['scale']), mode=mode)
|
||||
out = filter2D(out, self.sinc_kernel)
|
||||
# JPEG compression
|
||||
jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2'])
|
||||
out = torch.clamp(out, 0, 1)
|
||||
out = self.jpeger(out, quality=jpeg_p)
|
||||
else:
|
||||
# JPEG compression
|
||||
jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2'])
|
||||
out = torch.clamp(out, 0, 1)
|
||||
out = self.jpeger(out, quality=jpeg_p)
|
||||
# resize back + the final sinc filter
|
||||
mode = random.choice(['area', 'bilinear', 'bicubic'])
|
||||
out = F.interpolate(out, size=(ori_h // self.opt['scale'], ori_w // self.opt['scale']), mode=mode)
|
||||
out = filter2D(out, self.sinc_kernel)
|
||||
|
||||
# clamp and round
|
||||
self.lq = torch.clamp((out * 255.0).round(), 0, 255) / 255.
|
||||
|
||||
# random crop
|
||||
gt_size = self.opt['gt_size']
|
||||
(self.gt, self.gt_usm), self.lq = paired_random_crop([self.gt, self.gt_usm], self.lq, gt_size,
|
||||
self.opt['scale'])
|
||||
|
||||
# training pair pool
|
||||
self._dequeue_and_enqueue()
|
||||
else:
|
||||
self.lq = data['lq'].to(self.device)
|
||||
if 'gt' in data:
|
||||
self.gt = data['gt'].to(self.device)
|
||||
|
||||
def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
|
||||
# do not use the synthetic process during validation
|
||||
self.is_train = False
|
||||
super(RealESRGANModel, self).nondist_validation(dataloader, current_iter, tb_logger, save_img)
|
||||
self.is_train = True
|
||||
|
||||
def optimize_parameters(self, current_iter):
|
||||
l1_gt = self.gt_usm
|
||||
percep_gt = self.gt_usm
|
||||
gan_gt = self.gt_usm
|
||||
if self.opt['l1_gt_usm'] is False:
|
||||
l1_gt = self.gt
|
||||
if self.opt['percep_gt_usm'] is False:
|
||||
percep_gt = self.gt
|
||||
if self.opt['gan_gt_usm'] is False:
|
||||
gan_gt = self.gt
|
||||
|
||||
# optimize net_g
|
||||
for p in self.net_d.parameters():
|
||||
p.requires_grad = False
|
||||
|
||||
self.optimizer_g.zero_grad()
|
||||
self.output = self.net_g(self.lq)
|
||||
|
||||
l_g_total = 0
|
||||
loss_dict = OrderedDict()
|
||||
if (current_iter % self.net_d_iters == 0 and current_iter > self.net_d_init_iters):
|
||||
# pixel loss
|
||||
if self.cri_pix:
|
||||
l_g_pix = self.cri_pix(self.output, l1_gt)
|
||||
l_g_total += l_g_pix
|
||||
loss_dict['l_g_pix'] = l_g_pix
|
||||
# perceptual loss
|
||||
if self.cri_perceptual:
|
||||
l_g_percep, l_g_style = self.cri_perceptual(self.output, percep_gt)
|
||||
if l_g_percep is not None:
|
||||
l_g_total += l_g_percep
|
||||
loss_dict['l_g_percep'] = l_g_percep
|
||||
if l_g_style is not None:
|
||||
l_g_total += l_g_style
|
||||
loss_dict['l_g_style'] = l_g_style
|
||||
# gan loss
|
||||
fake_g_pred = self.net_d(self.output)
|
||||
l_g_gan = self.cri_gan(fake_g_pred, True, is_disc=False)
|
||||
l_g_total += l_g_gan
|
||||
loss_dict['l_g_gan'] = l_g_gan
|
||||
|
||||
l_g_total.backward()
|
||||
self.optimizer_g.step()
|
||||
|
||||
# optimize net_d
|
||||
for p in self.net_d.parameters():
|
||||
p.requires_grad = True
|
||||
|
||||
self.optimizer_d.zero_grad()
|
||||
# real
|
||||
real_d_pred = self.net_d(gan_gt)
|
||||
l_d_real = self.cri_gan(real_d_pred, True, is_disc=True)
|
||||
loss_dict['l_d_real'] = l_d_real
|
||||
loss_dict['out_d_real'] = torch.mean(real_d_pred.detach())
|
||||
l_d_real.backward()
|
||||
# fake
|
||||
fake_d_pred = self.net_d(self.output.detach().clone()) # clone for pt1.9
|
||||
l_d_fake = self.cri_gan(fake_d_pred, False, is_disc=True)
|
||||
loss_dict['l_d_fake'] = l_d_fake
|
||||
loss_dict['out_d_fake'] = torch.mean(fake_d_pred.detach())
|
||||
l_d_fake.backward()
|
||||
self.optimizer_d.step()
|
||||
|
||||
if self.ema_decay > 0:
|
||||
self.model_ema(decay=self.ema_decay)
|
||||
|
||||
self.log_dict = self.reduce_loss_dict(loss_dict)
|
||||
172
models/realesrnet_model.py
Normal file
172
models/realesrnet_model.py
Normal file
@@ -0,0 +1,172 @@
|
||||
import numpy as np
|
||||
import random
|
||||
import torch
|
||||
from basicsr.data.degradations import random_add_gaussian_noise_pt, random_add_poisson_noise_pt
|
||||
from basicsr.data.transforms import paired_random_crop
|
||||
from basicsr.models.sr_model import SRModel
|
||||
from basicsr.utils import DiffJPEG, USMSharp
|
||||
from basicsr.utils.img_process_util import filter2D
|
||||
from basicsr.utils.registry import MODEL_REGISTRY
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
@MODEL_REGISTRY.register()
|
||||
class RealESRNetModel(SRModel):
|
||||
"""RealESRNet Model"""
|
||||
|
||||
def __init__(self, opt):
|
||||
super(RealESRNetModel, self).__init__(opt)
|
||||
self.jpeger = DiffJPEG(differentiable=False).cuda()
|
||||
self.usm_shaper = USMSharp().cuda()
|
||||
self.queue_size = opt['queue_size']
|
||||
|
||||
@torch.no_grad()
|
||||
def _dequeue_and_enqueue(self):
|
||||
# training pair pool
|
||||
# initialize
|
||||
b, c, h, w = self.lq.size()
|
||||
if not hasattr(self, 'queue_lr'):
|
||||
assert self.queue_size % b == 0, 'queue size should be divisible by batch size'
|
||||
self.queue_lr = torch.zeros(self.queue_size, c, h, w).cuda()
|
||||
_, c, h, w = self.gt.size()
|
||||
self.queue_gt = torch.zeros(self.queue_size, c, h, w).cuda()
|
||||
self.queue_ptr = 0
|
||||
if self.queue_ptr == self.queue_size: # full
|
||||
# do dequeue and enqueue
|
||||
# shuffle
|
||||
idx = torch.randperm(self.queue_size)
|
||||
self.queue_lr = self.queue_lr[idx]
|
||||
self.queue_gt = self.queue_gt[idx]
|
||||
# get
|
||||
lq_dequeue = self.queue_lr[0:b, :, :, :].clone()
|
||||
gt_dequeue = self.queue_gt[0:b, :, :, :].clone()
|
||||
# update
|
||||
self.queue_lr[0:b, :, :, :] = self.lq.clone()
|
||||
self.queue_gt[0:b, :, :, :] = self.gt.clone()
|
||||
|
||||
self.lq = lq_dequeue
|
||||
self.gt = gt_dequeue
|
||||
else:
|
||||
# only do enqueue
|
||||
self.queue_lr[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.lq.clone()
|
||||
self.queue_gt[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.gt.clone()
|
||||
self.queue_ptr = self.queue_ptr + b
|
||||
|
||||
@torch.no_grad()
|
||||
def feed_data(self, data):
|
||||
if self.is_train:
|
||||
# training data synthesis
|
||||
self.gt = data['gt'].to(self.device)
|
||||
# USM the GT images
|
||||
if self.opt['gt_usm'] is True:
|
||||
self.gt = self.usm_shaper(self.gt)
|
||||
|
||||
self.kernel1 = data['kernel1'].to(self.device)
|
||||
self.kernel2 = data['kernel2'].to(self.device)
|
||||
self.sinc_kernel = data['sinc_kernel'].to(self.device)
|
||||
|
||||
ori_h, ori_w = self.gt.size()[2:4]
|
||||
|
||||
# ----------------------- The first degradation process ----------------------- #
|
||||
# blur
|
||||
out = filter2D(self.gt, self.kernel1)
|
||||
# random resize
|
||||
updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob'])[0]
|
||||
if updown_type == 'up':
|
||||
scale = np.random.uniform(1, self.opt['resize_range'][1])
|
||||
elif updown_type == 'down':
|
||||
scale = np.random.uniform(self.opt['resize_range'][0], 1)
|
||||
else:
|
||||
scale = 1
|
||||
mode = random.choice(['area', 'bilinear', 'bicubic'])
|
||||
out = F.interpolate(out, scale_factor=scale, mode=mode)
|
||||
# noise
|
||||
gray_noise_prob = self.opt['gray_noise_prob']
|
||||
if np.random.uniform() < self.opt['gaussian_noise_prob']:
|
||||
out = random_add_gaussian_noise_pt(
|
||||
out, sigma_range=self.opt['noise_range'], clip=True, rounds=False, gray_prob=gray_noise_prob)
|
||||
else:
|
||||
out = random_add_poisson_noise_pt(
|
||||
out,
|
||||
scale_range=self.opt['poisson_scale_range'],
|
||||
gray_prob=gray_noise_prob,
|
||||
clip=True,
|
||||
rounds=False)
|
||||
# JPEG compression
|
||||
jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range'])
|
||||
out = torch.clamp(out, 0, 1)
|
||||
out = self.jpeger(out, quality=jpeg_p)
|
||||
|
||||
# ----------------------- The second degradation process ----------------------- #
|
||||
# blur
|
||||
if np.random.uniform() < self.opt['second_blur_prob']:
|
||||
out = filter2D(out, self.kernel2)
|
||||
# random resize
|
||||
updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob2'])[0]
|
||||
if updown_type == 'up':
|
||||
scale = np.random.uniform(1, self.opt['resize_range2'][1])
|
||||
elif updown_type == 'down':
|
||||
scale = np.random.uniform(self.opt['resize_range2'][0], 1)
|
||||
else:
|
||||
scale = 1
|
||||
mode = random.choice(['area', 'bilinear', 'bicubic'])
|
||||
out = F.interpolate(
|
||||
out, size=(int(ori_h / self.opt['scale'] * scale), int(ori_w / self.opt['scale'] * scale)), mode=mode)
|
||||
# noise
|
||||
gray_noise_prob = self.opt['gray_noise_prob2']
|
||||
if np.random.uniform() < self.opt['gaussian_noise_prob2']:
|
||||
out = random_add_gaussian_noise_pt(
|
||||
out, sigma_range=self.opt['noise_range2'], clip=True, rounds=False, gray_prob=gray_noise_prob)
|
||||
else:
|
||||
out = random_add_poisson_noise_pt(
|
||||
out,
|
||||
scale_range=self.opt['poisson_scale_range2'],
|
||||
gray_prob=gray_noise_prob,
|
||||
clip=True,
|
||||
rounds=False)
|
||||
|
||||
# JPEG compression + the final sinc filter
|
||||
# We also need to resize images to desired sizes. We group [resize back + sinc filter] together
|
||||
# as one operation.
|
||||
# We consider two orders:
|
||||
# 1. [resize back + sinc filter] + JPEG compression
|
||||
# 2. JPEG compression + [resize back + sinc filter]
|
||||
# Empirically, we find other combinations (sinc + JPEG + Resize) will introduce twisted lines.
|
||||
if np.random.uniform() < 0.5:
|
||||
# resize back + the final sinc filter
|
||||
mode = random.choice(['area', 'bilinear', 'bicubic'])
|
||||
out = F.interpolate(out, size=(ori_h // self.opt['scale'], ori_w // self.opt['scale']), mode=mode)
|
||||
out = filter2D(out, self.sinc_kernel)
|
||||
# JPEG compression
|
||||
jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2'])
|
||||
out = torch.clamp(out, 0, 1)
|
||||
out = self.jpeger(out, quality=jpeg_p)
|
||||
else:
|
||||
# JPEG compression
|
||||
jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2'])
|
||||
out = torch.clamp(out, 0, 1)
|
||||
out = self.jpeger(out, quality=jpeg_p)
|
||||
# resize back + the final sinc filter
|
||||
mode = random.choice(['area', 'bilinear', 'bicubic'])
|
||||
out = F.interpolate(out, size=(ori_h // self.opt['scale'], ori_w // self.opt['scale']), mode=mode)
|
||||
out = filter2D(out, self.sinc_kernel)
|
||||
|
||||
# clamp and round
|
||||
self.lq = torch.clamp((out * 255.0).round(), 0, 255) / 255.
|
||||
|
||||
# random crop
|
||||
gt_size = self.opt['gt_size']
|
||||
self.gt, self.lq = paired_random_crop(self.gt, self.lq, gt_size, self.opt['scale'])
|
||||
|
||||
# training pair pool
|
||||
self._dequeue_and_enqueue()
|
||||
else:
|
||||
self.lq = data['lq'].to(self.device)
|
||||
if 'gt' in data:
|
||||
self.gt = data['gt'].to(self.device)
|
||||
|
||||
def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
|
||||
# do not use the synthetic process during validation
|
||||
self.is_train = False
|
||||
super(RealESRNetModel, self).nondist_validation(dataloader, current_iter, tb_logger, save_img)
|
||||
self.is_train = True
|
||||
186
options/train_realesrgan_x4plus.yml
Normal file
186
options/train_realesrgan_x4plus.yml
Normal file
@@ -0,0 +1,186 @@
|
||||
# general settings
|
||||
name: train_RealESRGANx4plus_400k_B12G4_fromRealESRNet
|
||||
model_type: RealESRGANModel
|
||||
scale: 4
|
||||
num_gpu: 4
|
||||
manual_seed: 0
|
||||
|
||||
# ----------------- options for synthesizing training data in RealESRGANModel ----------------- #
|
||||
# USM the ground-truth
|
||||
l1_gt_usm: True
|
||||
percep_gt_usm: True
|
||||
gan_gt_usm: False
|
||||
|
||||
# the first degradation process
|
||||
resize_prob: [0.2, 0.7, 0.1] # up, down, keep
|
||||
resize_range: [0.15, 1.5]
|
||||
gaussian_noise_prob: 0.5
|
||||
noise_range: [1, 30]
|
||||
poisson_scale_range: [0.05, 3]
|
||||
gray_noise_prob: 0.4
|
||||
jpeg_range: [30, 95]
|
||||
|
||||
# the second degradation process
|
||||
second_blur_prob: 0.8
|
||||
resize_prob2: [0.3, 0.4, 0.3] # up, down, keep
|
||||
resize_range2: [0.3, 1.2]
|
||||
gaussian_noise_prob2: 0.5
|
||||
noise_range2: [1, 25]
|
||||
poisson_scale_range2: [0.05, 2.5]
|
||||
gray_noise_prob2: 0.4
|
||||
jpeg_range2: [30, 95]
|
||||
|
||||
gt_size: 256
|
||||
queue_size: 180
|
||||
|
||||
# dataset and data loader settings
|
||||
datasets:
|
||||
train:
|
||||
name: DF2K+OST
|
||||
type: RealESRGANDataset
|
||||
dataroot_gt: datasets/DF2K
|
||||
meta_info: data/meta_info/meta_info_DF2Kmultiscale+OST_sub.txt
|
||||
io_backend:
|
||||
type: disk
|
||||
|
||||
blur_kernel_size: 21
|
||||
kernel_list: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
|
||||
kernel_prob: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
|
||||
sinc_prob: 0.1
|
||||
blur_sigma: [0.2, 3]
|
||||
betag_range: [0.5, 4]
|
||||
betap_range: [1, 2]
|
||||
|
||||
blur_kernel_size2: 21
|
||||
kernel_list2: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
|
||||
kernel_prob2: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
|
||||
sinc_prob2: 0.1
|
||||
blur_sigma2: [0.2, 1.5]
|
||||
betag_range2: [0.5, 4]
|
||||
betap_range2: [1, 2]
|
||||
|
||||
final_sinc_prob: 0.8
|
||||
|
||||
gt_size: 256
|
||||
use_hflip: True
|
||||
use_rot: False
|
||||
|
||||
# data loader
|
||||
use_shuffle: true
|
||||
num_worker_per_gpu: 5
|
||||
batch_size_per_gpu: 12
|
||||
dataset_enlarge_ratio: 1
|
||||
prefetch_mode: ~
|
||||
|
||||
# Uncomment these for validation
|
||||
# val:
|
||||
# name: validation
|
||||
# type: PairedImageDataset
|
||||
# dataroot_gt: path_to_gt
|
||||
# dataroot_lq: path_to_lq
|
||||
# io_backend:
|
||||
# type: disk
|
||||
|
||||
# network structures
|
||||
network_g:
|
||||
type: RRDBNet
|
||||
num_in_ch: 3
|
||||
num_out_ch: 3
|
||||
num_feat: 64
|
||||
num_block: 23
|
||||
num_grow_ch: 32
|
||||
|
||||
|
||||
network_d:
|
||||
type: UNetDiscriminatorSN
|
||||
num_in_ch: 3
|
||||
num_feat: 64
|
||||
skip_connection: True
|
||||
|
||||
# path
|
||||
path:
|
||||
# use the pre-trained Real-ESRNet model
|
||||
pretrain_network_g: experiments/train_RealESRNetx4plus_1000k_B12G4_fromESRGAN/model/net_g_1000000.pth
|
||||
param_key_g: params_ema
|
||||
strict_load_g: true
|
||||
resume_state: ~
|
||||
|
||||
# training settings
|
||||
train:
|
||||
ema_decay: 0.999
|
||||
optim_g:
|
||||
type: Adam
|
||||
lr: !!float 1e-4
|
||||
weight_decay: 0
|
||||
betas: [0.9, 0.99]
|
||||
optim_d:
|
||||
type: Adam
|
||||
lr: !!float 1e-4
|
||||
weight_decay: 0
|
||||
betas: [0.9, 0.99]
|
||||
|
||||
scheduler:
|
||||
type: MultiStepLR
|
||||
milestones: [400000]
|
||||
gamma: 0.5
|
||||
|
||||
total_iter: 400000
|
||||
warmup_iter: -1 # no warm up
|
||||
|
||||
# losses
|
||||
pixel_opt:
|
||||
type: L1Loss
|
||||
loss_weight: 1.0
|
||||
reduction: mean
|
||||
# perceptual loss (content and style losses)
|
||||
perceptual_opt:
|
||||
type: PerceptualLoss
|
||||
layer_weights:
|
||||
# before relu
|
||||
'conv1_2': 0.1
|
||||
'conv2_2': 0.1
|
||||
'conv3_4': 1
|
||||
'conv4_4': 1
|
||||
'conv5_4': 1
|
||||
vgg_type: vgg19
|
||||
use_input_norm: true
|
||||
perceptual_weight: !!float 1.0
|
||||
style_weight: 0
|
||||
range_norm: false
|
||||
criterion: l1
|
||||
# gan loss
|
||||
gan_opt:
|
||||
type: GANLoss
|
||||
gan_type: vanilla
|
||||
real_label_val: 1.0
|
||||
fake_label_val: 0.0
|
||||
loss_weight: !!float 1e-1
|
||||
|
||||
net_d_iters: 1
|
||||
net_d_init_iters: 0
|
||||
|
||||
# Uncomment these for validation
|
||||
# validation settings
|
||||
# val:
|
||||
# val_freq: !!float 5e3
|
||||
# save_img: True
|
||||
|
||||
# metrics:
|
||||
# psnr: # metric name, can be arbitrary
|
||||
# type: calculate_psnr
|
||||
# crop_border: 4
|
||||
# test_y_channel: false
|
||||
|
||||
# logging settings
|
||||
logger:
|
||||
print_freq: 100
|
||||
save_checkpoint_freq: !!float 5e3
|
||||
use_tb_logger: true
|
||||
wandb:
|
||||
project: ~
|
||||
resume_id: ~
|
||||
|
||||
# dist training settings
|
||||
dist_params:
|
||||
backend: nccl
|
||||
port: 29500
|
||||
144
options/train_realesrnet_x4plus.yml
Normal file
144
options/train_realesrnet_x4plus.yml
Normal file
@@ -0,0 +1,144 @@
|
||||
# general settings
|
||||
name: train_RealESRNetx4plus_1000k_B12G4_fromESRGAN
|
||||
model_type: RealESRNetModel
|
||||
scale: 4
|
||||
num_gpu: 4
|
||||
manual_seed: 0
|
||||
|
||||
# ----------------- options for synthesizing training data in RealESRNetModel ----------------- #
|
||||
gt_usm: True # USM the ground-truth
|
||||
|
||||
# the first degradation process
|
||||
resize_prob: [0.2, 0.7, 0.1] # up, down, keep
|
||||
resize_range: [0.15, 1.5]
|
||||
gaussian_noise_prob: 0.5
|
||||
noise_range: [1, 30]
|
||||
poisson_scale_range: [0.05, 3]
|
||||
gray_noise_prob: 0.4
|
||||
jpeg_range: [30, 95]
|
||||
|
||||
# the second degradation process
|
||||
second_blur_prob: 0.8
|
||||
resize_prob2: [0.3, 0.4, 0.3] # up, down, keep
|
||||
resize_range2: [0.3, 1.2]
|
||||
gaussian_noise_prob2: 0.5
|
||||
noise_range2: [1, 25]
|
||||
poisson_scale_range2: [0.05, 2.5]
|
||||
gray_noise_prob2: 0.4
|
||||
jpeg_range2: [30, 95]
|
||||
|
||||
gt_size: 256
|
||||
queue_size: 180
|
||||
|
||||
# dataset and data loader settings
|
||||
datasets:
|
||||
train:
|
||||
name: DF2K+OST
|
||||
type: RealESRGANDataset
|
||||
dataroot_gt: datasets/DF2K
|
||||
meta_info: data/meta_info/meta_info_DF2Kmultiscale+OST_sub.txt
|
||||
io_backend:
|
||||
type: disk
|
||||
|
||||
blur_kernel_size: 21
|
||||
kernel_list: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
|
||||
kernel_prob: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
|
||||
sinc_prob: 0.1
|
||||
blur_sigma: [0.2, 3]
|
||||
betag_range: [0.5, 4]
|
||||
betap_range: [1, 2]
|
||||
|
||||
blur_kernel_size2: 21
|
||||
kernel_list2: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
|
||||
kernel_prob2: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
|
||||
sinc_prob2: 0.1
|
||||
blur_sigma2: [0.2, 1.5]
|
||||
betag_range2: [0.5, 4]
|
||||
betap_range2: [1, 2]
|
||||
|
||||
final_sinc_prob: 0.8
|
||||
|
||||
gt_size: 256
|
||||
use_hflip: True
|
||||
use_rot: False
|
||||
|
||||
# data loader
|
||||
use_shuffle: true
|
||||
num_worker_per_gpu: 5
|
||||
batch_size_per_gpu: 12
|
||||
dataset_enlarge_ratio: 1
|
||||
prefetch_mode: ~
|
||||
|
||||
# Uncomment these for validation
|
||||
# val:
|
||||
# name: validation
|
||||
# type: PairedImageDataset
|
||||
# dataroot_gt: path_to_gt
|
||||
# dataroot_lq: path_to_lq
|
||||
# io_backend:
|
||||
# type: disk
|
||||
|
||||
# network structures
|
||||
network_g:
|
||||
type: RRDBNet
|
||||
num_in_ch: 3
|
||||
num_out_ch: 3
|
||||
num_feat: 64
|
||||
num_block: 23
|
||||
num_grow_ch: 32
|
||||
|
||||
# path
|
||||
path:
|
||||
pretrain_network_g: experiments/pretrained_models/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth
|
||||
param_key_g: params_ema
|
||||
strict_load_g: true
|
||||
resume_state: ~
|
||||
|
||||
# training settings
|
||||
train:
|
||||
ema_decay: 0.999
|
||||
optim_g:
|
||||
type: Adam
|
||||
lr: !!float 2e-4
|
||||
weight_decay: 0
|
||||
betas: [0.9, 0.99]
|
||||
|
||||
scheduler:
|
||||
type: MultiStepLR
|
||||
milestones: [1000000]
|
||||
gamma: 0.5
|
||||
|
||||
total_iter: 1000000
|
||||
warmup_iter: -1 # no warm up
|
||||
|
||||
# losses
|
||||
pixel_opt:
|
||||
type: L1Loss
|
||||
loss_weight: 1.0
|
||||
reduction: mean
|
||||
|
||||
# Uncomment these for validation
|
||||
# validation settings
|
||||
# val:
|
||||
# val_freq: !!float 5e3
|
||||
# save_img: True
|
||||
|
||||
# metrics:
|
||||
# psnr: # metric name, can be arbitrary
|
||||
# type: calculate_psnr
|
||||
# crop_border: 4
|
||||
# test_y_channel: false
|
||||
|
||||
# logging settings
|
||||
logger:
|
||||
print_freq: 100
|
||||
save_checkpoint_freq: !!float 5e3
|
||||
use_tb_logger: true
|
||||
wandb:
|
||||
project: ~
|
||||
resume_id: ~
|
||||
|
||||
# dist training settings
|
||||
dist_params:
|
||||
backend: nccl
|
||||
port: 29500
|
||||
10
train.py
Normal file
10
train.py
Normal file
@@ -0,0 +1,10 @@
|
||||
import os.path as osp
|
||||
from basicsr.train import train_pipeline
|
||||
|
||||
import archs # noqa: F401
|
||||
import data # noqa: F401
|
||||
import models # noqa: F401
|
||||
|
||||
if __name__ == '__main__':
|
||||
root_path = osp.abspath(osp.join(__file__, osp.pardir))
|
||||
train_pipeline(root_path)
|
||||
Reference in New Issue
Block a user