Compare commits
38 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
64ad194dda | ||
|
|
5745599813 | ||
|
|
3ce0c97e89 | ||
|
|
13186ac2c2 | ||
|
|
4356ba0578 | ||
|
|
32a4fa1772 | ||
|
|
18ebf723f2 | ||
|
|
1f83ce5432 | ||
|
|
bef5e3cabd | ||
|
|
064df9956b | ||
|
|
52eab16d11 | ||
|
|
94ae626008 | ||
|
|
9baa0b3d00 | ||
|
|
f932289af1 | ||
|
|
f59a0c66ec | ||
|
|
935993a040 | ||
|
|
74fcfea286 | ||
|
|
da1e1ee805 | ||
|
|
1d8745eb61 | ||
|
|
c94d2de155 | ||
|
|
f4297a70af | ||
|
|
492a829c14 | ||
|
|
0573f32dd0 | ||
|
|
8454fd2c7a | ||
|
|
ad2ff81725 | ||
|
|
2ed98632c0 | ||
|
|
37d3c81c34 | ||
|
|
8d982a2173 | ||
|
|
ee820df2e2 | ||
|
|
248cbedbce | ||
|
|
e9f056fb63 | ||
|
|
ec767200ed | ||
|
|
1b45b66436 | ||
|
|
9baeba566b | ||
|
|
85bd346714 | ||
|
|
a1713103c3 | ||
|
|
7fcc11f255 | ||
|
|
cc153d278a |
30
.github/workflows/publish-pip.yml
vendored
Normal file
30
.github/workflows/publish-pip.yml
vendored
Normal file
@@ -0,0 +1,30 @@
|
||||
name: PyPI Publish
|
||||
|
||||
on: push
|
||||
|
||||
jobs:
|
||||
build-n-publish:
|
||||
runs-on: ubuntu-latest
|
||||
if: startsWith(github.event.ref, 'refs/tags')
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- name: Set up Python 3.8
|
||||
uses: actions/setup-python@v1
|
||||
with:
|
||||
python-version: 3.8
|
||||
- name: Upgrade pip
|
||||
run: pip install pip --upgrade
|
||||
- name: Install PyTorch (cpu)
|
||||
run: pip install torch==1.7.0+cpu torchvision==0.8.1+cpu -f https://download.pytorch.org/whl/torch_stable.html
|
||||
- name: Install dependencies
|
||||
run: pip install -r requirements.txt
|
||||
- name: Build and install
|
||||
run: rm -rf .eggs && pip install -e .
|
||||
- name: Build for distribution
|
||||
# remove bdist_wheel for pip installation with compiling cuda extensions
|
||||
run: python setup.py sdist
|
||||
- name: Publish distribution to PyPI
|
||||
uses: pypa/gh-action-pypi-publish@master
|
||||
with:
|
||||
password: ${{ secrets.PYPI_API_TOKEN }}
|
||||
6
.github/workflows/pylint.yml
vendored
6
.github/workflows/pylint.yml
vendored
@@ -1,4 +1,4 @@
|
||||
name: Python Lint
|
||||
name: PyLint
|
||||
|
||||
on: [push, pull_request]
|
||||
|
||||
@@ -26,5 +26,5 @@ jobs:
|
||||
- name: Lint
|
||||
run: |
|
||||
flake8 .
|
||||
isort --check-only --diff basicsr/ options/ scripts/ tests/ inference/ setup.py
|
||||
yapf -r -d basicsr/ options/ scripts/ tests/ inference/ setup.py
|
||||
isort --check-only --diff realesrgan/ scripts/ inference_realesrgan.py setup.py
|
||||
yapf -r -d realesrgan/ scripts/ inference_realesrgan.py setup.py
|
||||
|
||||
9
.gitignore
vendored
9
.gitignore
vendored
@@ -1,3 +1,12 @@
|
||||
# ignored folders
|
||||
datasets/*
|
||||
experiments/*
|
||||
results/*
|
||||
tb_logger/*
|
||||
wandb/*
|
||||
tmp/*
|
||||
|
||||
version.py
|
||||
.vscode
|
||||
|
||||
# Byte-compiled / optimized / DLL files
|
||||
|
||||
29
LICENSE
Normal file
29
LICENSE
Normal file
@@ -0,0 +1,29 @@
|
||||
BSD 3-Clause License
|
||||
|
||||
Copyright (c) 2021, Xintao Wang
|
||||
All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
|
||||
1. Redistributions of source code must retain the above copyright notice, this
|
||||
list of conditions and the following disclaimer.
|
||||
|
||||
2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
this list of conditions and the following disclaimer in the documentation
|
||||
and/or other materials provided with the distribution.
|
||||
|
||||
3. Neither the name of the copyright holder nor the names of its
|
||||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
8
MANIFEST.in
Normal file
8
MANIFEST.in
Normal file
@@ -0,0 +1,8 @@
|
||||
include assets/*
|
||||
include inputs/*
|
||||
include scripts/*.py
|
||||
include inference_realesrgan.py
|
||||
include VERSION
|
||||
include LICENSE
|
||||
include requirements.txt
|
||||
include realesrgan/weights/README.md
|
||||
92
README.md
92
README.md
@@ -1,46 +1,81 @@
|
||||
# Real-ESRGAN
|
||||
|
||||
[**Paper**](https://arxiv.org/abs/2101.04061)
|
||||
[](https://github.com/xinntao/Real-ESRGAN/releases)
|
||||
[](https://pypi.org/project/realesrgan/)
|
||||
[](https://github.com/xinntao/Real-ESRGAN/issues)
|
||||
[](https://github.com/xinntao/Real-ESRGAN/blob/master/LICENSE)
|
||||
[](https://github.com/xinntao/Real-ESRGAN/blob/master/.github/workflows/pylint.yml)
|
||||
[](https://github.com/xinntao/Real-ESRGAN/blob/master/.github/workflows/publish-pip.yml)
|
||||
|
||||
1. [Colab Demo](https://colab.research.google.com/drive/1k2Zod6kSHEvraybHl50Lys0LerhyTMCo?usp=sharing) for Real-ESRGAN <a href="https://colab.research.google.com/drive/1k2Zod6kSHEvraybHl50Lys0LerhyTMCo?usp=sharing"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="google colab logo"></a>.
|
||||
2. [Portable Windows/Linux/MacOS executable files for Intel/AMD/Nvidia GPU](https://github.com/xinntao/Real-ESRGAN/releases). You can find more information [here](#Portable-executable-files).
|
||||
|
||||
Real-ESRGAN aims at developing **Practical Algorithms for General Image Restoration**.<br>
|
||||
We extend the powerful ESRGAN to a practical restoration application (namely, Real-ESRGAN), which is trained with pure synthetic data.
|
||||
|
||||
:triangular_flag_on_post: **Updates**
|
||||
|
||||
- :white_check_mark: Support arbitrary scale with `--outscale` (It actually further resizes outputs with `LANCZOS4`). Add *RealESRGAN_x2plus.pth* model.
|
||||
- :white_check_mark: [The inference code](inference_realesrgan.py) supports: 1) **tile** options; 2) images with **alpha channel**; 3) **gray** images; 4) **16-bit** images.
|
||||
- :white_check_mark: The training codes have been released. A detailed guide can be found in [Training.md](Training.md).
|
||||
|
||||
### :book: Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data
|
||||
|
||||
> [[Paper](https://arxiv.org/abs/2101.04061)]   [Project Page]   [Demo] <br>
|
||||
> [[Paper](https://arxiv.org/abs/2107.10833)]   [Project Page]   [Demo] <br>
|
||||
> [Xintao Wang](https://xinntao.github.io/), Liangbin Xie, [Chao Dong](https://scholar.google.com.hk/citations?user=OSDCB0UAAAAJ), [Ying Shan](https://scholar.google.com/citations?user=4oXBp9UAAAAJ&hl=en) <br>
|
||||
> Applied Research Center (ARC), Tencent PCG; Shenzhen Institutes of Advanced Technology, Chinese Academy of Sciences
|
||||
> Applied Research Center (ARC), Tencent PCG<br>
|
||||
> Shenzhen Institutes of Advanced Technology, Chinese Academy of Sciences
|
||||
|
||||
<p align="center">
|
||||
<img src="assets/teaser.jpg">
|
||||
</p>
|
||||
|
||||
#### Abstract
|
||||
---
|
||||
|
||||
Though many attempts have been made in blind super-resolution to restore low-resolution images with unknown and complex degradations, they are still far from addressing general real-world degraded images. In this work, we extend the powerful ESRGAN to a practical restoration application (namely, Real-ESRGAN), which is trained with pure synthetic data. Specifically, a high-order degradation modeling process is introduced to better simulate complex real-world degradations. We also consider the common ringing and overshoot artifacts in the synthesis process. In addition, we employ a U-Net discriminator with spectral normalization to increase discriminator capability and stabilize the training dynamics. Extensive comparisons have shown its superior visual performance than prior works on various real datasets. We also provide efficient implementations to synthesize training pairs on the fly.
|
||||
We have provided a pretrained model (*RealESRGAN_x4plus.pth*) with upsampling X4.<br>
|
||||
**Note that RealESRGAN may still fail in some cases as the real-world degradations are really too complex.**<br>
|
||||
Moreover, it **may not** perform well on **human faces, text**, *etc*, which will be optimized later.
|
||||
<br>
|
||||
|
||||
#### BibTeX
|
||||
Real-ESRGAN will be a long-term supported project (in my current plan :smiley:). It will be continuously updated
|
||||
in my spare time.
|
||||
|
||||
@Article{wang2021realesrgan,
|
||||
title={Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data},
|
||||
author={Xintao Wang and Liangbin Xie and Chao Dong and Ying Shan},
|
||||
journal={arXiv:2107.xxxxx},
|
||||
year={2021}
|
||||
}
|
||||
Here is a TODO list in the near future:
|
||||
|
||||
- [ ] optimize for human faces
|
||||
- [ ] optimize for texts
|
||||
- [ ] optimize for animation images
|
||||
- [ ] support more scales
|
||||
- [ ] support controllable restoration strength
|
||||
|
||||
If you have any good ideas or demands, please open an issue/discussion to let me know. <br>
|
||||
If you have some images that Real-ESRGAN could not well restored, please also open an issue/discussion. I will record it (but I cannot guarantee to resolve it:stuck_out_tongue:). If necessary, I will open a page to specially record these real-world cases that need to be solved, but the current technology is difficult to handle well.
|
||||
|
||||
---
|
||||
|
||||
We are cleaning the training codes. It will be finished on 23 or 24, July.
|
||||
### Portable executable files
|
||||
|
||||
---
|
||||
You can download **Windows/Linux/MacOS executable files for Intel/AMD/Nvidia GPU** from https://github.com/xinntao/Real-ESRGAN/releases
|
||||
|
||||
You can download **Windows executable files** from https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN-ncnn-vulkan.zip
|
||||
This executable file is **portable** and includes all the binaries and models required. No CUDA or PyTorch environment is needed.<br>
|
||||
|
||||
You can simply run the following command (the Windows example, more information is in the README.md of each executable files):
|
||||
|
||||
You can simply run the following command:
|
||||
```bash
|
||||
./realesrgan-ncnn-vulkan.exe -i input.jpg -o output.png
|
||||
```
|
||||
|
||||
Note that it may introduce block artifacts (and also generate slightly different results from the PyTorch implementation), because this executable file first crops the input image into several tiles, and then processes them separately, finally stitches together.
|
||||
We have provided three models:
|
||||
|
||||
This executable file is based on the wonderful [ncnn project](https://github.com/Tencent/ncnn) and [realsr-ncnn-vulkan](https://github.com/nihui/realsr-ncnn-vulkan).
|
||||
1. realesrgan-x4plus (default)
|
||||
2. realesrnet-x4plus
|
||||
3. esrgan-x4
|
||||
|
||||
You can use the `-n` argument for other models, for example, `./realesrgan-ncnn-vulkan.exe -i input.jpg -o output.png -n realesrnet-x4plus`
|
||||
|
||||
Note that it may introduce block inconsistency (and also generate slightly different results from the PyTorch implementation), because this executable file first crops the input image into several tiles, and then processes them separately, finally stitches together.
|
||||
|
||||
This executable file is based on the wonderful [Tencent/ncnn](https://github.com/Tencent/ncnn) and [realsr-ncnn-vulkan](https://github.com/nihui/realsr-ncnn-vulkan) by [nihui](https://github.com/nihui).
|
||||
|
||||
---
|
||||
|
||||
@@ -65,6 +100,7 @@ This executable file is based on the wonderful [ncnn project](https://github.com
|
||||
# We use BasicSR for both training and inference
|
||||
pip install basicsr
|
||||
pip install -r requirements.txt
|
||||
python setup.py develop
|
||||
```
|
||||
|
||||
## :zap: Quick Inference
|
||||
@@ -85,6 +121,26 @@ python inference_realesrgan.py --model_path experiments/pretrained_models/RealES
|
||||
|
||||
Results are in the `results` folder
|
||||
|
||||
## :european_castle: Model Zoo
|
||||
|
||||
- [RealESRGAN-x4plus](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth)
|
||||
- [RealESRNet-x4plus](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth)
|
||||
- [RealESRGAN-x2plus](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.0/RealESRGAN_x2plus.pth)
|
||||
- [official ESRGAN-x4](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth)
|
||||
|
||||
## :computer: Training
|
||||
|
||||
A detailed guide can be found in [Training.md](Training.md).
|
||||
|
||||
## BibTeX
|
||||
|
||||
@Article{wang2021realesrgan,
|
||||
title={Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data},
|
||||
author={Xintao Wang and Liangbin Xie and Chao Dong and Ying Shan},
|
||||
journal={arXiv:2107.10833},
|
||||
year={2021}
|
||||
}
|
||||
|
||||
## :e-mail: Contact
|
||||
|
||||
If you have any question, please email `xintao.wang@outlook.com` or `xintaowang@tencent.com`.
|
||||
|
||||
100
Training.md
Normal file
100
Training.md
Normal file
@@ -0,0 +1,100 @@
|
||||
# :computer: How to Train Real-ESRGAN
|
||||
|
||||
The training codes have been released. <br>
|
||||
Note that the codes have a lot of refactoring. So there may be some bugs/performance drops. Welcome to report issues and I will also retrain the models.
|
||||
|
||||
## Overview
|
||||
|
||||
The training has been divided into two stages. These two stages have the same data synthesis process and training pipeline, except for the loss functions. Specifically,
|
||||
|
||||
1. We first train Real-ESRNet with L1 loss from the pre-trained model ESRGAN.
|
||||
1. We then use the trained Real-ESRNet model as an initialization of the generator, and train the Real-ESRGAN with a combination of L1 loss, perceptual loss and GAN loss.
|
||||
|
||||
## Dataset Preparation
|
||||
|
||||
We use DF2K (DIV2K and Flickr2K) + OST datasets for our training. Only HR images are required. <br>
|
||||
You can download from :
|
||||
|
||||
1. DIV2K: http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_HR.zip
|
||||
2. Flickr2K: https://cv.snu.ac.kr/research/EDSR/Flickr2K.tar
|
||||
3. OST: https://openmmlab.oss-cn-hangzhou.aliyuncs.com/datasets/OST_dataset.zip
|
||||
|
||||
For the DF2K dataset, we use a multi-scale strategy, *i.e.*, we downsample HR images to obtain several Ground-Truth images with different scales.
|
||||
|
||||
We then crop DF2K images into sub-images for faster IO and processing.
|
||||
|
||||
You need to prepare a txt file containing the image paths. The following are some examples in `meta_info_DF2Kmultiscale+OST_sub.txt` (As different users may have different sub-images partitions, this file is not suitable for your purpose and you need to prepare your own txt file):
|
||||
|
||||
```txt
|
||||
DF2K_HR_sub/000001_s001.png
|
||||
DF2K_HR_sub/000001_s002.png
|
||||
DF2K_HR_sub/000001_s003.png
|
||||
...
|
||||
```
|
||||
|
||||
## Train Real-ESRNet
|
||||
|
||||
1. Download pre-trained model [ESRGAN](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth) into `experiments/pretrained_models`.
|
||||
```bash
|
||||
wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth -P experiments/pretrained_models
|
||||
```
|
||||
1. Modify the content in the option file `options/train_realesrnet_x4plus.yml` accordingly:
|
||||
```yml
|
||||
train:
|
||||
name: DF2K+OST
|
||||
type: RealESRGANDataset
|
||||
dataroot_gt: datasets/DF2K # modify to the root path of your folder
|
||||
meta_info: realesrgan/meta_info/meta_info_DF2Kmultiscale+OST_sub.txt # modify to your own generate meta info txt
|
||||
io_backend:
|
||||
type: disk
|
||||
```
|
||||
1. If you want to perform validation during training, uncomment those lines and modify accordingly:
|
||||
```yml
|
||||
# Uncomment these for validation
|
||||
# val:
|
||||
# name: validation
|
||||
# type: PairedImageDataset
|
||||
# dataroot_gt: path_to_gt
|
||||
# dataroot_lq: path_to_lq
|
||||
# io_backend:
|
||||
# type: disk
|
||||
|
||||
...
|
||||
|
||||
# Uncomment these for validation
|
||||
# validation settings
|
||||
# val:
|
||||
# val_freq: !!float 5e3
|
||||
# save_img: True
|
||||
|
||||
# metrics:
|
||||
# psnr: # metric name, can be arbitrary
|
||||
# type: calculate_psnr
|
||||
# crop_border: 4
|
||||
# test_y_channel: false
|
||||
```
|
||||
1. Before the formal training, you may run in the `--debug` mode to see whether everything is OK. We use four GPUs for training:
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3 \
|
||||
python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 realesrgan/train.py -opt options/train_realesrnet_x4plus.yml --launcher pytorch --debug
|
||||
```
|
||||
1. The formal training. We use four GPUs for training. We use the `--auto_resume` argument to automatically resume the training if necessary.
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3 \
|
||||
python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 realesrgan/train.py -opt options/train_realesrnet_x4plus.yml --launcher pytorch --auto_resume
|
||||
```
|
||||
|
||||
## Train Real-ESRGAN
|
||||
|
||||
1. After the training of Real-ESRNet, you now have the file `experiments/train_RealESRNetx4plus_1000k_B12G4_fromESRGAN/model/net_g_1000000.pth`. If you need to specify the pre-trained path to other files, modify the `pretrain_network_g` value in the option file `train_realesrgan_x4plus.yml`.
|
||||
1. Modify the option file `train_realesrgan_x4plus.yml` accordingly. Most modifications are similar to those listed above.
|
||||
1. Before the formal training, you may run in the `--debug` mode to see whether everything is OK. We use four GPUs for training:
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3 \
|
||||
python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 realesrgan/train.py -opt options/train_realesrgan_x4plus.yml --launcher pytorch --debug
|
||||
```
|
||||
1. The formal training. We use four GPUs for training. We use the `--auto_resume` argument to automatically resume the training if necessary.
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3 \
|
||||
python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 realesrgan/train.py -opt options/train_realesrgan_x4plus.yml --launcher pytorch --auto_resume
|
||||
```
|
||||
1
experiments/pretrained_models/README.md
Normal file
1
experiments/pretrained_models/README.md
Normal file
@@ -0,0 +1 @@
|
||||
# Put downloaded pre-trained models here
|
||||
@@ -1,67 +1,76 @@
|
||||
import argparse
|
||||
import cv2
|
||||
import glob
|
||||
import numpy as np
|
||||
import os
|
||||
import torch
|
||||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||
from torch.nn import functional as F
|
||||
|
||||
from realesrgan import RealESRGANer
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--model_path', type=str, default='experiments/pretrained_models/RealESRGAN_x4plus.pth')
|
||||
parser.add_argument('--scale', type=int, default=4)
|
||||
parser.add_argument('--input', type=str, default='inputs', help='input image or folder')
|
||||
parser.add_argument('--input', type=str, default='inputs', help='Input image or folder')
|
||||
parser.add_argument(
|
||||
'--model_path',
|
||||
type=str,
|
||||
default='experiments/pretrained_models/RealESRGAN_x4plus.pth',
|
||||
help='Path to the pre-trained model')
|
||||
parser.add_argument('--output', type=str, default='results', help='Output folder')
|
||||
parser.add_argument('--netscale', type=int, default=4, help='Upsample scale factor of the network')
|
||||
parser.add_argument('--outscale', type=float, default=4, help='The final upsampling scale of the image')
|
||||
parser.add_argument('--suffix', type=str, default='out', help='Suffix of the restored image')
|
||||
parser.add_argument('--tile', type=int, default=0, help='Tile size, 0 for no tile during testing')
|
||||
parser.add_argument('--tile_pad', type=int, default=10, help='Tile padding')
|
||||
parser.add_argument('--pre_pad', type=int, default=0, help='Pre padding size at each border')
|
||||
parser.add_argument('--half', action='store_true', help='Use half precision during inference')
|
||||
parser.add_argument(
|
||||
'--alpha_upsampler',
|
||||
type=str,
|
||||
default='realesrgan',
|
||||
help='The upsampler for the alpha channels. Options: realesrgan | bicubic')
|
||||
parser.add_argument(
|
||||
'--ext',
|
||||
type=str,
|
||||
default='auto',
|
||||
help='Image extension. Options: auto | jpg | png, auto means using the same extension as inputs')
|
||||
args = parser.parse_args()
|
||||
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
# set up model
|
||||
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=args.scale)
|
||||
loadnet = torch.load(args.model_path)
|
||||
model.load_state_dict(loadnet['params_ema'], strict=True)
|
||||
model.eval()
|
||||
model = model.to(device)
|
||||
upsampler = RealESRGANer(
|
||||
scale=args.netscale,
|
||||
model_path=args.model_path,
|
||||
tile=args.tile,
|
||||
tile_pad=args.tile_pad,
|
||||
pre_pad=args.pre_pad,
|
||||
half=args.half)
|
||||
os.makedirs(args.output, exist_ok=True)
|
||||
if os.path.isfile(args.input):
|
||||
paths = [args.input]
|
||||
else:
|
||||
paths = sorted(glob.glob(os.path.join(args.input, '*')))
|
||||
|
||||
os.makedirs('results/', exist_ok=True)
|
||||
for idx, path in enumerate(sorted(glob.glob(os.path.join(args.input, '*')))):
|
||||
imgname = os.path.splitext(os.path.basename(path))[0]
|
||||
for idx, path in enumerate(paths):
|
||||
imgname, extension = os.path.splitext(os.path.basename(path))
|
||||
print('Testing', idx, imgname)
|
||||
# read image
|
||||
img = cv2.imread(path, cv2.IMREAD_COLOR).astype(np.float32) / 255.
|
||||
img = torch.from_numpy(np.transpose(img[:, :, [2, 1, 0]], (2, 0, 1))).float()
|
||||
img = img.unsqueeze(0).to(device)
|
||||
|
||||
if args.scale == 2:
|
||||
mod_scale = 2
|
||||
elif args.scale == 1:
|
||||
mod_scale = 4
|
||||
else:
|
||||
mod_scale = None
|
||||
if mod_scale is not None:
|
||||
h_pad, w_pad = 0, 0
|
||||
_, _, h, w = img.size()
|
||||
if (h % mod_scale != 0):
|
||||
h_pad = (mod_scale - h % mod_scale)
|
||||
if (w % mod_scale != 0):
|
||||
w_pad = (mod_scale - w % mod_scale)
|
||||
img = F.pad(img, (0, w_pad, 0, h_pad), 'reflect')
|
||||
img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
|
||||
h, w = img.shape[0:2]
|
||||
if max(h, w) > 1000 and args.netscale == 4:
|
||||
print('WARNING: The input image is large, try X2 model for better performace.')
|
||||
if max(h, w) < 500 and args.netscale == 2:
|
||||
print('WARNING: The input image is small, try X4 model for better performace.')
|
||||
|
||||
try:
|
||||
# inference
|
||||
with torch.no_grad():
|
||||
output = model(img)
|
||||
# remove extra pad
|
||||
if mod_scale is not None:
|
||||
_, _, h, w = output.size()
|
||||
output = output[:, :, 0:h - h_pad, 0:w - w_pad]
|
||||
# save image
|
||||
output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
|
||||
output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0))
|
||||
output = (output * 255.0).round().astype(np.uint8)
|
||||
cv2.imwrite(f'results/{imgname}_RealESRGAN.png', output)
|
||||
output, img_mode = upsampler.enhance(img, outscale=args.outscale)
|
||||
except Exception as error:
|
||||
print('Error', error)
|
||||
else:
|
||||
if args.ext == 'auto':
|
||||
extension = extension[1:]
|
||||
else:
|
||||
extension = args.ext
|
||||
if img_mode == 'RGBA': # RGBA images should be saved in png format
|
||||
extension = 'png'
|
||||
save_path = os.path.join(args.output, f'{imgname}_{args.suffix}.{extension}')
|
||||
cv2.imwrite(save_path, output)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
BIN
inputs/tree_alpha_16bit.png
Normal file
BIN
inputs/tree_alpha_16bit.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 373 KiB |
BIN
inputs/wolf_gray.jpg
Normal file
BIN
inputs/wolf_gray.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 52 KiB |
186
options/train_realesrgan_x4plus.yml
Normal file
186
options/train_realesrgan_x4plus.yml
Normal file
@@ -0,0 +1,186 @@
|
||||
# general settings
|
||||
name: train_RealESRGANx4plus_400k_B12G4_fromRealESRNet
|
||||
model_type: RealESRGANModel
|
||||
scale: 4
|
||||
num_gpu: 4
|
||||
manual_seed: 0
|
||||
|
||||
# ----------------- options for synthesizing training data in RealESRGANModel ----------------- #
|
||||
# USM the ground-truth
|
||||
l1_gt_usm: True
|
||||
percep_gt_usm: True
|
||||
gan_gt_usm: False
|
||||
|
||||
# the first degradation process
|
||||
resize_prob: [0.2, 0.7, 0.1] # up, down, keep
|
||||
resize_range: [0.15, 1.5]
|
||||
gaussian_noise_prob: 0.5
|
||||
noise_range: [1, 30]
|
||||
poisson_scale_range: [0.05, 3]
|
||||
gray_noise_prob: 0.4
|
||||
jpeg_range: [30, 95]
|
||||
|
||||
# the second degradation process
|
||||
second_blur_prob: 0.8
|
||||
resize_prob2: [0.3, 0.4, 0.3] # up, down, keep
|
||||
resize_range2: [0.3, 1.2]
|
||||
gaussian_noise_prob2: 0.5
|
||||
noise_range2: [1, 25]
|
||||
poisson_scale_range2: [0.05, 2.5]
|
||||
gray_noise_prob2: 0.4
|
||||
jpeg_range2: [30, 95]
|
||||
|
||||
gt_size: 256
|
||||
queue_size: 180
|
||||
|
||||
# dataset and data loader settings
|
||||
datasets:
|
||||
train:
|
||||
name: DF2K+OST
|
||||
type: RealESRGANDataset
|
||||
dataroot_gt: datasets/DF2K
|
||||
meta_info: realesrgan/data/meta_info/meta_info_DF2Kmultiscale+OST_sub.txt
|
||||
io_backend:
|
||||
type: disk
|
||||
|
||||
blur_kernel_size: 21
|
||||
kernel_list: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
|
||||
kernel_prob: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
|
||||
sinc_prob: 0.1
|
||||
blur_sigma: [0.2, 3]
|
||||
betag_range: [0.5, 4]
|
||||
betap_range: [1, 2]
|
||||
|
||||
blur_kernel_size2: 21
|
||||
kernel_list2: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
|
||||
kernel_prob2: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
|
||||
sinc_prob2: 0.1
|
||||
blur_sigma2: [0.2, 1.5]
|
||||
betag_range2: [0.5, 4]
|
||||
betap_range2: [1, 2]
|
||||
|
||||
final_sinc_prob: 0.8
|
||||
|
||||
gt_size: 256
|
||||
use_hflip: True
|
||||
use_rot: False
|
||||
|
||||
# data loader
|
||||
use_shuffle: true
|
||||
num_worker_per_gpu: 5
|
||||
batch_size_per_gpu: 12
|
||||
dataset_enlarge_ratio: 1
|
||||
prefetch_mode: ~
|
||||
|
||||
# Uncomment these for validation
|
||||
# val:
|
||||
# name: validation
|
||||
# type: PairedImageDataset
|
||||
# dataroot_gt: path_to_gt
|
||||
# dataroot_lq: path_to_lq
|
||||
# io_backend:
|
||||
# type: disk
|
||||
|
||||
# network structures
|
||||
network_g:
|
||||
type: RRDBNet
|
||||
num_in_ch: 3
|
||||
num_out_ch: 3
|
||||
num_feat: 64
|
||||
num_block: 23
|
||||
num_grow_ch: 32
|
||||
|
||||
|
||||
network_d:
|
||||
type: UNetDiscriminatorSN
|
||||
num_in_ch: 3
|
||||
num_feat: 64
|
||||
skip_connection: True
|
||||
|
||||
# path
|
||||
path:
|
||||
# use the pre-trained Real-ESRNet model
|
||||
pretrain_network_g: experiments/train_RealESRNetx4plus_1000k_B12G4_fromESRGAN/models/net_g_1000000.pth
|
||||
param_key_g: params_ema
|
||||
strict_load_g: true
|
||||
resume_state: ~
|
||||
|
||||
# training settings
|
||||
train:
|
||||
ema_decay: 0.999
|
||||
optim_g:
|
||||
type: Adam
|
||||
lr: !!float 1e-4
|
||||
weight_decay: 0
|
||||
betas: [0.9, 0.99]
|
||||
optim_d:
|
||||
type: Adam
|
||||
lr: !!float 1e-4
|
||||
weight_decay: 0
|
||||
betas: [0.9, 0.99]
|
||||
|
||||
scheduler:
|
||||
type: MultiStepLR
|
||||
milestones: [400000]
|
||||
gamma: 0.5
|
||||
|
||||
total_iter: 400000
|
||||
warmup_iter: -1 # no warm up
|
||||
|
||||
# losses
|
||||
pixel_opt:
|
||||
type: L1Loss
|
||||
loss_weight: 1.0
|
||||
reduction: mean
|
||||
# perceptual loss (content and style losses)
|
||||
perceptual_opt:
|
||||
type: PerceptualLoss
|
||||
layer_weights:
|
||||
# before relu
|
||||
'conv1_2': 0.1
|
||||
'conv2_2': 0.1
|
||||
'conv3_4': 1
|
||||
'conv4_4': 1
|
||||
'conv5_4': 1
|
||||
vgg_type: vgg19
|
||||
use_input_norm: true
|
||||
perceptual_weight: !!float 1.0
|
||||
style_weight: 0
|
||||
range_norm: false
|
||||
criterion: l1
|
||||
# gan loss
|
||||
gan_opt:
|
||||
type: GANLoss
|
||||
gan_type: vanilla
|
||||
real_label_val: 1.0
|
||||
fake_label_val: 0.0
|
||||
loss_weight: !!float 1e-1
|
||||
|
||||
net_d_iters: 1
|
||||
net_d_init_iters: 0
|
||||
|
||||
# Uncomment these for validation
|
||||
# validation settings
|
||||
# val:
|
||||
# val_freq: !!float 5e3
|
||||
# save_img: True
|
||||
|
||||
# metrics:
|
||||
# psnr: # metric name, can be arbitrary
|
||||
# type: calculate_psnr
|
||||
# crop_border: 4
|
||||
# test_y_channel: false
|
||||
|
||||
# logging settings
|
||||
logger:
|
||||
print_freq: 100
|
||||
save_checkpoint_freq: !!float 5e3
|
||||
use_tb_logger: true
|
||||
wandb:
|
||||
project: ~
|
||||
resume_id: ~
|
||||
|
||||
# dist training settings
|
||||
dist_params:
|
||||
backend: nccl
|
||||
port: 29500
|
||||
144
options/train_realesrnet_x4plus.yml
Normal file
144
options/train_realesrnet_x4plus.yml
Normal file
@@ -0,0 +1,144 @@
|
||||
# general settings
|
||||
name: train_RealESRNetx4plus_1000k_B12G4_fromESRGAN
|
||||
model_type: RealESRNetModel
|
||||
scale: 4
|
||||
num_gpu: 4
|
||||
manual_seed: 0
|
||||
|
||||
# ----------------- options for synthesizing training data in RealESRNetModel ----------------- #
|
||||
gt_usm: True # USM the ground-truth
|
||||
|
||||
# the first degradation process
|
||||
resize_prob: [0.2, 0.7, 0.1] # up, down, keep
|
||||
resize_range: [0.15, 1.5]
|
||||
gaussian_noise_prob: 0.5
|
||||
noise_range: [1, 30]
|
||||
poisson_scale_range: [0.05, 3]
|
||||
gray_noise_prob: 0.4
|
||||
jpeg_range: [30, 95]
|
||||
|
||||
# the second degradation process
|
||||
second_blur_prob: 0.8
|
||||
resize_prob2: [0.3, 0.4, 0.3] # up, down, keep
|
||||
resize_range2: [0.3, 1.2]
|
||||
gaussian_noise_prob2: 0.5
|
||||
noise_range2: [1, 25]
|
||||
poisson_scale_range2: [0.05, 2.5]
|
||||
gray_noise_prob2: 0.4
|
||||
jpeg_range2: [30, 95]
|
||||
|
||||
gt_size: 256
|
||||
queue_size: 180
|
||||
|
||||
# dataset and data loader settings
|
||||
datasets:
|
||||
train:
|
||||
name: DF2K+OST
|
||||
type: RealESRGANDataset
|
||||
dataroot_gt: datasets/DF2K
|
||||
meta_info: realesrgan/data/meta_info/meta_info_DF2Kmultiscale+OST_sub.txt
|
||||
io_backend:
|
||||
type: disk
|
||||
|
||||
blur_kernel_size: 21
|
||||
kernel_list: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
|
||||
kernel_prob: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
|
||||
sinc_prob: 0.1
|
||||
blur_sigma: [0.2, 3]
|
||||
betag_range: [0.5, 4]
|
||||
betap_range: [1, 2]
|
||||
|
||||
blur_kernel_size2: 21
|
||||
kernel_list2: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
|
||||
kernel_prob2: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
|
||||
sinc_prob2: 0.1
|
||||
blur_sigma2: [0.2, 1.5]
|
||||
betag_range2: [0.5, 4]
|
||||
betap_range2: [1, 2]
|
||||
|
||||
final_sinc_prob: 0.8
|
||||
|
||||
gt_size: 256
|
||||
use_hflip: True
|
||||
use_rot: False
|
||||
|
||||
# data loader
|
||||
use_shuffle: true
|
||||
num_worker_per_gpu: 5
|
||||
batch_size_per_gpu: 12
|
||||
dataset_enlarge_ratio: 1
|
||||
prefetch_mode: ~
|
||||
|
||||
# Uncomment these for validation
|
||||
# val:
|
||||
# name: validation
|
||||
# type: PairedImageDataset
|
||||
# dataroot_gt: path_to_gt
|
||||
# dataroot_lq: path_to_lq
|
||||
# io_backend:
|
||||
# type: disk
|
||||
|
||||
# network structures
|
||||
network_g:
|
||||
type: RRDBNet
|
||||
num_in_ch: 3
|
||||
num_out_ch: 3
|
||||
num_feat: 64
|
||||
num_block: 23
|
||||
num_grow_ch: 32
|
||||
|
||||
# path
|
||||
path:
|
||||
pretrain_network_g: experiments/pretrained_models/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth
|
||||
param_key_g: params_ema
|
||||
strict_load_g: true
|
||||
resume_state: ~
|
||||
|
||||
# training settings
|
||||
train:
|
||||
ema_decay: 0.999
|
||||
optim_g:
|
||||
type: Adam
|
||||
lr: !!float 2e-4
|
||||
weight_decay: 0
|
||||
betas: [0.9, 0.99]
|
||||
|
||||
scheduler:
|
||||
type: MultiStepLR
|
||||
milestones: [1000000]
|
||||
gamma: 0.5
|
||||
|
||||
total_iter: 1000000
|
||||
warmup_iter: -1 # no warm up
|
||||
|
||||
# losses
|
||||
pixel_opt:
|
||||
type: L1Loss
|
||||
loss_weight: 1.0
|
||||
reduction: mean
|
||||
|
||||
# Uncomment these for validation
|
||||
# validation settings
|
||||
# val:
|
||||
# val_freq: !!float 5e3
|
||||
# save_img: True
|
||||
|
||||
# metrics:
|
||||
# psnr: # metric name, can be arbitrary
|
||||
# type: calculate_psnr
|
||||
# crop_border: 4
|
||||
# test_y_channel: false
|
||||
|
||||
# logging settings
|
||||
logger:
|
||||
print_freq: 100
|
||||
save_checkpoint_freq: !!float 5e3
|
||||
use_tb_logger: true
|
||||
wandb:
|
||||
project: ~
|
||||
resume_id: ~
|
||||
|
||||
# dist training settings
|
||||
dist_params:
|
||||
backend: nccl
|
||||
port: 29500
|
||||
6
realesrgan/__init__.py
Normal file
6
realesrgan/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
# flake8: noqa
|
||||
from .archs import *
|
||||
from .data import *
|
||||
from .models import *
|
||||
from .utils import *
|
||||
from .version import __gitsha__, __version__
|
||||
10
realesrgan/archs/__init__.py
Normal file
10
realesrgan/archs/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
||||
import importlib
|
||||
from basicsr.utils import scandir
|
||||
from os import path as osp
|
||||
|
||||
# automatically scan and import arch modules for registry
|
||||
# scan all the files that end with '_arch.py' under the archs folder
|
||||
arch_folder = osp.dirname(osp.abspath(__file__))
|
||||
arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')]
|
||||
# import all the arch modules
|
||||
_arch_modules = [importlib.import_module(f'realesrgan.archs.{file_name}') for file_name in arch_filenames]
|
||||
60
realesrgan/archs/discriminator_arch.py
Normal file
60
realesrgan/archs/discriminator_arch.py
Normal file
@@ -0,0 +1,60 @@
|
||||
from basicsr.utils.registry import ARCH_REGISTRY
|
||||
from torch import nn as nn
|
||||
from torch.nn import functional as F
|
||||
from torch.nn.utils import spectral_norm
|
||||
|
||||
|
||||
@ARCH_REGISTRY.register()
|
||||
class UNetDiscriminatorSN(nn.Module):
|
||||
"""Defines a U-Net discriminator with spectral normalization (SN)"""
|
||||
|
||||
def __init__(self, num_in_ch, num_feat=64, skip_connection=True):
|
||||
super(UNetDiscriminatorSN, self).__init__()
|
||||
self.skip_connection = skip_connection
|
||||
norm = spectral_norm
|
||||
|
||||
self.conv0 = nn.Conv2d(num_in_ch, num_feat, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
self.conv1 = norm(nn.Conv2d(num_feat, num_feat * 2, 4, 2, 1, bias=False))
|
||||
self.conv2 = norm(nn.Conv2d(num_feat * 2, num_feat * 4, 4, 2, 1, bias=False))
|
||||
self.conv3 = norm(nn.Conv2d(num_feat * 4, num_feat * 8, 4, 2, 1, bias=False))
|
||||
# upsample
|
||||
self.conv4 = norm(nn.Conv2d(num_feat * 8, num_feat * 4, 3, 1, 1, bias=False))
|
||||
self.conv5 = norm(nn.Conv2d(num_feat * 4, num_feat * 2, 3, 1, 1, bias=False))
|
||||
self.conv6 = norm(nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1, bias=False))
|
||||
|
||||
# extra
|
||||
self.conv7 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False))
|
||||
self.conv8 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False))
|
||||
|
||||
self.conv9 = nn.Conv2d(num_feat, 1, 3, 1, 1)
|
||||
|
||||
def forward(self, x):
|
||||
x0 = F.leaky_relu(self.conv0(x), negative_slope=0.2, inplace=True)
|
||||
x1 = F.leaky_relu(self.conv1(x0), negative_slope=0.2, inplace=True)
|
||||
x2 = F.leaky_relu(self.conv2(x1), negative_slope=0.2, inplace=True)
|
||||
x3 = F.leaky_relu(self.conv3(x2), negative_slope=0.2, inplace=True)
|
||||
|
||||
# upsample
|
||||
x3 = F.interpolate(x3, scale_factor=2, mode='bilinear', align_corners=False)
|
||||
x4 = F.leaky_relu(self.conv4(x3), negative_slope=0.2, inplace=True)
|
||||
|
||||
if self.skip_connection:
|
||||
x4 = x4 + x2
|
||||
x4 = F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False)
|
||||
x5 = F.leaky_relu(self.conv5(x4), negative_slope=0.2, inplace=True)
|
||||
|
||||
if self.skip_connection:
|
||||
x5 = x5 + x1
|
||||
x5 = F.interpolate(x5, scale_factor=2, mode='bilinear', align_corners=False)
|
||||
x6 = F.leaky_relu(self.conv6(x5), negative_slope=0.2, inplace=True)
|
||||
|
||||
if self.skip_connection:
|
||||
x6 = x6 + x0
|
||||
|
||||
# extra
|
||||
out = F.leaky_relu(self.conv7(x6), negative_slope=0.2, inplace=True)
|
||||
out = F.leaky_relu(self.conv8(out), negative_slope=0.2, inplace=True)
|
||||
out = self.conv9(out)
|
||||
|
||||
return out
|
||||
10
realesrgan/data/__init__.py
Normal file
10
realesrgan/data/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
||||
import importlib
|
||||
from basicsr.utils import scandir
|
||||
from os import path as osp
|
||||
|
||||
# automatically scan and import dataset modules for registry
|
||||
# scan all the files that end with '_dataset.py' under the data folder
|
||||
data_folder = osp.dirname(osp.abspath(__file__))
|
||||
dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')]
|
||||
# import all the dataset modules
|
||||
_dataset_modules = [importlib.import_module(f'realesrgan.data.{file_name}') for file_name in dataset_filenames]
|
||||
175
realesrgan/data/realesrgan_dataset.py
Normal file
175
realesrgan/data/realesrgan_dataset.py
Normal file
@@ -0,0 +1,175 @@
|
||||
import cv2
|
||||
import math
|
||||
import numpy as np
|
||||
import os
|
||||
import os.path as osp
|
||||
import random
|
||||
import time
|
||||
import torch
|
||||
from basicsr.data.degradations import circular_lowpass_kernel, random_mixed_kernels
|
||||
from basicsr.data.transforms import augment
|
||||
from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
|
||||
from basicsr.utils.registry import DATASET_REGISTRY
|
||||
from torch.utils import data as data
|
||||
|
||||
|
||||
@DATASET_REGISTRY.register()
|
||||
class RealESRGANDataset(data.Dataset):
|
||||
"""
|
||||
Dataset used for Real-ESRGAN model.
|
||||
"""
|
||||
|
||||
def __init__(self, opt):
|
||||
super(RealESRGANDataset, self).__init__()
|
||||
self.opt = opt
|
||||
# file client (io backend)
|
||||
self.file_client = None
|
||||
self.io_backend_opt = opt['io_backend']
|
||||
self.gt_folder = opt['dataroot_gt']
|
||||
|
||||
if self.io_backend_opt['type'] == 'lmdb':
|
||||
self.io_backend_opt['db_paths'] = [self.gt_folder]
|
||||
self.io_backend_opt['client_keys'] = ['gt']
|
||||
if not self.gt_folder.endswith('.lmdb'):
|
||||
raise ValueError(f"'dataroot_gt' should end with '.lmdb', but received {self.gt_folder}")
|
||||
with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin:
|
||||
self.paths = [line.split('.')[0] for line in fin]
|
||||
else:
|
||||
with open(self.opt['meta_info']) as fin:
|
||||
paths = [line.strip() for line in fin]
|
||||
self.paths = [os.path.join(self.gt_folder, v) for v in paths]
|
||||
|
||||
# blur settings for the first degradation
|
||||
self.blur_kernel_size = opt['blur_kernel_size']
|
||||
self.kernel_list = opt['kernel_list']
|
||||
self.kernel_prob = opt['kernel_prob']
|
||||
self.blur_sigma = opt['blur_sigma']
|
||||
self.betag_range = opt['betag_range']
|
||||
self.betap_range = opt['betap_range']
|
||||
self.sinc_prob = opt['sinc_prob']
|
||||
|
||||
# blur settings for the second degradation
|
||||
self.blur_kernel_size2 = opt['blur_kernel_size2']
|
||||
self.kernel_list2 = opt['kernel_list2']
|
||||
self.kernel_prob2 = opt['kernel_prob2']
|
||||
self.blur_sigma2 = opt['blur_sigma2']
|
||||
self.betag_range2 = opt['betag_range2']
|
||||
self.betap_range2 = opt['betap_range2']
|
||||
self.sinc_prob2 = opt['sinc_prob2']
|
||||
|
||||
# a final sinc filter
|
||||
self.final_sinc_prob = opt['final_sinc_prob']
|
||||
|
||||
self.kernel_range = [2 * v + 1 for v in range(3, 11)] # kernel size ranges from 7 to 21
|
||||
self.pulse_tensor = torch.zeros(21, 21).float() # convolving with pulse tensor brings no blurry effect
|
||||
self.pulse_tensor[10, 10] = 1
|
||||
|
||||
def __getitem__(self, index):
|
||||
if self.file_client is None:
|
||||
self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
|
||||
|
||||
# -------------------------------- Load gt images -------------------------------- #
|
||||
# Shape: (h, w, c); channel order: BGR; image range: [0, 1], float32.
|
||||
gt_path = self.paths[index]
|
||||
# avoid errors caused by high latency in reading files
|
||||
retry = 3
|
||||
while retry > 0:
|
||||
try:
|
||||
img_bytes = self.file_client.get(gt_path, 'gt')
|
||||
except Exception as e:
|
||||
logger = get_root_logger()
|
||||
logger.warn(f'File client error: {e}, remaining retry times: {retry - 1}')
|
||||
# change another file to read
|
||||
index = random.randint(0, self.__len__())
|
||||
gt_path = self.paths[index]
|
||||
time.sleep(1) # sleep 1s for occasional server congestion
|
||||
else:
|
||||
break
|
||||
finally:
|
||||
retry -= 1
|
||||
img_gt = imfrombytes(img_bytes, float32=True)
|
||||
|
||||
# -------------------- augmentation for training: flip, rotation -------------------- #
|
||||
img_gt = augment(img_gt, self.opt['use_hflip'], self.opt['use_rot'])
|
||||
|
||||
# crop or pad to 400: 400 is hard-coded. You may change it accordingly
|
||||
h, w = img_gt.shape[0:2]
|
||||
crop_pad_size = 400
|
||||
# pad
|
||||
if h < crop_pad_size or w < crop_pad_size:
|
||||
pad_h = max(0, crop_pad_size - h)
|
||||
pad_w = max(0, crop_pad_size - w)
|
||||
img_gt = cv2.copyMakeBorder(img_gt, 0, pad_h, 0, pad_w, cv2.BORDER_REFLECT_101)
|
||||
# crop
|
||||
if img_gt.shape[0] > crop_pad_size or img_gt.shape[1] > crop_pad_size:
|
||||
h, w = img_gt.shape[0:2]
|
||||
# randomly choose top and left coordinates
|
||||
top = random.randint(0, h - crop_pad_size)
|
||||
left = random.randint(0, w - crop_pad_size)
|
||||
img_gt = img_gt[top:top + crop_pad_size, left:left + crop_pad_size, ...]
|
||||
|
||||
# ------------------------ Generate kernels (used in the first degradation) ------------------------ #
|
||||
kernel_size = random.choice(self.kernel_range)
|
||||
if np.random.uniform() < self.opt['sinc_prob']:
|
||||
# this sinc filter setting is for kernels ranging from [7, 21]
|
||||
if kernel_size < 13:
|
||||
omega_c = np.random.uniform(np.pi / 3, np.pi)
|
||||
else:
|
||||
omega_c = np.random.uniform(np.pi / 5, np.pi)
|
||||
kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False)
|
||||
else:
|
||||
kernel = random_mixed_kernels(
|
||||
self.kernel_list,
|
||||
self.kernel_prob,
|
||||
kernel_size,
|
||||
self.blur_sigma,
|
||||
self.blur_sigma, [-math.pi, math.pi],
|
||||
self.betag_range,
|
||||
self.betap_range,
|
||||
noise_range=None)
|
||||
# pad kernel
|
||||
pad_size = (21 - kernel_size) // 2
|
||||
kernel = np.pad(kernel, ((pad_size, pad_size), (pad_size, pad_size)))
|
||||
|
||||
# ------------------------ Generate kernels (used in the second degradation) ------------------------ #
|
||||
kernel_size = random.choice(self.kernel_range)
|
||||
if np.random.uniform() < self.opt['sinc_prob2']:
|
||||
if kernel_size < 13:
|
||||
omega_c = np.random.uniform(np.pi / 3, np.pi)
|
||||
else:
|
||||
omega_c = np.random.uniform(np.pi / 5, np.pi)
|
||||
kernel2 = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False)
|
||||
else:
|
||||
kernel2 = random_mixed_kernels(
|
||||
self.kernel_list2,
|
||||
self.kernel_prob2,
|
||||
kernel_size,
|
||||
self.blur_sigma2,
|
||||
self.blur_sigma2, [-math.pi, math.pi],
|
||||
self.betag_range2,
|
||||
self.betap_range2,
|
||||
noise_range=None)
|
||||
|
||||
# pad kernel
|
||||
pad_size = (21 - kernel_size) // 2
|
||||
kernel2 = np.pad(kernel2, ((pad_size, pad_size), (pad_size, pad_size)))
|
||||
|
||||
# ------------------------------------- sinc kernel ------------------------------------- #
|
||||
if np.random.uniform() < self.opt['final_sinc_prob']:
|
||||
kernel_size = random.choice(self.kernel_range)
|
||||
omega_c = np.random.uniform(np.pi / 3, np.pi)
|
||||
sinc_kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=21)
|
||||
sinc_kernel = torch.FloatTensor(sinc_kernel)
|
||||
else:
|
||||
sinc_kernel = self.pulse_tensor
|
||||
|
||||
# BGR to RGB, HWC to CHW, numpy to tensor
|
||||
img_gt = img2tensor([img_gt], bgr2rgb=True, float32=True)[0]
|
||||
kernel = torch.FloatTensor(kernel)
|
||||
kernel2 = torch.FloatTensor(kernel2)
|
||||
|
||||
return_d = {'gt': img_gt, 'kernel1': kernel, 'kernel2': kernel2, 'sinc_kernel': sinc_kernel, 'gt_path': gt_path}
|
||||
return return_d
|
||||
|
||||
def __len__(self):
|
||||
return len(self.paths)
|
||||
10
realesrgan/models/__init__.py
Normal file
10
realesrgan/models/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
||||
import importlib
|
||||
from basicsr.utils import scandir
|
||||
from os import path as osp
|
||||
|
||||
# automatically scan and import model modules for registry
|
||||
# scan all the files that end with '_model.py' under the model folder
|
||||
model_folder = osp.dirname(osp.abspath(__file__))
|
||||
model_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) if v.endswith('_model.py')]
|
||||
# import all the model modules
|
||||
_model_modules = [importlib.import_module(f'realesrgan.models.{file_name}') for file_name in model_filenames]
|
||||
242
realesrgan/models/realesrgan_model.py
Normal file
242
realesrgan/models/realesrgan_model.py
Normal file
@@ -0,0 +1,242 @@
|
||||
import numpy as np
|
||||
import random
|
||||
import torch
|
||||
from basicsr.data.degradations import random_add_gaussian_noise_pt, random_add_poisson_noise_pt
|
||||
from basicsr.data.transforms import paired_random_crop
|
||||
from basicsr.models.srgan_model import SRGANModel
|
||||
from basicsr.utils import DiffJPEG, USMSharp
|
||||
from basicsr.utils.img_process_util import filter2D
|
||||
from basicsr.utils.registry import MODEL_REGISTRY
|
||||
from collections import OrderedDict
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
@MODEL_REGISTRY.register()
|
||||
class RealESRGANModel(SRGANModel):
|
||||
"""RealESRGAN Model"""
|
||||
|
||||
def __init__(self, opt):
|
||||
super(RealESRGANModel, self).__init__(opt)
|
||||
self.jpeger = DiffJPEG(differentiable=False).cuda()
|
||||
self.usm_sharpener = USMSharp().cuda()
|
||||
self.queue_size = opt['queue_size']
|
||||
|
||||
@torch.no_grad()
|
||||
def _dequeue_and_enqueue(self):
|
||||
# training pair pool
|
||||
# initialize
|
||||
b, c, h, w = self.lq.size()
|
||||
if not hasattr(self, 'queue_lr'):
|
||||
assert self.queue_size % b == 0, 'queue size should be divisible by batch size'
|
||||
self.queue_lr = torch.zeros(self.queue_size, c, h, w).cuda()
|
||||
_, c, h, w = self.gt.size()
|
||||
self.queue_gt = torch.zeros(self.queue_size, c, h, w).cuda()
|
||||
self.queue_ptr = 0
|
||||
if self.queue_ptr == self.queue_size: # full
|
||||
# do dequeue and enqueue
|
||||
# shuffle
|
||||
idx = torch.randperm(self.queue_size)
|
||||
self.queue_lr = self.queue_lr[idx]
|
||||
self.queue_gt = self.queue_gt[idx]
|
||||
# get
|
||||
lq_dequeue = self.queue_lr[0:b, :, :, :].clone()
|
||||
gt_dequeue = self.queue_gt[0:b, :, :, :].clone()
|
||||
# update
|
||||
self.queue_lr[0:b, :, :, :] = self.lq.clone()
|
||||
self.queue_gt[0:b, :, :, :] = self.gt.clone()
|
||||
|
||||
self.lq = lq_dequeue
|
||||
self.gt = gt_dequeue
|
||||
else:
|
||||
# only do enqueue
|
||||
self.queue_lr[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.lq.clone()
|
||||
self.queue_gt[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.gt.clone()
|
||||
self.queue_ptr = self.queue_ptr + b
|
||||
|
||||
@torch.no_grad()
|
||||
def feed_data(self, data):
|
||||
if self.is_train:
|
||||
# training data synthesis
|
||||
self.gt = data['gt'].to(self.device)
|
||||
self.gt_usm = self.usm_sharpener(self.gt)
|
||||
|
||||
self.kernel1 = data['kernel1'].to(self.device)
|
||||
self.kernel2 = data['kernel2'].to(self.device)
|
||||
self.sinc_kernel = data['sinc_kernel'].to(self.device)
|
||||
|
||||
ori_h, ori_w = self.gt.size()[2:4]
|
||||
|
||||
# ----------------------- The first degradation process ----------------------- #
|
||||
# blur
|
||||
out = filter2D(self.gt_usm, self.kernel1)
|
||||
# random resize
|
||||
updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob'])[0]
|
||||
if updown_type == 'up':
|
||||
scale = np.random.uniform(1, self.opt['resize_range'][1])
|
||||
elif updown_type == 'down':
|
||||
scale = np.random.uniform(self.opt['resize_range'][0], 1)
|
||||
else:
|
||||
scale = 1
|
||||
mode = random.choice(['area', 'bilinear', 'bicubic'])
|
||||
out = F.interpolate(out, scale_factor=scale, mode=mode)
|
||||
# noise
|
||||
gray_noise_prob = self.opt['gray_noise_prob']
|
||||
if np.random.uniform() < self.opt['gaussian_noise_prob']:
|
||||
out = random_add_gaussian_noise_pt(
|
||||
out, sigma_range=self.opt['noise_range'], clip=True, rounds=False, gray_prob=gray_noise_prob)
|
||||
else:
|
||||
out = random_add_poisson_noise_pt(
|
||||
out,
|
||||
scale_range=self.opt['poisson_scale_range'],
|
||||
gray_prob=gray_noise_prob,
|
||||
clip=True,
|
||||
rounds=False)
|
||||
# JPEG compression
|
||||
jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range'])
|
||||
out = torch.clamp(out, 0, 1)
|
||||
out = self.jpeger(out, quality=jpeg_p)
|
||||
|
||||
# ----------------------- The second degradation process ----------------------- #
|
||||
# blur
|
||||
if np.random.uniform() < self.opt['second_blur_prob']:
|
||||
out = filter2D(out, self.kernel2)
|
||||
# random resize
|
||||
updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob2'])[0]
|
||||
if updown_type == 'up':
|
||||
scale = np.random.uniform(1, self.opt['resize_range2'][1])
|
||||
elif updown_type == 'down':
|
||||
scale = np.random.uniform(self.opt['resize_range2'][0], 1)
|
||||
else:
|
||||
scale = 1
|
||||
mode = random.choice(['area', 'bilinear', 'bicubic'])
|
||||
out = F.interpolate(
|
||||
out, size=(int(ori_h / self.opt['scale'] * scale), int(ori_w / self.opt['scale'] * scale)), mode=mode)
|
||||
# noise
|
||||
gray_noise_prob = self.opt['gray_noise_prob2']
|
||||
if np.random.uniform() < self.opt['gaussian_noise_prob2']:
|
||||
out = random_add_gaussian_noise_pt(
|
||||
out, sigma_range=self.opt['noise_range2'], clip=True, rounds=False, gray_prob=gray_noise_prob)
|
||||
else:
|
||||
out = random_add_poisson_noise_pt(
|
||||
out,
|
||||
scale_range=self.opt['poisson_scale_range2'],
|
||||
gray_prob=gray_noise_prob,
|
||||
clip=True,
|
||||
rounds=False)
|
||||
|
||||
# JPEG compression + the final sinc filter
|
||||
# We also need to resize images to desired sizes. We group [resize back + sinc filter] together
|
||||
# as one operation.
|
||||
# We consider two orders:
|
||||
# 1. [resize back + sinc filter] + JPEG compression
|
||||
# 2. JPEG compression + [resize back + sinc filter]
|
||||
# Empirically, we find other combinations (sinc + JPEG + Resize) will introduce twisted lines.
|
||||
if np.random.uniform() < 0.5:
|
||||
# resize back + the final sinc filter
|
||||
mode = random.choice(['area', 'bilinear', 'bicubic'])
|
||||
out = F.interpolate(out, size=(ori_h // self.opt['scale'], ori_w // self.opt['scale']), mode=mode)
|
||||
out = filter2D(out, self.sinc_kernel)
|
||||
# JPEG compression
|
||||
jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2'])
|
||||
out = torch.clamp(out, 0, 1)
|
||||
out = self.jpeger(out, quality=jpeg_p)
|
||||
else:
|
||||
# JPEG compression
|
||||
jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2'])
|
||||
out = torch.clamp(out, 0, 1)
|
||||
out = self.jpeger(out, quality=jpeg_p)
|
||||
# resize back + the final sinc filter
|
||||
mode = random.choice(['area', 'bilinear', 'bicubic'])
|
||||
out = F.interpolate(out, size=(ori_h // self.opt['scale'], ori_w // self.opt['scale']), mode=mode)
|
||||
out = filter2D(out, self.sinc_kernel)
|
||||
|
||||
# clamp and round
|
||||
self.lq = torch.clamp((out * 255.0).round(), 0, 255) / 255.
|
||||
|
||||
# random crop
|
||||
gt_size = self.opt['gt_size']
|
||||
(self.gt, self.gt_usm), self.lq = paired_random_crop([self.gt, self.gt_usm], self.lq, gt_size,
|
||||
self.opt['scale'])
|
||||
|
||||
# training pair pool
|
||||
self._dequeue_and_enqueue()
|
||||
# sharpen self.gt again, as we have changed the self.gt with self._dequeue_and_enqueue
|
||||
self.gt_usm = self.usm_sharpener(self.gt)
|
||||
else:
|
||||
self.lq = data['lq'].to(self.device)
|
||||
if 'gt' in data:
|
||||
self.gt = data['gt'].to(self.device)
|
||||
|
||||
def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
|
||||
# do not use the synthetic process during validation
|
||||
self.is_train = False
|
||||
super(RealESRGANModel, self).nondist_validation(dataloader, current_iter, tb_logger, save_img)
|
||||
self.is_train = True
|
||||
|
||||
def optimize_parameters(self, current_iter):
|
||||
l1_gt = self.gt_usm
|
||||
percep_gt = self.gt_usm
|
||||
gan_gt = self.gt_usm
|
||||
if self.opt['l1_gt_usm'] is False:
|
||||
l1_gt = self.gt
|
||||
if self.opt['percep_gt_usm'] is False:
|
||||
percep_gt = self.gt
|
||||
if self.opt['gan_gt_usm'] is False:
|
||||
gan_gt = self.gt
|
||||
|
||||
# optimize net_g
|
||||
for p in self.net_d.parameters():
|
||||
p.requires_grad = False
|
||||
|
||||
self.optimizer_g.zero_grad()
|
||||
self.output = self.net_g(self.lq)
|
||||
|
||||
l_g_total = 0
|
||||
loss_dict = OrderedDict()
|
||||
if (current_iter % self.net_d_iters == 0 and current_iter > self.net_d_init_iters):
|
||||
# pixel loss
|
||||
if self.cri_pix:
|
||||
l_g_pix = self.cri_pix(self.output, l1_gt)
|
||||
l_g_total += l_g_pix
|
||||
loss_dict['l_g_pix'] = l_g_pix
|
||||
# perceptual loss
|
||||
if self.cri_perceptual:
|
||||
l_g_percep, l_g_style = self.cri_perceptual(self.output, percep_gt)
|
||||
if l_g_percep is not None:
|
||||
l_g_total += l_g_percep
|
||||
loss_dict['l_g_percep'] = l_g_percep
|
||||
if l_g_style is not None:
|
||||
l_g_total += l_g_style
|
||||
loss_dict['l_g_style'] = l_g_style
|
||||
# gan loss
|
||||
fake_g_pred = self.net_d(self.output)
|
||||
l_g_gan = self.cri_gan(fake_g_pred, True, is_disc=False)
|
||||
l_g_total += l_g_gan
|
||||
loss_dict['l_g_gan'] = l_g_gan
|
||||
|
||||
l_g_total.backward()
|
||||
self.optimizer_g.step()
|
||||
|
||||
# optimize net_d
|
||||
for p in self.net_d.parameters():
|
||||
p.requires_grad = True
|
||||
|
||||
self.optimizer_d.zero_grad()
|
||||
# real
|
||||
real_d_pred = self.net_d(gan_gt)
|
||||
l_d_real = self.cri_gan(real_d_pred, True, is_disc=True)
|
||||
loss_dict['l_d_real'] = l_d_real
|
||||
loss_dict['out_d_real'] = torch.mean(real_d_pred.detach())
|
||||
l_d_real.backward()
|
||||
# fake
|
||||
fake_d_pred = self.net_d(self.output.detach().clone()) # clone for pt1.9
|
||||
l_d_fake = self.cri_gan(fake_d_pred, False, is_disc=True)
|
||||
loss_dict['l_d_fake'] = l_d_fake
|
||||
loss_dict['out_d_fake'] = torch.mean(fake_d_pred.detach())
|
||||
l_d_fake.backward()
|
||||
self.optimizer_d.step()
|
||||
|
||||
if self.ema_decay > 0:
|
||||
self.model_ema(decay=self.ema_decay)
|
||||
|
||||
self.log_dict = self.reduce_loss_dict(loss_dict)
|
||||
172
realesrgan/models/realesrnet_model.py
Normal file
172
realesrgan/models/realesrnet_model.py
Normal file
@@ -0,0 +1,172 @@
|
||||
import numpy as np
|
||||
import random
|
||||
import torch
|
||||
from basicsr.data.degradations import random_add_gaussian_noise_pt, random_add_poisson_noise_pt
|
||||
from basicsr.data.transforms import paired_random_crop
|
||||
from basicsr.models.sr_model import SRModel
|
||||
from basicsr.utils import DiffJPEG, USMSharp
|
||||
from basicsr.utils.img_process_util import filter2D
|
||||
from basicsr.utils.registry import MODEL_REGISTRY
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
@MODEL_REGISTRY.register()
|
||||
class RealESRNetModel(SRModel):
|
||||
"""RealESRNet Model"""
|
||||
|
||||
def __init__(self, opt):
|
||||
super(RealESRNetModel, self).__init__(opt)
|
||||
self.jpeger = DiffJPEG(differentiable=False).cuda()
|
||||
self.usm_sharpener = USMSharp().cuda()
|
||||
self.queue_size = opt['queue_size']
|
||||
|
||||
@torch.no_grad()
|
||||
def _dequeue_and_enqueue(self):
|
||||
# training pair pool
|
||||
# initialize
|
||||
b, c, h, w = self.lq.size()
|
||||
if not hasattr(self, 'queue_lr'):
|
||||
assert self.queue_size % b == 0, 'queue size should be divisible by batch size'
|
||||
self.queue_lr = torch.zeros(self.queue_size, c, h, w).cuda()
|
||||
_, c, h, w = self.gt.size()
|
||||
self.queue_gt = torch.zeros(self.queue_size, c, h, w).cuda()
|
||||
self.queue_ptr = 0
|
||||
if self.queue_ptr == self.queue_size: # full
|
||||
# do dequeue and enqueue
|
||||
# shuffle
|
||||
idx = torch.randperm(self.queue_size)
|
||||
self.queue_lr = self.queue_lr[idx]
|
||||
self.queue_gt = self.queue_gt[idx]
|
||||
# get
|
||||
lq_dequeue = self.queue_lr[0:b, :, :, :].clone()
|
||||
gt_dequeue = self.queue_gt[0:b, :, :, :].clone()
|
||||
# update
|
||||
self.queue_lr[0:b, :, :, :] = self.lq.clone()
|
||||
self.queue_gt[0:b, :, :, :] = self.gt.clone()
|
||||
|
||||
self.lq = lq_dequeue
|
||||
self.gt = gt_dequeue
|
||||
else:
|
||||
# only do enqueue
|
||||
self.queue_lr[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.lq.clone()
|
||||
self.queue_gt[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.gt.clone()
|
||||
self.queue_ptr = self.queue_ptr + b
|
||||
|
||||
@torch.no_grad()
|
||||
def feed_data(self, data):
|
||||
if self.is_train:
|
||||
# training data synthesis
|
||||
self.gt = data['gt'].to(self.device)
|
||||
# USM the GT images
|
||||
if self.opt['gt_usm'] is True:
|
||||
self.gt = self.usm_sharpener(self.gt)
|
||||
|
||||
self.kernel1 = data['kernel1'].to(self.device)
|
||||
self.kernel2 = data['kernel2'].to(self.device)
|
||||
self.sinc_kernel = data['sinc_kernel'].to(self.device)
|
||||
|
||||
ori_h, ori_w = self.gt.size()[2:4]
|
||||
|
||||
# ----------------------- The first degradation process ----------------------- #
|
||||
# blur
|
||||
out = filter2D(self.gt, self.kernel1)
|
||||
# random resize
|
||||
updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob'])[0]
|
||||
if updown_type == 'up':
|
||||
scale = np.random.uniform(1, self.opt['resize_range'][1])
|
||||
elif updown_type == 'down':
|
||||
scale = np.random.uniform(self.opt['resize_range'][0], 1)
|
||||
else:
|
||||
scale = 1
|
||||
mode = random.choice(['area', 'bilinear', 'bicubic'])
|
||||
out = F.interpolate(out, scale_factor=scale, mode=mode)
|
||||
# noise
|
||||
gray_noise_prob = self.opt['gray_noise_prob']
|
||||
if np.random.uniform() < self.opt['gaussian_noise_prob']:
|
||||
out = random_add_gaussian_noise_pt(
|
||||
out, sigma_range=self.opt['noise_range'], clip=True, rounds=False, gray_prob=gray_noise_prob)
|
||||
else:
|
||||
out = random_add_poisson_noise_pt(
|
||||
out,
|
||||
scale_range=self.opt['poisson_scale_range'],
|
||||
gray_prob=gray_noise_prob,
|
||||
clip=True,
|
||||
rounds=False)
|
||||
# JPEG compression
|
||||
jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range'])
|
||||
out = torch.clamp(out, 0, 1)
|
||||
out = self.jpeger(out, quality=jpeg_p)
|
||||
|
||||
# ----------------------- The second degradation process ----------------------- #
|
||||
# blur
|
||||
if np.random.uniform() < self.opt['second_blur_prob']:
|
||||
out = filter2D(out, self.kernel2)
|
||||
# random resize
|
||||
updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob2'])[0]
|
||||
if updown_type == 'up':
|
||||
scale = np.random.uniform(1, self.opt['resize_range2'][1])
|
||||
elif updown_type == 'down':
|
||||
scale = np.random.uniform(self.opt['resize_range2'][0], 1)
|
||||
else:
|
||||
scale = 1
|
||||
mode = random.choice(['area', 'bilinear', 'bicubic'])
|
||||
out = F.interpolate(
|
||||
out, size=(int(ori_h / self.opt['scale'] * scale), int(ori_w / self.opt['scale'] * scale)), mode=mode)
|
||||
# noise
|
||||
gray_noise_prob = self.opt['gray_noise_prob2']
|
||||
if np.random.uniform() < self.opt['gaussian_noise_prob2']:
|
||||
out = random_add_gaussian_noise_pt(
|
||||
out, sigma_range=self.opt['noise_range2'], clip=True, rounds=False, gray_prob=gray_noise_prob)
|
||||
else:
|
||||
out = random_add_poisson_noise_pt(
|
||||
out,
|
||||
scale_range=self.opt['poisson_scale_range2'],
|
||||
gray_prob=gray_noise_prob,
|
||||
clip=True,
|
||||
rounds=False)
|
||||
|
||||
# JPEG compression + the final sinc filter
|
||||
# We also need to resize images to desired sizes. We group [resize back + sinc filter] together
|
||||
# as one operation.
|
||||
# We consider two orders:
|
||||
# 1. [resize back + sinc filter] + JPEG compression
|
||||
# 2. JPEG compression + [resize back + sinc filter]
|
||||
# Empirically, we find other combinations (sinc + JPEG + Resize) will introduce twisted lines.
|
||||
if np.random.uniform() < 0.5:
|
||||
# resize back + the final sinc filter
|
||||
mode = random.choice(['area', 'bilinear', 'bicubic'])
|
||||
out = F.interpolate(out, size=(ori_h // self.opt['scale'], ori_w // self.opt['scale']), mode=mode)
|
||||
out = filter2D(out, self.sinc_kernel)
|
||||
# JPEG compression
|
||||
jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2'])
|
||||
out = torch.clamp(out, 0, 1)
|
||||
out = self.jpeger(out, quality=jpeg_p)
|
||||
else:
|
||||
# JPEG compression
|
||||
jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2'])
|
||||
out = torch.clamp(out, 0, 1)
|
||||
out = self.jpeger(out, quality=jpeg_p)
|
||||
# resize back + the final sinc filter
|
||||
mode = random.choice(['area', 'bilinear', 'bicubic'])
|
||||
out = F.interpolate(out, size=(ori_h // self.opt['scale'], ori_w // self.opt['scale']), mode=mode)
|
||||
out = filter2D(out, self.sinc_kernel)
|
||||
|
||||
# clamp and round
|
||||
self.lq = torch.clamp((out * 255.0).round(), 0, 255) / 255.
|
||||
|
||||
# random crop
|
||||
gt_size = self.opt['gt_size']
|
||||
self.gt, self.lq = paired_random_crop(self.gt, self.lq, gt_size, self.opt['scale'])
|
||||
|
||||
# training pair pool
|
||||
self._dequeue_and_enqueue()
|
||||
else:
|
||||
self.lq = data['lq'].to(self.device)
|
||||
if 'gt' in data:
|
||||
self.gt = data['gt'].to(self.device)
|
||||
|
||||
def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
|
||||
# do not use the synthetic process during validation
|
||||
self.is_train = False
|
||||
super(RealESRNetModel, self).nondist_validation(dataloader, current_iter, tb_logger, save_img)
|
||||
self.is_train = True
|
||||
11
realesrgan/train.py
Normal file
11
realesrgan/train.py
Normal file
@@ -0,0 +1,11 @@
|
||||
# flake8: noqa
|
||||
import os.path as osp
|
||||
from basicsr.train import train_pipeline
|
||||
|
||||
import realesrgan.archs
|
||||
import realesrgan.data
|
||||
import realesrgan.models
|
||||
|
||||
if __name__ == '__main__':
|
||||
root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir))
|
||||
train_pipeline(root_path)
|
||||
231
realesrgan/utils.py
Normal file
231
realesrgan/utils.py
Normal file
@@ -0,0 +1,231 @@
|
||||
import cv2
|
||||
import math
|
||||
import numpy as np
|
||||
import os
|
||||
import torch
|
||||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||
from torch.hub import download_url_to_file, get_dir
|
||||
from torch.nn import functional as F
|
||||
from urllib.parse import urlparse
|
||||
|
||||
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
|
||||
class RealESRGANer():
|
||||
|
||||
def __init__(self, scale, model_path, tile=0, tile_pad=10, pre_pad=10, half=False):
|
||||
self.scale = scale
|
||||
self.tile_size = tile
|
||||
self.tile_pad = tile_pad
|
||||
self.pre_pad = pre_pad
|
||||
self.mod_scale = None
|
||||
self.half = half
|
||||
|
||||
# initialize model
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=scale)
|
||||
|
||||
if model_path.startswith('https://'):
|
||||
model_path = load_file_from_url(
|
||||
url=model_path, model_dir='realesrgan/weights', progress=True, file_name=None)
|
||||
loadnet = torch.load(model_path)
|
||||
if 'params_ema' in loadnet:
|
||||
keyname = 'params_ema'
|
||||
else:
|
||||
keyname = 'params'
|
||||
model.load_state_dict(loadnet[keyname], strict=True)
|
||||
model.eval()
|
||||
self.model = model.to(self.device)
|
||||
if self.half:
|
||||
self.model = self.model.half()
|
||||
|
||||
def pre_process(self, img):
|
||||
img = torch.from_numpy(np.transpose(img, (2, 0, 1))).float()
|
||||
self.img = img.unsqueeze(0).to(self.device)
|
||||
if self.half:
|
||||
self.img = self.img.half()
|
||||
|
||||
# pre_pad
|
||||
if self.pre_pad != 0:
|
||||
self.img = F.pad(self.img, (0, self.pre_pad, 0, self.pre_pad), 'reflect')
|
||||
# mod pad
|
||||
if self.scale == 2:
|
||||
self.mod_scale = 2
|
||||
elif self.scale == 1:
|
||||
self.mod_scale = 4
|
||||
if self.mod_scale is not None:
|
||||
self.mod_pad_h, self.mod_pad_w = 0, 0
|
||||
_, _, h, w = self.img.size()
|
||||
if (h % self.mod_scale != 0):
|
||||
self.mod_pad_h = (self.mod_scale - h % self.mod_scale)
|
||||
if (w % self.mod_scale != 0):
|
||||
self.mod_pad_w = (self.mod_scale - w % self.mod_scale)
|
||||
self.img = F.pad(self.img, (0, self.mod_pad_w, 0, self.mod_pad_h), 'reflect')
|
||||
|
||||
def process(self):
|
||||
self.output = self.model(self.img)
|
||||
|
||||
def tile_process(self):
|
||||
"""Modified from: https://github.com/ata4/esrgan-launcher
|
||||
"""
|
||||
batch, channel, height, width = self.img.shape
|
||||
output_height = height * self.scale
|
||||
output_width = width * self.scale
|
||||
output_shape = (batch, channel, output_height, output_width)
|
||||
|
||||
# start with black image
|
||||
self.output = self.img.new_zeros(output_shape)
|
||||
tiles_x = math.ceil(width / self.tile_size)
|
||||
tiles_y = math.ceil(height / self.tile_size)
|
||||
|
||||
# loop over all tiles
|
||||
for y in range(tiles_y):
|
||||
for x in range(tiles_x):
|
||||
# extract tile from input image
|
||||
ofs_x = x * self.tile_size
|
||||
ofs_y = y * self.tile_size
|
||||
# input tile area on total image
|
||||
input_start_x = ofs_x
|
||||
input_end_x = min(ofs_x + self.tile_size, width)
|
||||
input_start_y = ofs_y
|
||||
input_end_y = min(ofs_y + self.tile_size, height)
|
||||
|
||||
# input tile area on total image with padding
|
||||
input_start_x_pad = max(input_start_x - self.tile_pad, 0)
|
||||
input_end_x_pad = min(input_end_x + self.tile_pad, width)
|
||||
input_start_y_pad = max(input_start_y - self.tile_pad, 0)
|
||||
input_end_y_pad = min(input_end_y + self.tile_pad, height)
|
||||
|
||||
# input tile dimensions
|
||||
input_tile_width = input_end_x - input_start_x
|
||||
input_tile_height = input_end_y - input_start_y
|
||||
tile_idx = y * tiles_x + x + 1
|
||||
input_tile = self.img[:, :, input_start_y_pad:input_end_y_pad, input_start_x_pad:input_end_x_pad]
|
||||
|
||||
# upscale tile
|
||||
try:
|
||||
with torch.no_grad():
|
||||
output_tile = self.model(input_tile)
|
||||
except Exception as error:
|
||||
print('Error', error)
|
||||
print(f'\tTile {tile_idx}/{tiles_x * tiles_y}')
|
||||
|
||||
# output tile area on total image
|
||||
output_start_x = input_start_x * self.scale
|
||||
output_end_x = input_end_x * self.scale
|
||||
output_start_y = input_start_y * self.scale
|
||||
output_end_y = input_end_y * self.scale
|
||||
|
||||
# output tile area without padding
|
||||
output_start_x_tile = (input_start_x - input_start_x_pad) * self.scale
|
||||
output_end_x_tile = output_start_x_tile + input_tile_width * self.scale
|
||||
output_start_y_tile = (input_start_y - input_start_y_pad) * self.scale
|
||||
output_end_y_tile = output_start_y_tile + input_tile_height * self.scale
|
||||
|
||||
# put tile into output image
|
||||
self.output[:, :, output_start_y:output_end_y,
|
||||
output_start_x:output_end_x] = output_tile[:, :, output_start_y_tile:output_end_y_tile,
|
||||
output_start_x_tile:output_end_x_tile]
|
||||
|
||||
def post_process(self):
|
||||
# remove extra pad
|
||||
if self.mod_scale is not None:
|
||||
_, _, h, w = self.output.size()
|
||||
self.output = self.output[:, :, 0:h - self.mod_pad_h * self.scale, 0:w - self.mod_pad_w * self.scale]
|
||||
# remove prepad
|
||||
if self.pre_pad != 0:
|
||||
_, _, h, w = self.output.size()
|
||||
self.output = self.output[:, :, 0:h - self.pre_pad * self.scale, 0:w - self.pre_pad * self.scale]
|
||||
return self.output
|
||||
|
||||
@torch.no_grad()
|
||||
def enhance(self, img, outscale=None, alpha_upsampler='realesrgan'):
|
||||
h_input, w_input = img.shape[0:2]
|
||||
# img: numpy
|
||||
img = img.astype(np.float32)
|
||||
if np.max(img) > 255: # 16-bit image
|
||||
max_range = 65535
|
||||
print('\tInput is a 16-bit image')
|
||||
else:
|
||||
max_range = 255
|
||||
img = img / max_range
|
||||
if len(img.shape) == 2: # gray image
|
||||
img_mode = 'L'
|
||||
img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
|
||||
elif img.shape[2] == 4: # RGBA image with alpha channel
|
||||
img_mode = 'RGBA'
|
||||
alpha = img[:, :, 3]
|
||||
img = img[:, :, 0:3]
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
if alpha_upsampler == 'realesrgan':
|
||||
alpha = cv2.cvtColor(alpha, cv2.COLOR_GRAY2RGB)
|
||||
else:
|
||||
img_mode = 'RGB'
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
|
||||
# ------------------- process image (without the alpha channel) ------------------- #
|
||||
self.pre_process(img)
|
||||
if self.tile_size > 0:
|
||||
self.tile_process()
|
||||
else:
|
||||
self.process()
|
||||
output_img = self.post_process()
|
||||
output_img = output_img.data.squeeze().float().cpu().clamp_(0, 1).numpy()
|
||||
output_img = np.transpose(output_img[[2, 1, 0], :, :], (1, 2, 0))
|
||||
if img_mode == 'L':
|
||||
output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2GRAY)
|
||||
|
||||
# ------------------- process the alpha channel if necessary ------------------- #
|
||||
if img_mode == 'RGBA':
|
||||
if alpha_upsampler == 'realesrgan':
|
||||
self.pre_process(alpha)
|
||||
if self.tile_size > 0:
|
||||
self.tile_process()
|
||||
else:
|
||||
self.process()
|
||||
output_alpha = self.post_process()
|
||||
output_alpha = output_alpha.data.squeeze().float().cpu().clamp_(0, 1).numpy()
|
||||
output_alpha = np.transpose(output_alpha[[2, 1, 0], :, :], (1, 2, 0))
|
||||
output_alpha = cv2.cvtColor(output_alpha, cv2.COLOR_BGR2GRAY)
|
||||
else:
|
||||
h, w = alpha.shape[0:2]
|
||||
output_alpha = cv2.resize(alpha, (w * self.scale, h * self.scale), interpolation=cv2.INTER_LINEAR)
|
||||
|
||||
# merge the alpha channel
|
||||
output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2BGRA)
|
||||
output_img[:, :, 3] = output_alpha
|
||||
|
||||
# ------------------------------ return ------------------------------ #
|
||||
if max_range == 65535: # 16-bit image
|
||||
output = (output_img * 65535.0).round().astype(np.uint16)
|
||||
else:
|
||||
output = (output_img * 255.0).round().astype(np.uint8)
|
||||
|
||||
if outscale is not None and outscale != float(self.scale):
|
||||
output = cv2.resize(
|
||||
output, (
|
||||
int(w_input * outscale),
|
||||
int(h_input * outscale),
|
||||
), interpolation=cv2.INTER_LANCZOS4)
|
||||
|
||||
return output, img_mode
|
||||
|
||||
|
||||
def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
|
||||
"""Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py
|
||||
"""
|
||||
if model_dir is None:
|
||||
hub_dir = get_dir()
|
||||
model_dir = os.path.join(hub_dir, 'checkpoints')
|
||||
|
||||
os.makedirs(os.path.join(ROOT_DIR, model_dir), exist_ok=True)
|
||||
|
||||
parts = urlparse(url)
|
||||
filename = os.path.basename(parts.path)
|
||||
if file_name is not None:
|
||||
filename = file_name
|
||||
cached_file = os.path.abspath(os.path.join(ROOT_DIR, model_dir, filename))
|
||||
if not os.path.exists(cached_file):
|
||||
print(f'Downloading: "{url}" to {cached_file}\n')
|
||||
download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
|
||||
return cached_file
|
||||
3
realesrgan/weights/README.md
Normal file
3
realesrgan/weights/README.md
Normal file
@@ -0,0 +1,3 @@
|
||||
# Weights
|
||||
|
||||
Put the downloaded weights to this folder.
|
||||
@@ -1,4 +1,4 @@
|
||||
basicsr
|
||||
cv2
|
||||
numpy
|
||||
opencv-python
|
||||
torch>=1.7
|
||||
|
||||
17
scripts/pytorch2onnx.py
Normal file
17
scripts/pytorch2onnx.py
Normal file
@@ -0,0 +1,17 @@
|
||||
import torch
|
||||
import torch.onnx
|
||||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||
|
||||
# An instance of your model
|
||||
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32)
|
||||
model.load_state_dict(torch.load('experiments/pretrained_models/RealESRGAN_x4plus.pth')['params_ema'])
|
||||
# set the train mode to false since we will only run the forward pass.
|
||||
model.train(False)
|
||||
model.cpu().eval()
|
||||
|
||||
# An example input you would normally provide to your model's forward() method
|
||||
x = torch.rand(1, 3, 64, 64)
|
||||
|
||||
# Export the model
|
||||
with torch.no_grad():
|
||||
torch_out = torch.onnx._export(model, x, 'realesrgan-x4.onnx', opset_version=11, export_params=True)
|
||||
@@ -16,7 +16,7 @@ split_before_expression_after_opening_paren = true
|
||||
line_length = 120
|
||||
multi_line_output = 0
|
||||
known_standard_library = pkg_resources,setuptools
|
||||
known_first_party = basicsr # modify it!
|
||||
known_first_party = realesrgan
|
||||
known_third_party = basicsr,cv2,numpy,torch
|
||||
no_lines_before = STDLIB,LOCALFOLDER
|
||||
default_section = THIRDPARTY
|
||||
|
||||
113
setup.py
Normal file
113
setup.py
Normal file
@@ -0,0 +1,113 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
from setuptools import find_packages, setup
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
import time
|
||||
|
||||
version_file = 'realesrgan/version.py'
|
||||
|
||||
|
||||
def readme():
|
||||
with open('README.md', encoding='utf-8') as f:
|
||||
content = f.read()
|
||||
return content
|
||||
|
||||
|
||||
def get_git_hash():
|
||||
|
||||
def _minimal_ext_cmd(cmd):
|
||||
# construct minimal environment
|
||||
env = {}
|
||||
for k in ['SYSTEMROOT', 'PATH', 'HOME']:
|
||||
v = os.environ.get(k)
|
||||
if v is not None:
|
||||
env[k] = v
|
||||
# LANGUAGE is used on win32
|
||||
env['LANGUAGE'] = 'C'
|
||||
env['LANG'] = 'C'
|
||||
env['LC_ALL'] = 'C'
|
||||
out = subprocess.Popen(cmd, stdout=subprocess.PIPE, env=env).communicate()[0]
|
||||
return out
|
||||
|
||||
try:
|
||||
out = _minimal_ext_cmd(['git', 'rev-parse', 'HEAD'])
|
||||
sha = out.strip().decode('ascii')
|
||||
except OSError:
|
||||
sha = 'unknown'
|
||||
|
||||
return sha
|
||||
|
||||
|
||||
def get_hash():
|
||||
if os.path.exists('.git'):
|
||||
sha = get_git_hash()[:7]
|
||||
elif os.path.exists(version_file):
|
||||
try:
|
||||
from facexlib.version import __version__
|
||||
sha = __version__.split('+')[-1]
|
||||
except ImportError:
|
||||
raise ImportError('Unable to get git version')
|
||||
else:
|
||||
sha = 'unknown'
|
||||
|
||||
return sha
|
||||
|
||||
|
||||
def write_version_py():
|
||||
content = """# GENERATED VERSION FILE
|
||||
# TIME: {}
|
||||
__version__ = '{}'
|
||||
__gitsha__ = '{}'
|
||||
version_info = ({})
|
||||
"""
|
||||
sha = get_hash()
|
||||
with open('VERSION', 'r') as f:
|
||||
SHORT_VERSION = f.read().strip()
|
||||
VERSION_INFO = ', '.join([x if x.isdigit() else f'"{x}"' for x in SHORT_VERSION.split('.')])
|
||||
|
||||
version_file_str = content.format(time.asctime(), SHORT_VERSION, sha, VERSION_INFO)
|
||||
with open(version_file, 'w') as f:
|
||||
f.write(version_file_str)
|
||||
|
||||
|
||||
def get_version():
|
||||
with open(version_file, 'r') as f:
|
||||
exec(compile(f.read(), version_file, 'exec'))
|
||||
return locals()['__version__']
|
||||
|
||||
|
||||
def get_requirements(filename='requirements.txt'):
|
||||
here = os.path.dirname(os.path.realpath(__file__))
|
||||
with open(os.path.join(here, filename), 'r') as f:
|
||||
requires = [line.replace('\n', '') for line in f.readlines()]
|
||||
return requires
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
write_version_py()
|
||||
setup(
|
||||
name='realesrgan',
|
||||
version=get_version(),
|
||||
description='Real-ESRGAN aims at developing Practical Algorithms for General Image Restoration',
|
||||
long_description=readme(),
|
||||
long_description_content_type='text/markdown',
|
||||
author='Xintao Wang',
|
||||
author_email='xintao.wang@outlook.com',
|
||||
keywords='computer vision, pytorch, image restoration, super-resolution, esrgan, real-esrgan',
|
||||
url='https://github.com/xinntao/Real-ESRGAN',
|
||||
include_package_data=True,
|
||||
packages=find_packages(exclude=('options', 'datasets', 'experiments', 'results', 'tb_logger', 'wandb')),
|
||||
classifiers=[
|
||||
'Development Status :: 4 - Beta',
|
||||
'License :: OSI Approved :: Apache Software License',
|
||||
'Operating System :: OS Independent',
|
||||
'Programming Language :: Python :: 3',
|
||||
'Programming Language :: Python :: 3.7',
|
||||
'Programming Language :: Python :: 3.8',
|
||||
],
|
||||
license='BSD-3-Clause License',
|
||||
setup_requires=['cython', 'numpy'],
|
||||
install_requires=get_requirements(),
|
||||
zip_safe=False)
|
||||
Reference in New Issue
Block a user