22 Commits

Author SHA1 Message Date
Xintao
13186ac2c2 Merge pull request #18 from xinntao/pypi
Support PyPI
2021-08-08 16:29:11 +08:00
Xintao
4356ba0578 update readme 2021-08-08 16:26:09 +08:00
Xintao
32a4fa1772 add publish-pip action 2021-08-08 16:20:01 +08:00
Xintao
18ebf723f2 adaption for pypi 2021-08-08 16:12:56 +08:00
Xintao
1f83ce5432 update .gitignore 2021-08-08 15:29:33 +08:00
Xintao
bef5e3cabd update for running 2021-08-08 15:27:55 +08:00
Xintao
064df9956b add setup.py 2021-08-08 15:13:53 +08:00
Xintao
52eab16d11 update .gitignore 2021-08-08 15:01:08 +08:00
Xintao
94ae626008 regroup 2021-08-08 14:50:32 +08:00
Xintao
9baa0b3d00 regroup files 2021-08-08 14:46:43 +08:00
Xintao
f932289af1 support half inference 2021-08-01 12:10:35 +08:00
Xintao
f59a0c66ec Add Linux/MacOS executable files 2021-08-01 02:35:36 +08:00
Xintao
935993a040 fix colab link error 2021-07-31 12:34:09 +08:00
Xintao
74fcfea286 fix bug: extension 2021-07-28 15:21:47 +08:00
Xintao
da1e1ee805 update Readme.txt 2021-07-28 03:01:47 +08:00
Xintao
1d8745eb61 update Readme 2021-07-28 02:59:20 +08:00
Xintao
c94d2de155 support more inference features: tile, alpha channel, gray image, 16-bit 2021-07-28 02:56:22 +08:00
Xintao
f4297a70af Merge branch 'master' of github.com:xinntao/Real-ESRGAN 2021-07-27 11:07:05 +08:00
Xintao
492a829c14 fix bug: gt_sum; fix typo 2021-07-27 11:06:43 +08:00
Xintao
0573f32dd0 Create LICENSE 2021-07-26 00:33:08 +08:00
Xintao
8454fd2c7a add pytorch2onnx 2021-07-25 16:16:57 +08:00
Xintao
ad2ff81725 add RealESRNet model, fix bug in exe file 2021-07-25 16:16:25 +08:00
29 changed files with 549 additions and 85 deletions

30
.github/workflows/publish-pip.yml vendored Normal file
View File

@@ -0,0 +1,30 @@
name: PyPI Publish
on: push
jobs:
build-n-publish:
runs-on: ubuntu-latest
if: startsWith(github.event.ref, 'refs/tags')
steps:
- uses: actions/checkout@v2
- name: Set up Python 3.8
uses: actions/setup-python@v1
with:
python-version: 3.8
- name: Upgrade pip
run: pip install pip --upgrade
- name: Install PyTorch (cpu)
run: pip install torch==1.7.0+cpu torchvision==0.8.1+cpu -f https://download.pytorch.org/whl/torch_stable.html
- name: Install dependencies
run: pip install -r requirements.txt
- name: Build and install
run: rm -rf .eggs && pip install -e .
- name: Build for distribution
# remove bdist_wheel for pip installation with compiling cuda extensions
run: python setup.py sdist
- name: Publish distribution to PyPI
uses: pypa/gh-action-pypi-publish@master
with:
password: ${{ secrets.PYPI_API_TOKEN }}

View File

@@ -26,5 +26,5 @@ jobs:
- name: Lint
run: |
flake8 .
isort --check-only --diff data/ models/ inference_realesrgan.py
yapf -r -d data/ models/ inference_realesrgan.py
isort --check-only --diff realesrgan/ scripts/ inference_realesrgan.py setup.py
yapf -r -d realesrgan/ scripts/ inference_realesrgan.py setup.py

9
.gitignore vendored
View File

@@ -1,3 +1,12 @@
# ignored folders
datasets/*
experiments/*
results/*
tb_logger/*
wandb/*
tmp/*
version.py
.vscode
# Byte-compiled / optimized / DLL files

29
LICENSE Normal file
View File

@@ -0,0 +1,29 @@
BSD 3-Clause License
Copyright (c) 2021, Xintao Wang
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
3. Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

8
MANIFEST.in Normal file
View File

@@ -0,0 +1,8 @@
include assets/*
include inputs/*
include scripts/*.py
include inference_realesrgan.py
include VERSION
include LICENSE
include requirements.txt
include realesrgan/weights/README.md

View File

@@ -5,13 +5,16 @@
[![LICENSE](https://img.shields.io/github/license/xinntao/Real-ESRGAN.svg)](https://github.com/xinntao/Real-ESRGAN/blob/master/LICENSE)
[![python lint](https://github.com/xinntao/Real-ESRGAN/actions/workflows/pylint.yml/badge.svg)](https://github.com/xinntao/Real-ESRGAN/blob/master/.github/workflows/pylint.yml)
1. [Colab Demo](https://colab.research.google.com/drive/1sVsoBd9AjckIXThgtZhGrHRfFI6UUYOo) for Real-ESRGAN <a href="https://colab.research.google.com/drive/1k2Zod6kSHEvraybHl50Lys0LerhyTMCo?usp=sharing"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="google colab logo"></a>.
2. [Portable Windows executable file](https://github.com/xinntao/Real-ESRGAN/releases). You can find more information [here](#Portable-executable-files).
1. [Colab Demo](https://colab.research.google.com/drive/1k2Zod6kSHEvraybHl50Lys0LerhyTMCo?usp=sharing) for Real-ESRGAN <a href="https://colab.research.google.com/drive/1k2Zod6kSHEvraybHl50Lys0LerhyTMCo?usp=sharing"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="google colab logo"></a>.
2. [Portable Windows/Linux/MacOS executable files for Intel/AMD/Nvidia GPU](https://github.com/xinntao/Real-ESRGAN/releases). You can find more information [here](#Portable-executable-files).
Real-ESRGAN aims at developing **Practical Algorithms for General Image Restoration**.<br>
We extend the powerful ESRGAN to a practical restoration application (namely, Real-ESRGAN), which is trained with pure synthetic data.
:triangular_flag_on_post: The training codes have been released. A detailed guide can be found in [Training.md](Training.md).
:triangular_flag_on_post: **Updates**
- :white_check_mark: [The inference code](inference_realesrgan.py) supports: 1) **tile** options; 2) images with **alpha channel**; 3) **gray** images; 4) **16-bit** images.
- :white_check_mark: The training codes have been released. A detailed guide can be found in [Training.md](Training.md).
### :book: Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data
@@ -49,16 +52,24 @@ If you have some images that Real-ESRGAN could not well restored, please also op
### Portable executable files
You can download **Windows executable files** from https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN-ncnn-vulkan.zip
You can download **Windows/Linux/MacOS executable files for Intel/AMD/Nvidia GPU** from https://github.com/xinntao/Real-ESRGAN/releases
This executable file is **portable** and includes all the binaries and models required. No CUDA or PyTorch environment is needed.<br>
You can simply run the following command:
You can simply run the following command (the Windows example, more information is in the README.md of each executable files):
```bash
./realesrgan-ncnn-vulkan.exe -i input.jpg -o output.png
```
We have provided three models:
1. realesrgan-x4plus (default)
2. realesrnet-x4plus
3. esrgan-x4
You can use the `-n` argument for other models, for example, `./realesrgan-ncnn-vulkan.exe -i input.jpg -o output.png -n realesrnet-x4plus`
Note that it may introduce block inconsistency (and also generate slightly different results from the PyTorch implementation), because this executable file first crops the input image into several tiles, and then processes them separately, finally stitches together.
This executable file is based on the wonderful [Tencent/ncnn](https://github.com/Tencent/ncnn) and [realsr-ncnn-vulkan](https://github.com/nihui/realsr-ncnn-vulkan) by [nihui](https://github.com/nihui).
@@ -86,6 +97,7 @@ This executable file is based on the wonderful [Tencent/ncnn](https://github.com
# We use BasicSR for both training and inference
pip install basicsr
pip install -r requirements.txt
python setup.py develop
```
## :zap: Quick Inference
@@ -106,6 +118,12 @@ python inference_realesrgan.py --model_path experiments/pretrained_models/RealES
Results are in the `results` folder
## :european_castle: Model Zoo
- [RealESRGAN-x4plus](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth)
- [RealESRNet-x4plus](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth)
- [official ESRGAN-x4](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth)
## :computer: Training
A detailed guide can be found in [Training.md](Training.md).

View File

@@ -34,14 +34,17 @@ DF2K_HR_sub/000001_s003.png
## Train Real-ESRNet
1. Download pre-trained model [ESRGAN](https://drive.google.com/file/d/1b3_bWZTjNO3iL2js1yWkJfjZykcQgvzT/view?usp=sharing) into `experiments/pretrained_models`.
1. Download pre-trained model [ESRGAN](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth) into `experiments/pretrained_models`.
```bash
wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth -P experiments/pretrained_models
```
1. Modify the content in the option file `options/train_realesrnet_x4plus.yml` accordingly:
```yml
train:
name: DF2K+OST
type: RealESRGANDataset
dataroot_gt: datasets/DF2K # modify to the root path of your folder
meta_info: data/meta_info/meta_info_DF2Kmultiscale+OST_sub.txt # modify to your own generate meta info txt
meta_info: realesrgan/meta_info/meta_info_DF2Kmultiscale+OST_sub.txt # modify to your own generate meta info txt
io_backend:
type: disk
```
@@ -73,12 +76,12 @@ DF2K_HR_sub/000001_s003.png
1. Before the formal training, you may run in the `--debug` mode to see whether everything is OK. We use four GPUs for training:
```bash
CUDA_VISIBLE_DEVICES=0,1,2,3 \
python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 train.py -opt options/train_realesrnet_x4plus.yml --launcher pytorch --debug
python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 realesrgan/train.py -opt options/train_realesrnet_x4plus.yml --launcher pytorch --debug
```
1. The formal training. We use four GPUs for training. We use the `--auto_resume` argument to automatically resume the training if necessary.
```bash
CUDA_VISIBLE_DEVICES=0,1,2,3 \
python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 train.py -opt options/train_realesrnet_x4plus.yml --launcher pytorch --auto_resume
python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 realesrgan/train.py -opt options/train_realesrnet_x4plus.yml --launcher pytorch --auto_resume
```
## Train Real-ESRGAN
@@ -88,10 +91,10 @@ DF2K_HR_sub/000001_s003.png
1. Before the formal training, you may run in the `--debug` mode to see whether everything is OK. We use four GPUs for training:
```bash
CUDA_VISIBLE_DEVICES=0,1,2,3 \
python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 train.py -opt options/train_realesrgan_x4plus.yml --launcher pytorch --debug
python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 realesrgan/train.py -opt options/train_realesrgan_x4plus.yml --launcher pytorch --debug
```
1. The formal training. We use four GPUs for training. We use the `--auto_resume` argument to automatically resume the training if necessary.
```bash
CUDA_VISIBLE_DEVICES=0,1,2,3 \
python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 train.py -opt options/train_realesrgan_x4plus.yml --launcher pytorch --auto_resume
python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 realesrgan/train.py -opt options/train_realesrgan_x4plus.yml --launcher pytorch --auto_resume
```

1
VERSION Normal file
View File

@@ -0,0 +1 @@
0.2.0

View File

@@ -1,67 +1,65 @@
import argparse
import cv2
import glob
import numpy as np
import os
import torch
from basicsr.archs.rrdbnet_arch import RRDBNet
from torch.nn import functional as F
from realesrgan import RealESRGANer
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--model_path', type=str, default='experiments/pretrained_models/RealESRGAN_x4plus.pth')
parser.add_argument('--scale', type=int, default=4)
parser.add_argument('--input', type=str, default='inputs', help='input image or folder')
parser.add_argument('--input', type=str, default='inputs', help='Input image or folder')
parser.add_argument(
'--model_path',
type=str,
default='experiments/pretrained_models/RealESRGAN_x4plus.pth',
help='Path to the pre-trained model')
parser.add_argument('--output', type=str, default='results', help='Output folder')
parser.add_argument('--scale', type=int, default=4, help='Upsample scale factor')
parser.add_argument('--suffix', type=str, default='out', help='Suffix of the restored image')
parser.add_argument('--tile', type=int, default=0, help='Tile size, 0 for no tile during testing')
parser.add_argument('--tile_pad', type=int, default=10, help='Tile padding')
parser.add_argument('--pre_pad', type=int, default=0, help='Pre padding size at each border')
parser.add_argument('--half', action='store_true', help='Use half precision during inference')
parser.add_argument(
'--alpha_upsampler',
type=str,
default='realesrgan',
help='The upsampler for the alpha channels. Options: realesrgan | bicubic')
parser.add_argument(
'--ext',
type=str,
default='auto',
help='Image extension. Options: auto | jpg | png, auto means using the same extension as inputs')
args = parser.parse_args()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# set up model
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=args.scale)
loadnet = torch.load(args.model_path)
model.load_state_dict(loadnet['params_ema'], strict=True)
model.eval()
model = model.to(device)
upsampler = RealESRGANer(
scale=args.scale,
model_path=args.model_path,
tile=args.tile,
tile_pad=args.tile_pad,
pre_pad=args.pre_pad,
half=args.half)
os.makedirs(args.output, exist_ok=True)
if os.path.isfile(args.input):
paths = [args.input]
else:
paths = sorted(glob.glob(os.path.join(args.input, '*')))
os.makedirs('results/', exist_ok=True)
for idx, path in enumerate(sorted(glob.glob(os.path.join(args.input, '*')))):
imgname = os.path.splitext(os.path.basename(path))[0]
for idx, path in enumerate(paths):
imgname, extension = os.path.splitext(os.path.basename(path))
print('Testing', idx, imgname)
# read image
img = cv2.imread(path, cv2.IMREAD_COLOR).astype(np.float32) / 255.
img = torch.from_numpy(np.transpose(img[:, :, [2, 1, 0]], (2, 0, 1))).float()
img = img.unsqueeze(0).to(device)
if args.scale == 2:
mod_scale = 2
elif args.scale == 1:
mod_scale = 4
img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
output, img_mode = upsampler.enhance(img)
if args.ext == 'auto':
extension = extension[1:]
else:
mod_scale = None
if mod_scale is not None:
h_pad, w_pad = 0, 0
_, _, h, w = img.size()
if (h % mod_scale != 0):
h_pad = (mod_scale - h % mod_scale)
if (w % mod_scale != 0):
w_pad = (mod_scale - w % mod_scale)
img = F.pad(img, (0, w_pad, 0, h_pad), 'reflect')
try:
# inference
with torch.no_grad():
output = model(img)
# remove extra pad
if mod_scale is not None:
_, _, h, w = output.size()
output = output[:, :, 0:h - h_pad, 0:w - w_pad]
# save image
output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0))
output = (output * 255.0).round().astype(np.uint8)
cv2.imwrite(f'results/{imgname}_RealESRGAN.png', output)
except Exception as error:
print('Error', error)
extension = args.ext
if img_mode == 'RGBA': # RGBA images should be saved in png format
extension = 'png'
save_path = os.path.join(args.output, f'{imgname}_{args.suffix}.{extension}')
cv2.imwrite(save_path, output)
if __name__ == '__main__':

BIN
inputs/tree_alpha_16bit.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 373 KiB

BIN
inputs/wolf_gray.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 52 KiB

View File

@@ -39,7 +39,7 @@ datasets:
name: DF2K+OST
type: RealESRGANDataset
dataroot_gt: datasets/DF2K
meta_info: data/meta_info/meta_info_DF2Kmultiscale+OST_sub.txt
meta_info: realesrgan/data/meta_info/meta_info_DF2Kmultiscale+OST_sub.txt
io_backend:
type: disk
@@ -100,7 +100,7 @@ network_d:
# path
path:
# use the pre-trained Real-ESRNet model
pretrain_network_g: experiments/train_RealESRNetx4plus_1000k_B12G4_fromESRGAN/model/net_g_1000000.pth
pretrain_network_g: experiments/train_RealESRNetx4plus_1000k_B12G4_fromESRGAN/models/net_g_1000000.pth
param_key_g: params_ema
strict_load_g: true
resume_state: ~

View File

@@ -36,7 +36,7 @@ datasets:
name: DF2K+OST
type: RealESRGANDataset
dataroot_gt: datasets/DF2K
meta_info: data/meta_info/meta_info_DF2Kmultiscale+OST_sub.txt
meta_info: realesrgan/data/meta_info/meta_info_DF2Kmultiscale+OST_sub.txt
io_backend:
type: disk

6
realesrgan/__init__.py Normal file
View File

@@ -0,0 +1,6 @@
# flake8: noqa
from .archs import *
from .data import *
from .models import *
from .utils import *
from .version import __gitsha__, __version__

View File

@@ -7,4 +7,4 @@ from os import path as osp
arch_folder = osp.dirname(osp.abspath(__file__))
arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')]
# import all the arch modules
_arch_modules = [importlib.import_module(f'archs.{file_name}') for file_name in arch_filenames]
_arch_modules = [importlib.import_module(f'realesrgan.archs.{file_name}') for file_name in arch_filenames]

View File

@@ -7,4 +7,4 @@ from os import path as osp
data_folder = osp.dirname(osp.abspath(__file__))
dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')]
# import all the dataset modules
_dataset_modules = [importlib.import_module(f'data.{file_name}') for file_name in dataset_filenames]
_dataset_modules = [importlib.import_module(f'realesrgan.data.{file_name}') for file_name in dataset_filenames]

View File

@@ -7,4 +7,4 @@ from os import path as osp
model_folder = osp.dirname(osp.abspath(__file__))
model_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) if v.endswith('_model.py')]
# import all the model modules
_model_modules = [importlib.import_module(f'models.{file_name}') for file_name in model_filenames]
_model_modules = [importlib.import_module(f'realesrgan.models.{file_name}') for file_name in model_filenames]

View File

@@ -18,7 +18,7 @@ class RealESRGANModel(SRGANModel):
def __init__(self, opt):
super(RealESRGANModel, self).__init__(opt)
self.jpeger = DiffJPEG(differentiable=False).cuda()
self.usm_shaper = USMSharp().cuda()
self.usm_sharpener = USMSharp().cuda()
self.queue_size = opt['queue_size']
@torch.no_grad()
@@ -58,7 +58,7 @@ class RealESRGANModel(SRGANModel):
if self.is_train:
# training data synthesis
self.gt = data['gt'].to(self.device)
self.gt_usm = self.usm_shaper(self.gt)
self.gt_usm = self.usm_sharpener(self.gt)
self.kernel1 = data['kernel1'].to(self.device)
self.kernel2 = data['kernel2'].to(self.device)
@@ -160,6 +160,8 @@ class RealESRGANModel(SRGANModel):
# training pair pool
self._dequeue_and_enqueue()
# sharpen self.gt again, as we have changed the self.gt with self._dequeue_and_enqueue
self.gt_usm = self.usm_sharpener(self.gt)
else:
self.lq = data['lq'].to(self.device)
if 'gt' in data:

View File

@@ -17,7 +17,7 @@ class RealESRNetModel(SRModel):
def __init__(self, opt):
super(RealESRNetModel, self).__init__(opt)
self.jpeger = DiffJPEG(differentiable=False).cuda()
self.usm_shaper = USMSharp().cuda()
self.usm_sharpener = USMSharp().cuda()
self.queue_size = opt['queue_size']
@torch.no_grad()
@@ -59,7 +59,7 @@ class RealESRNetModel(SRModel):
self.gt = data['gt'].to(self.device)
# USM the GT images
if self.opt['gt_usm'] is True:
self.gt = self.usm_shaper(self.gt)
self.gt = self.usm_sharpener(self.gt)
self.kernel1 = data['kernel1'].to(self.device)
self.kernel2 = data['kernel2'].to(self.device)

11
realesrgan/train.py Normal file
View File

@@ -0,0 +1,11 @@
# flake8: noqa
import os.path as osp
from basicsr.train import train_pipeline
import realesrgan.archs
import realesrgan.data
import realesrgan.models
if __name__ == '__main__':
root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir))
train_pipeline(root_path)

226
realesrgan/utils.py Normal file
View File

@@ -0,0 +1,226 @@
import cv2
import math
import numpy as np
import os
import torch
from basicsr.archs.rrdbnet_arch import RRDBNet
from torch.hub import download_url_to_file, get_dir
from torch.nn import functional as F
from urllib.parse import urlparse
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
class RealESRGANer():
def __init__(self, scale, model_path, tile=0, tile_pad=10, pre_pad=10, half=False):
self.scale = scale
self.tile_size = tile
self.tile_pad = tile_pad
self.pre_pad = pre_pad
self.mod_scale = None
self.half = half
# initialize model
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=scale)
if model_path.startswith('https://'):
model_path = load_file_from_url(
url=model_path, model_dir='realesrgan/weights', progress=True, file_name=None)
loadnet = torch.load(model_path)
if 'params_ema' in loadnet:
keyname = 'params_ema'
else:
keyname = 'params'
model.load_state_dict(loadnet[keyname], strict=True)
model.eval()
self.model = model.to(self.device)
if self.half:
self.model = self.model.half()
def pre_process(self, img):
img = torch.from_numpy(np.transpose(img, (2, 0, 1))).float()
self.img = img.unsqueeze(0).to(self.device)
if self.half:
self.img = self.img.half()
# pre_pad
if self.pre_pad != 0:
self.img = F.pad(self.img, (0, self.pre_pad, 0, self.pre_pad), 'reflect')
# mod pad
if self.scale == 2:
self.mod_scale = 2
elif self.scale == 1:
self.mod_scale = 4
if self.mod_scale is not None:
self.mod_pad_h, self.mod_pad_w = 0, 0
_, _, h, w = self.img.size()
if (h % self.mod_scale != 0):
self.mod_pad_h = (self.mod_scale - h % self.mod_scale)
if (w % self.mod_scale != 0):
self.mod_pad_w = (self.mod_scale - w % self.mod_scale)
self.img = F.pad(self.img, (0, self.mod_pad_w, 0, self.mod_pad_h), 'reflect')
def process(self):
try:
# inference
with torch.no_grad():
self.output = self.model(self.img)
except Exception as error:
print('Error', error)
def tile_process(self):
"""Modified from: https://github.com/ata4/esrgan-launcher
"""
batch, channel, height, width = self.img.shape
output_height = height * self.scale
output_width = width * self.scale
output_shape = (batch, channel, output_height, output_width)
# start with black image
self.output = self.img.new_zeros(output_shape)
tiles_x = math.ceil(width / self.tile_size)
tiles_y = math.ceil(height / self.tile_size)
# loop over all tiles
for y in range(tiles_y):
for x in range(tiles_x):
# extract tile from input image
ofs_x = x * self.tile_size
ofs_y = y * self.tile_size
# input tile area on total image
input_start_x = ofs_x
input_end_x = min(ofs_x + self.tile_size, width)
input_start_y = ofs_y
input_end_y = min(ofs_y + self.tile_size, height)
# input tile area on total image with padding
input_start_x_pad = max(input_start_x - self.tile_pad, 0)
input_end_x_pad = min(input_end_x + self.tile_pad, width)
input_start_y_pad = max(input_start_y - self.tile_pad, 0)
input_end_y_pad = min(input_end_y + self.tile_pad, height)
# input tile dimensions
input_tile_width = input_end_x - input_start_x
input_tile_height = input_end_y - input_start_y
tile_idx = y * tiles_x + x + 1
input_tile = self.img[:, :, input_start_y_pad:input_end_y_pad, input_start_x_pad:input_end_x_pad]
# upscale tile
try:
with torch.no_grad():
output_tile = self.model(input_tile)
except Exception as error:
print('Error', error)
print(f'\tTile {tile_idx}/{tiles_x * tiles_y}')
# output tile area on total image
output_start_x = input_start_x * self.scale
output_end_x = input_end_x * self.scale
output_start_y = input_start_y * self.scale
output_end_y = input_end_y * self.scale
# output tile area without padding
output_start_x_tile = (input_start_x - input_start_x_pad) * self.scale
output_end_x_tile = output_start_x_tile + input_tile_width * self.scale
output_start_y_tile = (input_start_y - input_start_y_pad) * self.scale
output_end_y_tile = output_start_y_tile + input_tile_height * self.scale
# put tile into output image
self.output[:, :, output_start_y:output_end_y,
output_start_x:output_end_x] = output_tile[:, :, output_start_y_tile:output_end_y_tile,
output_start_x_tile:output_end_x_tile]
def post_process(self):
# remove extra pad
if self.mod_scale is not None:
_, _, h, w = self.output.size()
self.output = self.output[:, :, 0:h - self.mod_pad_h * self.scale, 0:w - self.mod_pad_w * self.scale]
# remove prepad
if self.pre_pad != 0:
_, _, h, w = self.output.size()
self.output = self.output[:, :, 0:h - self.pre_pad * self.scale, 0:w - self.pre_pad * self.scale]
return self.output
def enhance(self, img, tile=False, alpha_upsampler='realesrgan'):
# img: numpy
img = img.astype(np.float32)
if np.max(img) > 255: # 16-bit image
max_range = 65535
print('\tInput is a 16-bit image')
else:
max_range = 255
img = img / max_range
if len(img.shape) == 2: # gray image
img_mode = 'L'
img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
elif img.shape[2] == 4: # RGBA image with alpha channel
img_mode = 'RGBA'
alpha = img[:, :, 3]
img = img[:, :, 0:3]
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
if alpha_upsampler == 'realesrgan':
alpha = cv2.cvtColor(alpha, cv2.COLOR_GRAY2RGB)
else:
img_mode = 'RGB'
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
# ------------------- process image (without the alpha channel) ------------------- #
self.pre_process(img)
if self.tile_size > 0:
self.tile_process()
else:
self.process()
output_img = self.post_process()
output_img = output_img.data.squeeze().float().cpu().clamp_(0, 1).numpy()
output_img = np.transpose(output_img[[2, 1, 0], :, :], (1, 2, 0))
if img_mode == 'L':
output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2GRAY)
# ------------------- process the alpha channel if necessary ------------------- #
if img_mode == 'RGBA':
if alpha_upsampler == 'realesrgan':
self.pre_process(alpha)
if self.tile_size > 0:
self.tile_process()
else:
self.process()
output_alpha = self.post_process()
output_alpha = output_alpha.data.squeeze().float().cpu().clamp_(0, 1).numpy()
output_alpha = np.transpose(output_alpha[[2, 1, 0], :, :], (1, 2, 0))
output_alpha = cv2.cvtColor(output_alpha, cv2.COLOR_BGR2GRAY)
else:
h, w = alpha.shape[0:2]
output_alpha = cv2.resize(alpha, (w * self.scale, h * self.scale), interpolation=cv2.INTER_LINEAR)
# merge the alpha channel
output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2BGRA)
output_img[:, :, 3] = output_alpha
# ------------------------------ return ------------------------------ #
if max_range == 65535: # 16-bit image
output = (output_img * 65535.0).round().astype(np.uint16)
else:
output = (output_img * 255.0).round().astype(np.uint8)
return output, img_mode
def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
"""Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py
"""
if model_dir is None:
hub_dir = get_dir()
model_dir = os.path.join(hub_dir, 'checkpoints')
os.makedirs(os.path.join(ROOT_DIR, model_dir), exist_ok=True)
parts = urlparse(url)
filename = os.path.basename(parts.path)
if file_name is not None:
filename = file_name
cached_file = os.path.abspath(os.path.join(ROOT_DIR, model_dir, filename))
if not os.path.exists(cached_file):
print(f'Downloading: "{url}" to {cached_file}\n')
download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
return cached_file

View File

@@ -0,0 +1,3 @@
# Weights
Put the downloaded weights to this folder.

View File

@@ -1,4 +1,4 @@
basicsr
cv2
numpy
opencv-python
torch>=1.7

17
scripts/pytorch2onnx.py Normal file
View File

@@ -0,0 +1,17 @@
import torch
import torch.onnx
from basicsr.archs.rrdbnet_arch import RRDBNet
# An instance of your model
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32)
model.load_state_dict(torch.load('experiments/pretrained_models/RealESRGAN_x4plus.pth')['params_ema'])
# set the train mode to false since we will only run the forward pass.
model.train(False)
model.cpu().eval()
# An example input you would normally provide to your model's forward() method
x = torch.rand(1, 3, 64, 64)
# Export the model
with torch.no_grad():
torch_out = torch.onnx._export(model, x, 'realesrgan-x4.onnx', opset_version=11, export_params=True)

View File

@@ -16,7 +16,7 @@ split_before_expression_after_opening_paren = true
line_length = 120
multi_line_output = 0
known_standard_library = pkg_resources,setuptools
known_first_party = basicsr # modify it!
known_first_party = realesrgan
known_third_party = basicsr,cv2,numpy,torch
no_lines_before = STDLIB,LOCALFOLDER
default_section = THIRDPARTY

113
setup.py Normal file
View File

@@ -0,0 +1,113 @@
#!/usr/bin/env python
from setuptools import find_packages, setup
import os
import subprocess
import time
version_file = 'realesrgan/version.py'
def readme():
with open('README.md', encoding='utf-8') as f:
content = f.read()
return content
def get_git_hash():
def _minimal_ext_cmd(cmd):
# construct minimal environment
env = {}
for k in ['SYSTEMROOT', 'PATH', 'HOME']:
v = os.environ.get(k)
if v is not None:
env[k] = v
# LANGUAGE is used on win32
env['LANGUAGE'] = 'C'
env['LANG'] = 'C'
env['LC_ALL'] = 'C'
out = subprocess.Popen(cmd, stdout=subprocess.PIPE, env=env).communicate()[0]
return out
try:
out = _minimal_ext_cmd(['git', 'rev-parse', 'HEAD'])
sha = out.strip().decode('ascii')
except OSError:
sha = 'unknown'
return sha
def get_hash():
if os.path.exists('.git'):
sha = get_git_hash()[:7]
elif os.path.exists(version_file):
try:
from facexlib.version import __version__
sha = __version__.split('+')[-1]
except ImportError:
raise ImportError('Unable to get git version')
else:
sha = 'unknown'
return sha
def write_version_py():
content = """# GENERATED VERSION FILE
# TIME: {}
__version__ = '{}'
__gitsha__ = '{}'
version_info = ({})
"""
sha = get_hash()
with open('VERSION', 'r') as f:
SHORT_VERSION = f.read().strip()
VERSION_INFO = ', '.join([x if x.isdigit() else f'"{x}"' for x in SHORT_VERSION.split('.')])
version_file_str = content.format(time.asctime(), SHORT_VERSION, sha, VERSION_INFO)
with open(version_file, 'w') as f:
f.write(version_file_str)
def get_version():
with open(version_file, 'r') as f:
exec(compile(f.read(), version_file, 'exec'))
return locals()['__version__']
def get_requirements(filename='requirements.txt'):
here = os.path.dirname(os.path.realpath(__file__))
with open(os.path.join(here, filename), 'r') as f:
requires = [line.replace('\n', '') for line in f.readlines()]
return requires
if __name__ == '__main__':
write_version_py()
setup(
name='realesrgan',
version=get_version(),
description='Real-ESRGAN aims at developing Practical Algorithms for General Image Restoration',
long_description=readme(),
long_description_content_type='text/markdown',
author='Xintao Wang',
author_email='xintao.wang@outlook.com',
keywords='computer vision, pytorch, image restoration, super-resolution, esrgan, real-esrgan',
url='https://github.com/xinntao/Real-ESRGAN',
include_package_data=True,
packages=find_packages(exclude=('options', 'datasets', 'experiments', 'results', 'tb_logger', 'wandb')),
classifiers=[
'Development Status :: 4 - Beta',
'License :: OSI Approved :: Apache Software License',
'Operating System :: OS Independent',
'Programming Language :: Python :: 3',
'Programming Language :: Python :: 3.7',
'Programming Language :: Python :: 3.8',
],
license='BSD-3-Clause License',
setup_requires=['cython', 'numpy'],
install_requires=get_requirements(),
zip_safe=False)

View File

@@ -1,10 +0,0 @@
import os.path as osp
from basicsr.train import train_pipeline
import archs # noqa: F401
import data # noqa: F401
import models # noqa: F401
if __name__ == '__main__':
root_path = osp.abspath(osp.join(__file__, osp.pardir))
train_pipeline(root_path)