add inference codes

This commit is contained in:
Xintao
2021-07-23 02:50:08 +08:00
parent 4dc033d62b
commit 9a1dd23287
14 changed files with 148 additions and 319 deletions

21
LICENSE
View File

@@ -1,21 +0,0 @@
MIT License
Copyright (c) 2020 Xintao Wang
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

134
README.md
View File

@@ -1,72 +1,90 @@
# ProjectTemplate-Python # Real-ESRGAN
[English](README.md) **|** [简体中文](README_CN.md)   [GitHub](https://github.com/xinntao/ProjectTemplate-Python) **|** [Gitee码云](https://gitee.com/xinntao/ProjectTemplate-Python) [**Paper**](https://arxiv.org/abs/2101.04061)
## File Modification ### :book: Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data
1. Setup *pre-commit* hook > [[Paper](https://arxiv.org/abs/2101.04061)] &emsp; [[Project Page](https://xinntao.github.io/projects/gfpgan)] &emsp; [Demo] <br>
1. If necessary, modify `.pre-commit-config.yaml` > [Xintao Wang](https://xinntao.github.io/), Liangbin Xie, [Chao Dong](https://scholar.google.com.hk/citations?user=OSDCB0UAAAAJ), [Ying Shan](https://scholar.google.com/citations?user=4oXBp9UAAAAJ&hl=en) <br>
1. In the repository root path, run > Applied Research Center (ARC), Tencent PCG; Shenzhen Institutes of Advanced Technology, Chinese Academy of Sciences
> pre-commit install
1. Modify the `.gitignore` file
1. Modify the `LICENSE` file
This repository uses the *MIT* license, you may change it to other licenses
1. Modify the *setup* files
1. `setup.cfg`
1. `setup.py`, especially the `basicsr` keyword
1. Modify the `requirements.txt` files
1. Modify the `VERSION` file
## GitHub Workflows #### Abstract
1. [pylint](./github/workflows/pylint.yml) Though many attempts have been made in blind super-resolution to restore low-resolution images with unknown and complex degradations, they are still far from addressing general real-world degraded images. In this work, we extend the powerful ESRGAN to a practical restoration application (namely, Real-ESRGAN), which is trained with pure synthetic data. Specifically, a high-order degradation modeling process is introduced to better simulate complex real-world degradations. We also consider the common ringing and overshoot artifacts in the synthesis process. In addition, we employ a U-Net discriminator with spectral normalization to increase discriminator capability and stabilize the training dynamics. Extensive comparisons have shown its superior visual performance than prior works on various real datasets. We also provide efficient implementations to synthesize training pairs on the fly.
1. [gitee-repo-mirror](./github/workflow/gitee-repo-mirror.yml) - Support Gitee码云
1. Clone GitHub repo in the [Gitee](https://gitee.com/) website
1. Modify [gitee-repo-mirror](./github/workflow/gitee-repo-mirror.yml)
1. In Github *Settings* -> *Secrets*, add `SSH_PRIVATE_KEY`
## Other Procedures #### BibTeX
1. The `description`, `website`, `topics` in the main page @Article{wang2021realesrgan,
1. Support Chinese documents, for example, `README_CN.md` title={Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data},
author={Xintao Wang and Liangbin Xie and Chao Dong and Ying Shan},
journal={arXiv:2107.xxxxx},
year={2021}
}
## Emoji ---
[Emoji cheat-sheet](https://github.com/ikatyang/emoji-cheat-sheet) We are cleaning the training codes. It will be finished on 23 or 24, July.
| Emoji | Meaning | ---
| :--- | :---: |
| :rocket: | Used for [BasicSR](https://github.com/xinntao/BasicSR) Logo |
| :sparkles: | Features |
| :zap: | HOWTOs |
| :wrench: | Installation / Usage |
| :hourglass_flowing_sand: | TODO list |
| :turtle: | Dataset preparation |
| :computer: | Commands |
| :european_castle: | Model zoo |
| :memo: | Designs |
| :scroll: | License and acknowledgement |
| :earth_asia: | Citations |
| :e-mail: | Contact |
| :m: | Models |
| :arrow_double_down: | Download |
| :file_folder: | Datasets |
| :chart_with_upwards_trend: | Curves|
| :eyes: | Screenshot |
| :books: |References |
## Useful Image Links You can download **Windows executable files** from https://https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN-ncnn-vulkan.zip
<img src="https://colab.research.google.com/assets/colab-badge.svg" height="28" alt="google colab logo"> Google Colab Logo <br> You can simply run the following command:
<img src="https://upload.wikimedia.org/wikipedia/commons/8/8d/Windows_darkblue_2012.svg" height="28" alt="google colab logo"> Windows Logo <br> ```bash
<img src="https://upload.wikimedia.org/wikipedia/commons/3/3a/Logo-ubuntu_no%28r%29-black_orange-hex.svg" alt="Ubuntu" height="24"> Ubuntu Logo <br> realesrgan-ncnn-vulkan.exe -i input.jpg -o output.png
```
## Other Useful Tips This executable file is based on the wonderful [ncnn project](https://github.com/Tencent/ncnn) and [realsr-ncnn-vulkan](https://github.com/nihui/realsr-ncnn-vulkan)
1. `More` drop-down menu ---
<details>
<summary>More</summary> <p align="center">
<ul> <img src="assets/teaser.jpg">
<li>Nov 19, 2020. Set up ProjectTemplate-Python.</li> </p>
</ul>
</details> ---
## :wrench: Dependencies and Installation
- Python >= 3.7 (Recommend to use [Anaconda](https://www.anaconda.com/download/#linux) or [Miniconda](https://docs.conda.io/en/latest/miniconda.html))
- [PyTorch >= 1.7](https://pytorch.org/)
### Installation
1. Clone repo
```bash
git clone https://github.com/xinntao/Real-ESRGAN.git
cd Real-ESRGAN
```
1. Install dependent packages
```bash
# Install basicsr - https://github.com/xinntao/BasicSR
# We use BasicSR for both training and inference
pip install basicsr
pip install -r requirements.txt
```
## :zap: Quick Inference
Download pre-trained models: [RealESRGAN_x4plus.pth](https://https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth)
Download pretrained models:
```bash
wget https://https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth -P experiments/pretrained_models
```
Inference!
```bash
python inference_realesrgan.py --model_path experiments/pretrained_models/RealESRGAN_x4plus.pth --input inputs
```
Results are in the `results` folder
## :e-mail: Contact
If you have any question, please email `xintao.wang@outlook.com` or `xintaowang@tencent.com`.

View File

@@ -1,72 +0,0 @@
# ProjectTemplate-Python
[English](README.md) **|** [简体中文](README_CN.md) &emsp; [GitHub](https://github.com/xinntao/ProjectTemplate-Python) **|** [Gitee码云](https://gitee.com/xinntao/ProjectTemplate-Python)
## 文件修改
1. 设置 *pre-commit* hook.
1. 若需要, 修改 `.pre-commit-config.yaml`
1. 在文件夹根目录, 运行
> pre-commit install
1. 修改 `.gitignore` 文件
1. 修改 `LICENSE` 文件
本仓库使用 *MIT* 许可, 根据需要可以修改成其他许可
1. 修改 *setup* 文件
1. `setup.cfg`
1. `setup.py`, 特别是其中包含的关键字 `basicsr`
1. 修改 `requirements.txt` 文件
1. 修改 `VERSION` 文件
## GitHub Workflows
1. [pylint](./github/workflows/pylint.yml)
1. [gitee-repo-mirror](./github/workflow/gitee-repo-mirror.yml) - 支持 Gitee码云
1. 在 [Gitee](https://gitee.com/) 网站克隆 Github 仓库
1. 修改 [gitee-repo-mirror](./github/workflow/gitee-repo-mirror.yml) 文件
1. 在 Github 中的 *Settings* -> *Secrets*`SSH_PRIVATE_KEY`
## 其他流程
1. 主页上的 `description`, `website`, `topics`
1. 支持中文文档, 比如, `README_CN.md`
## Emoji
[Emoji cheat-sheet](https://github.com/ikatyang/emoji-cheat-sheet)
| Emoji | Meaning |
| :--- | :---: |
| :rocket: | Used for [BasicSR](https://github.com/xinntao/BasicSR) Logo |
| :sparkles: | Features |
| :zap: | HOWTOs |
| :wrench: | Installation / Usage |
| :hourglass_flowing_sand: | TODO list |
| :turtle: | Dataset preparation |
| :computer: | Commands |
| :european_castle: | Model zoo |
| :memo: | Designs |
| :scroll: | License and acknowledgement |
| :earth_asia: | Citations |
| :e-mail: | Contact |
| :m: | Models |
| :arrow_double_down: | Download |
| :file_folder: | Datasets |
| :chart_with_upwards_trend: | Curves|
| :eyes: | Screenshot |
| :books: |References |
## 有用的图像链接
<img src="https://colab.research.google.com/assets/colab-badge.svg" height="28" alt="google colab logo"> Google Colab Logo <br>
<img src="https://upload.wikimedia.org/wikipedia/commons/8/8d/Windows_darkblue_2012.svg" height="28" alt="google colab logo"> Windows Logo <br>
<img src="https://upload.wikimedia.org/wikipedia/commons/3/3a/Logo-ubuntu_no%28r%29-black_orange-hex.svg" alt="Ubuntu" height="24"> Ubuntu Logo <br>
## 其他有用的技巧
1. `More` 下拉菜单
<details>
<summary>More</summary>
<ul>
<li>Nov 19, 2020. Set up ProjectTemplate-Python.</li>
</ul>
</details>

View File

@@ -1 +0,0 @@
0.0.0

BIN
assets/teaser.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 392 KiB

68
inference_realesrgan.py Normal file
View File

@@ -0,0 +1,68 @@
import argparse
import cv2
import glob
import numpy as np
import os
import torch
from basicsr.archs.rrdbnet_arch import RRDBNet
from torch.nn import functional as F
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--model_path', type=str, default='experiments/pretrained_models/RealESRGAN_x4plus.pth')
parser.add_argument('--scale', type=int, default=4)
parser.add_argument('--input', type=str, default='inputs', help='input image or folder')
args = parser.parse_args()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# set up model
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=args.scale)
loadnet = torch.load(args.model_path)
model.load_state_dict(loadnet['params_ema'], strict=True)
model.eval()
model = model.to(device)
os.makedirs('results/', exist_ok=True)
for idx, path in enumerate(sorted(glob.glob(os.path.join(args.input, '*')))):
imgname = os.path.splitext(os.path.basename(path))[0]
print('Testing', idx, imgname)
# read image
img = cv2.imread(path, cv2.IMREAD_COLOR).astype(np.float32) / 255.
img = torch.from_numpy(np.transpose(img[:, :, [2, 1, 0]], (2, 0, 1))).float()
img = img.unsqueeze(0).to(device)
if args.scale == 2:
mod_scale = 2
elif args.scale == 1:
mod_scale = 4
else:
mod_scale = None
if mod_scale is not None:
h_pad, w_pad = 0, 0
_, _, h, w = img.size()
if (h % mod_scale != 0):
h_pad = (mod_scale - h % mod_scale)
if (w % mod_scale != 0):
w_pad = (mod_scale - w % mod_scale)
img = F.pad(img, (0, w_pad, 0, h_pad), 'reflect')
try:
# inference
with torch.no_grad():
output = model(img)
# remove extra pad
if mod_scale is not None:
_, _, h, w = output.size()
output = output[:, :, 0:h - h_pad, 0:w - w_pad]
# save image
output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0))
output = (output * 255.0).round().astype(np.uint8)
cv2.imwrite(f'results/{imgname}_RealESRGAN.png', output)
except Exception as error:
print('Error', error)
if __name__ == '__main__':
main()

BIN
inputs/00003.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 161 KiB

BIN
inputs/0014.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 9.4 KiB

BIN
inputs/0030.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 12 KiB

BIN
inputs/ADE_val_00000114.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 10 KiB

BIN
inputs/OST_009.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 701 KiB

View File

@@ -1 +1,4 @@
basicsr
cv2
numpy numpy
torch>=1.7

View File

@@ -17,6 +17,6 @@ line_length = 120
multi_line_output = 0 multi_line_output = 0
known_standard_library = pkg_resources,setuptools known_standard_library = pkg_resources,setuptools
known_first_party = basicsr # modify it! known_first_party = basicsr # modify it!
known_third_party = torch known_third_party = basicsr,cv2,numpy,torch
no_lines_before = STDLIB,LOCALFOLDER no_lines_before = STDLIB,LOCALFOLDER
default_section = THIRDPARTY default_section = THIRDPARTY

166
setup.py
View File

@@ -1,166 +0,0 @@
#!/usr/bin/env python
from setuptools import find_packages, setup
import os
import subprocess
import sys
import time
import torch
from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension
version_file = 'basicsr/version.py'
def readme():
with open('README.md', encoding='utf-8') as f:
content = f.read()
return content
def get_git_hash():
def _minimal_ext_cmd(cmd):
# construct minimal environment
env = {}
for k in ['SYSTEMROOT', 'PATH', 'HOME']:
v = os.environ.get(k)
if v is not None:
env[k] = v
# LANGUAGE is used on win32
env['LANGUAGE'] = 'C'
env['LANG'] = 'C'
env['LC_ALL'] = 'C'
out = subprocess.Popen(cmd, stdout=subprocess.PIPE, env=env).communicate()[0]
return out
try:
out = _minimal_ext_cmd(['git', 'rev-parse', 'HEAD'])
sha = out.strip().decode('ascii')
except OSError:
sha = 'unknown'
return sha
def get_hash():
if os.path.exists('.git'):
sha = get_git_hash()[:7]
elif os.path.exists(version_file):
try:
from basicsr.version import __version__
sha = __version__.split('+')[-1]
except ImportError:
raise ImportError('Unable to get git version')
else:
sha = 'unknown'
return sha
def write_version_py():
content = """# GENERATED VERSION FILE
# TIME: {}
__version__ = '{}'
__gitsha__ = '{}'
version_info = ({})
"""
sha = get_hash()
with open('VERSION', 'r') as f:
SHORT_VERSION = f.read().strip()
VERSION_INFO = ', '.join([x if x.isdigit() else f'"{x}"' for x in SHORT_VERSION.split('.')])
version_file_str = content.format(time.asctime(), SHORT_VERSION, sha, VERSION_INFO)
with open(version_file, 'w') as f:
f.write(version_file_str)
def get_version():
with open(version_file, 'r') as f:
exec(compile(f.read(), version_file, 'exec'))
return locals()['__version__']
def make_cuda_ext(name, module, sources, sources_cuda=None):
if sources_cuda is None:
sources_cuda = []
define_macros = []
extra_compile_args = {'cxx': []}
if torch.cuda.is_available() or os.getenv('FORCE_CUDA', '0') == '1':
define_macros += [('WITH_CUDA', None)]
extension = CUDAExtension
extra_compile_args['nvcc'] = [
'-D__CUDA_NO_HALF_OPERATORS__',
'-D__CUDA_NO_HALF_CONVERSIONS__',
'-D__CUDA_NO_HALF2_OPERATORS__',
]
sources += sources_cuda
else:
print(f'Compiling {name} without CUDA')
extension = CppExtension
return extension(
name=f'{module}.{name}',
sources=[os.path.join(*module.split('.'), p) for p in sources],
define_macros=define_macros,
extra_compile_args=extra_compile_args)
def get_requirements(filename='requirements.txt'):
here = os.path.dirname(os.path.realpath(__file__))
with open(os.path.join(here, filename), 'r') as f:
requires = [line.replace('\n', '') for line in f.readlines()]
return requires
if __name__ == '__main__':
if '--cuda_ext' in sys.argv:
ext_modules = [
make_cuda_ext(
name='deform_conv_ext',
module='basicsr.ops.dcn',
sources=['src/deform_conv_ext.cpp'],
sources_cuda=['src/deform_conv_cuda.cpp', 'src/deform_conv_cuda_kernel.cu']),
make_cuda_ext(
name='fused_act_ext',
module='basicsr.ops.fused_act',
sources=['src/fused_bias_act.cpp'],
sources_cuda=['src/fused_bias_act_kernel.cu']),
make_cuda_ext(
name='upfirdn2d_ext',
module='basicsr.ops.upfirdn2d',
sources=['src/upfirdn2d.cpp'],
sources_cuda=['src/upfirdn2d_kernel.cu']),
]
sys.argv.remove('--cuda_ext')
else:
ext_modules = []
write_version_py()
setup(
name='basicsr',
version=get_version(),
description='Open Source Image and Video Super-Resolution Toolbox',
long_description=readme(),
long_description_content_type='text/markdown',
author='Xintao Wang',
author_email='xintao.wang@outlook.com',
keywords='computer vision, restoration, super resolution',
url='https://github.com/xinntao/BasicSR',
include_package_data=True,
packages=find_packages(exclude=('options', 'datasets', 'experiments', 'results', 'tb_logger', 'wandb')),
classifiers=[
'Development Status :: 4 - Beta',
'License :: OSI Approved :: Apache Software License',
'Operating System :: OS Independent',
'Programming Language :: Python :: 3',
'Programming Language :: Python :: 3.7',
'Programming Language :: Python :: 3.8',
],
license='Apache License 2.0',
setup_requires=['cython', 'numpy'],
install_requires=get_requirements(),
ext_modules=ext_modules,
cmdclass={'build_ext': BuildExtension},
zip_safe=False)