Compare commits
25 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
64ad194dda | ||
|
|
5745599813 | ||
|
|
3ce0c97e89 | ||
|
|
13186ac2c2 | ||
|
|
4356ba0578 | ||
|
|
32a4fa1772 | ||
|
|
18ebf723f2 | ||
|
|
1f83ce5432 | ||
|
|
bef5e3cabd | ||
|
|
064df9956b | ||
|
|
52eab16d11 | ||
|
|
94ae626008 | ||
|
|
9baa0b3d00 | ||
|
|
f932289af1 | ||
|
|
f59a0c66ec | ||
|
|
935993a040 | ||
|
|
74fcfea286 | ||
|
|
da1e1ee805 | ||
|
|
1d8745eb61 | ||
|
|
c94d2de155 | ||
|
|
f4297a70af | ||
|
|
492a829c14 | ||
|
|
0573f32dd0 | ||
|
|
8454fd2c7a | ||
|
|
ad2ff81725 |
30
.github/workflows/publish-pip.yml
vendored
Normal file
30
.github/workflows/publish-pip.yml
vendored
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
name: PyPI Publish
|
||||||
|
|
||||||
|
on: push
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
build-n-publish:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
if: startsWith(github.event.ref, 'refs/tags')
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v2
|
||||||
|
- name: Set up Python 3.8
|
||||||
|
uses: actions/setup-python@v1
|
||||||
|
with:
|
||||||
|
python-version: 3.8
|
||||||
|
- name: Upgrade pip
|
||||||
|
run: pip install pip --upgrade
|
||||||
|
- name: Install PyTorch (cpu)
|
||||||
|
run: pip install torch==1.7.0+cpu torchvision==0.8.1+cpu -f https://download.pytorch.org/whl/torch_stable.html
|
||||||
|
- name: Install dependencies
|
||||||
|
run: pip install -r requirements.txt
|
||||||
|
- name: Build and install
|
||||||
|
run: rm -rf .eggs && pip install -e .
|
||||||
|
- name: Build for distribution
|
||||||
|
# remove bdist_wheel for pip installation with compiling cuda extensions
|
||||||
|
run: python setup.py sdist
|
||||||
|
- name: Publish distribution to PyPI
|
||||||
|
uses: pypa/gh-action-pypi-publish@master
|
||||||
|
with:
|
||||||
|
password: ${{ secrets.PYPI_API_TOKEN }}
|
||||||
4
.github/workflows/pylint.yml
vendored
4
.github/workflows/pylint.yml
vendored
@@ -26,5 +26,5 @@ jobs:
|
|||||||
- name: Lint
|
- name: Lint
|
||||||
run: |
|
run: |
|
||||||
flake8 .
|
flake8 .
|
||||||
isort --check-only --diff data/ models/ inference_realesrgan.py
|
isort --check-only --diff realesrgan/ scripts/ inference_realesrgan.py setup.py
|
||||||
yapf -r -d data/ models/ inference_realesrgan.py
|
yapf -r -d realesrgan/ scripts/ inference_realesrgan.py setup.py
|
||||||
|
|||||||
9
.gitignore
vendored
9
.gitignore
vendored
@@ -1,3 +1,12 @@
|
|||||||
|
# ignored folders
|
||||||
|
datasets/*
|
||||||
|
experiments/*
|
||||||
|
results/*
|
||||||
|
tb_logger/*
|
||||||
|
wandb/*
|
||||||
|
tmp/*
|
||||||
|
|
||||||
|
version.py
|
||||||
.vscode
|
.vscode
|
||||||
|
|
||||||
# Byte-compiled / optimized / DLL files
|
# Byte-compiled / optimized / DLL files
|
||||||
|
|||||||
29
LICENSE
Normal file
29
LICENSE
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
BSD 3-Clause License
|
||||||
|
|
||||||
|
Copyright (c) 2021, Xintao Wang
|
||||||
|
All rights reserved.
|
||||||
|
|
||||||
|
Redistribution and use in source and binary forms, with or without
|
||||||
|
modification, are permitted provided that the following conditions are met:
|
||||||
|
|
||||||
|
1. Redistributions of source code must retain the above copyright notice, this
|
||||||
|
list of conditions and the following disclaimer.
|
||||||
|
|
||||||
|
2. Redistributions in binary form must reproduce the above copyright notice,
|
||||||
|
this list of conditions and the following disclaimer in the documentation
|
||||||
|
and/or other materials provided with the distribution.
|
||||||
|
|
||||||
|
3. Neither the name of the copyright holder nor the names of its
|
||||||
|
contributors may be used to endorse or promote products derived from
|
||||||
|
this software without specific prior written permission.
|
||||||
|
|
||||||
|
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||||
|
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||||
|
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||||
|
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||||
|
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||||
|
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||||
|
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||||
|
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||||
|
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
|
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
8
MANIFEST.in
Normal file
8
MANIFEST.in
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
include assets/*
|
||||||
|
include inputs/*
|
||||||
|
include scripts/*.py
|
||||||
|
include inference_realesrgan.py
|
||||||
|
include VERSION
|
||||||
|
include LICENSE
|
||||||
|
include requirements.txt
|
||||||
|
include realesrgan/weights/README.md
|
||||||
32
README.md
32
README.md
@@ -1,17 +1,23 @@
|
|||||||
# Real-ESRGAN
|
# Real-ESRGAN
|
||||||
|
|
||||||
[](https://github.com/xinntao/Real-ESRGAN/releases)
|
[](https://github.com/xinntao/Real-ESRGAN/releases)
|
||||||
|
[](https://pypi.org/project/realesrgan/)
|
||||||
[](https://github.com/xinntao/Real-ESRGAN/issues)
|
[](https://github.com/xinntao/Real-ESRGAN/issues)
|
||||||
[](https://github.com/xinntao/Real-ESRGAN/blob/master/LICENSE)
|
[](https://github.com/xinntao/Real-ESRGAN/blob/master/LICENSE)
|
||||||
[](https://github.com/xinntao/Real-ESRGAN/blob/master/.github/workflows/pylint.yml)
|
[](https://github.com/xinntao/Real-ESRGAN/blob/master/.github/workflows/pylint.yml)
|
||||||
|
[](https://github.com/xinntao/Real-ESRGAN/blob/master/.github/workflows/publish-pip.yml)
|
||||||
|
|
||||||
1. [Colab Demo](https://colab.research.google.com/drive/1sVsoBd9AjckIXThgtZhGrHRfFI6UUYOo) for Real-ESRGAN <a href="https://colab.research.google.com/drive/1k2Zod6kSHEvraybHl50Lys0LerhyTMCo?usp=sharing"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="google colab logo"></a>.
|
1. [Colab Demo](https://colab.research.google.com/drive/1k2Zod6kSHEvraybHl50Lys0LerhyTMCo?usp=sharing) for Real-ESRGAN <a href="https://colab.research.google.com/drive/1k2Zod6kSHEvraybHl50Lys0LerhyTMCo?usp=sharing"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="google colab logo"></a>.
|
||||||
2. [Portable Windows executable file](https://github.com/xinntao/Real-ESRGAN/releases). You can find more information [here](#Portable-executable-files).
|
2. [Portable Windows/Linux/MacOS executable files for Intel/AMD/Nvidia GPU](https://github.com/xinntao/Real-ESRGAN/releases). You can find more information [here](#Portable-executable-files).
|
||||||
|
|
||||||
Real-ESRGAN aims at developing **Practical Algorithms for General Image Restoration**.<br>
|
Real-ESRGAN aims at developing **Practical Algorithms for General Image Restoration**.<br>
|
||||||
We extend the powerful ESRGAN to a practical restoration application (namely, Real-ESRGAN), which is trained with pure synthetic data.
|
We extend the powerful ESRGAN to a practical restoration application (namely, Real-ESRGAN), which is trained with pure synthetic data.
|
||||||
|
|
||||||
:triangular_flag_on_post: The training codes have been released. A detailed guide can be found in [Training.md](Training.md).
|
:triangular_flag_on_post: **Updates**
|
||||||
|
|
||||||
|
- :white_check_mark: Support arbitrary scale with `--outscale` (It actually further resizes outputs with `LANCZOS4`). Add *RealESRGAN_x2plus.pth* model.
|
||||||
|
- :white_check_mark: [The inference code](inference_realesrgan.py) supports: 1) **tile** options; 2) images with **alpha channel**; 3) **gray** images; 4) **16-bit** images.
|
||||||
|
- :white_check_mark: The training codes have been released. A detailed guide can be found in [Training.md](Training.md).
|
||||||
|
|
||||||
### :book: Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data
|
### :book: Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data
|
||||||
|
|
||||||
@@ -49,16 +55,24 @@ If you have some images that Real-ESRGAN could not well restored, please also op
|
|||||||
|
|
||||||
### Portable executable files
|
### Portable executable files
|
||||||
|
|
||||||
You can download **Windows executable files** from https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN-ncnn-vulkan.zip
|
You can download **Windows/Linux/MacOS executable files for Intel/AMD/Nvidia GPU** from https://github.com/xinntao/Real-ESRGAN/releases
|
||||||
|
|
||||||
This executable file is **portable** and includes all the binaries and models required. No CUDA or PyTorch environment is needed.<br>
|
This executable file is **portable** and includes all the binaries and models required. No CUDA or PyTorch environment is needed.<br>
|
||||||
|
|
||||||
You can simply run the following command:
|
You can simply run the following command (the Windows example, more information is in the README.md of each executable files):
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
./realesrgan-ncnn-vulkan.exe -i input.jpg -o output.png
|
./realesrgan-ncnn-vulkan.exe -i input.jpg -o output.png
|
||||||
```
|
```
|
||||||
|
|
||||||
|
We have provided three models:
|
||||||
|
|
||||||
|
1. realesrgan-x4plus (default)
|
||||||
|
2. realesrnet-x4plus
|
||||||
|
3. esrgan-x4
|
||||||
|
|
||||||
|
You can use the `-n` argument for other models, for example, `./realesrgan-ncnn-vulkan.exe -i input.jpg -o output.png -n realesrnet-x4plus`
|
||||||
|
|
||||||
Note that it may introduce block inconsistency (and also generate slightly different results from the PyTorch implementation), because this executable file first crops the input image into several tiles, and then processes them separately, finally stitches together.
|
Note that it may introduce block inconsistency (and also generate slightly different results from the PyTorch implementation), because this executable file first crops the input image into several tiles, and then processes them separately, finally stitches together.
|
||||||
|
|
||||||
This executable file is based on the wonderful [Tencent/ncnn](https://github.com/Tencent/ncnn) and [realsr-ncnn-vulkan](https://github.com/nihui/realsr-ncnn-vulkan) by [nihui](https://github.com/nihui).
|
This executable file is based on the wonderful [Tencent/ncnn](https://github.com/Tencent/ncnn) and [realsr-ncnn-vulkan](https://github.com/nihui/realsr-ncnn-vulkan) by [nihui](https://github.com/nihui).
|
||||||
@@ -86,6 +100,7 @@ This executable file is based on the wonderful [Tencent/ncnn](https://github.com
|
|||||||
# We use BasicSR for both training and inference
|
# We use BasicSR for both training and inference
|
||||||
pip install basicsr
|
pip install basicsr
|
||||||
pip install -r requirements.txt
|
pip install -r requirements.txt
|
||||||
|
python setup.py develop
|
||||||
```
|
```
|
||||||
|
|
||||||
## :zap: Quick Inference
|
## :zap: Quick Inference
|
||||||
@@ -106,6 +121,13 @@ python inference_realesrgan.py --model_path experiments/pretrained_models/RealES
|
|||||||
|
|
||||||
Results are in the `results` folder
|
Results are in the `results` folder
|
||||||
|
|
||||||
|
## :european_castle: Model Zoo
|
||||||
|
|
||||||
|
- [RealESRGAN-x4plus](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth)
|
||||||
|
- [RealESRNet-x4plus](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth)
|
||||||
|
- [RealESRGAN-x2plus](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.0/RealESRGAN_x2plus.pth)
|
||||||
|
- [official ESRGAN-x4](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth)
|
||||||
|
|
||||||
## :computer: Training
|
## :computer: Training
|
||||||
|
|
||||||
A detailed guide can be found in [Training.md](Training.md).
|
A detailed guide can be found in [Training.md](Training.md).
|
||||||
|
|||||||
15
Training.md
15
Training.md
@@ -34,14 +34,17 @@ DF2K_HR_sub/000001_s003.png
|
|||||||
|
|
||||||
## Train Real-ESRNet
|
## Train Real-ESRNet
|
||||||
|
|
||||||
1. Download pre-trained model [ESRGAN](https://drive.google.com/file/d/1b3_bWZTjNO3iL2js1yWkJfjZykcQgvzT/view?usp=sharing) into `experiments/pretrained_models`.
|
1. Download pre-trained model [ESRGAN](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth) into `experiments/pretrained_models`.
|
||||||
|
```bash
|
||||||
|
wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth -P experiments/pretrained_models
|
||||||
|
```
|
||||||
1. Modify the content in the option file `options/train_realesrnet_x4plus.yml` accordingly:
|
1. Modify the content in the option file `options/train_realesrnet_x4plus.yml` accordingly:
|
||||||
```yml
|
```yml
|
||||||
train:
|
train:
|
||||||
name: DF2K+OST
|
name: DF2K+OST
|
||||||
type: RealESRGANDataset
|
type: RealESRGANDataset
|
||||||
dataroot_gt: datasets/DF2K # modify to the root path of your folder
|
dataroot_gt: datasets/DF2K # modify to the root path of your folder
|
||||||
meta_info: data/meta_info/meta_info_DF2Kmultiscale+OST_sub.txt # modify to your own generate meta info txt
|
meta_info: realesrgan/meta_info/meta_info_DF2Kmultiscale+OST_sub.txt # modify to your own generate meta info txt
|
||||||
io_backend:
|
io_backend:
|
||||||
type: disk
|
type: disk
|
||||||
```
|
```
|
||||||
@@ -73,12 +76,12 @@ DF2K_HR_sub/000001_s003.png
|
|||||||
1. Before the formal training, you may run in the `--debug` mode to see whether everything is OK. We use four GPUs for training:
|
1. Before the formal training, you may run in the `--debug` mode to see whether everything is OK. We use four GPUs for training:
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0,1,2,3 \
|
CUDA_VISIBLE_DEVICES=0,1,2,3 \
|
||||||
python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 train.py -opt options/train_realesrnet_x4plus.yml --launcher pytorch --debug
|
python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 realesrgan/train.py -opt options/train_realesrnet_x4plus.yml --launcher pytorch --debug
|
||||||
```
|
```
|
||||||
1. The formal training. We use four GPUs for training. We use the `--auto_resume` argument to automatically resume the training if necessary.
|
1. The formal training. We use four GPUs for training. We use the `--auto_resume` argument to automatically resume the training if necessary.
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0,1,2,3 \
|
CUDA_VISIBLE_DEVICES=0,1,2,3 \
|
||||||
python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 train.py -opt options/train_realesrnet_x4plus.yml --launcher pytorch --auto_resume
|
python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 realesrgan/train.py -opt options/train_realesrnet_x4plus.yml --launcher pytorch --auto_resume
|
||||||
```
|
```
|
||||||
|
|
||||||
## Train Real-ESRGAN
|
## Train Real-ESRGAN
|
||||||
@@ -88,10 +91,10 @@ DF2K_HR_sub/000001_s003.png
|
|||||||
1. Before the formal training, you may run in the `--debug` mode to see whether everything is OK. We use four GPUs for training:
|
1. Before the formal training, you may run in the `--debug` mode to see whether everything is OK. We use four GPUs for training:
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0,1,2,3 \
|
CUDA_VISIBLE_DEVICES=0,1,2,3 \
|
||||||
python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 train.py -opt options/train_realesrgan_x4plus.yml --launcher pytorch --debug
|
python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 realesrgan/train.py -opt options/train_realesrgan_x4plus.yml --launcher pytorch --debug
|
||||||
```
|
```
|
||||||
1. The formal training. We use four GPUs for training. We use the `--auto_resume` argument to automatically resume the training if necessary.
|
1. The formal training. We use four GPUs for training. We use the `--auto_resume` argument to automatically resume the training if necessary.
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0,1,2,3 \
|
CUDA_VISIBLE_DEVICES=0,1,2,3 \
|
||||||
python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 train.py -opt options/train_realesrgan_x4plus.yml --launcher pytorch --auto_resume
|
python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 realesrgan/train.py -opt options/train_realesrgan_x4plus.yml --launcher pytorch --auto_resume
|
||||||
```
|
```
|
||||||
|
|||||||
@@ -1,67 +1,76 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import cv2
|
import cv2
|
||||||
import glob
|
import glob
|
||||||
import numpy as np
|
|
||||||
import os
|
import os
|
||||||
import torch
|
|
||||||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
from realesrgan import RealESRGANer
|
||||||
from torch.nn import functional as F
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('--model_path', type=str, default='experiments/pretrained_models/RealESRGAN_x4plus.pth')
|
parser.add_argument('--input', type=str, default='inputs', help='Input image or folder')
|
||||||
parser.add_argument('--scale', type=int, default=4)
|
parser.add_argument(
|
||||||
parser.add_argument('--input', type=str, default='inputs', help='input image or folder')
|
'--model_path',
|
||||||
|
type=str,
|
||||||
|
default='experiments/pretrained_models/RealESRGAN_x4plus.pth',
|
||||||
|
help='Path to the pre-trained model')
|
||||||
|
parser.add_argument('--output', type=str, default='results', help='Output folder')
|
||||||
|
parser.add_argument('--netscale', type=int, default=4, help='Upsample scale factor of the network')
|
||||||
|
parser.add_argument('--outscale', type=float, default=4, help='The final upsampling scale of the image')
|
||||||
|
parser.add_argument('--suffix', type=str, default='out', help='Suffix of the restored image')
|
||||||
|
parser.add_argument('--tile', type=int, default=0, help='Tile size, 0 for no tile during testing')
|
||||||
|
parser.add_argument('--tile_pad', type=int, default=10, help='Tile padding')
|
||||||
|
parser.add_argument('--pre_pad', type=int, default=0, help='Pre padding size at each border')
|
||||||
|
parser.add_argument('--half', action='store_true', help='Use half precision during inference')
|
||||||
|
parser.add_argument(
|
||||||
|
'--alpha_upsampler',
|
||||||
|
type=str,
|
||||||
|
default='realesrgan',
|
||||||
|
help='The upsampler for the alpha channels. Options: realesrgan | bicubic')
|
||||||
|
parser.add_argument(
|
||||||
|
'--ext',
|
||||||
|
type=str,
|
||||||
|
default='auto',
|
||||||
|
help='Image extension. Options: auto | jpg | png, auto means using the same extension as inputs')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
upsampler = RealESRGANer(
|
||||||
# set up model
|
scale=args.netscale,
|
||||||
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=args.scale)
|
model_path=args.model_path,
|
||||||
loadnet = torch.load(args.model_path)
|
tile=args.tile,
|
||||||
model.load_state_dict(loadnet['params_ema'], strict=True)
|
tile_pad=args.tile_pad,
|
||||||
model.eval()
|
pre_pad=args.pre_pad,
|
||||||
model = model.to(device)
|
half=args.half)
|
||||||
|
os.makedirs(args.output, exist_ok=True)
|
||||||
|
if os.path.isfile(args.input):
|
||||||
|
paths = [args.input]
|
||||||
|
else:
|
||||||
|
paths = sorted(glob.glob(os.path.join(args.input, '*')))
|
||||||
|
|
||||||
os.makedirs('results/', exist_ok=True)
|
for idx, path in enumerate(paths):
|
||||||
for idx, path in enumerate(sorted(glob.glob(os.path.join(args.input, '*')))):
|
imgname, extension = os.path.splitext(os.path.basename(path))
|
||||||
imgname = os.path.splitext(os.path.basename(path))[0]
|
|
||||||
print('Testing', idx, imgname)
|
print('Testing', idx, imgname)
|
||||||
# read image
|
|
||||||
img = cv2.imread(path, cv2.IMREAD_COLOR).astype(np.float32) / 255.
|
|
||||||
img = torch.from_numpy(np.transpose(img[:, :, [2, 1, 0]], (2, 0, 1))).float()
|
|
||||||
img = img.unsqueeze(0).to(device)
|
|
||||||
|
|
||||||
if args.scale == 2:
|
img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
|
||||||
mod_scale = 2
|
h, w = img.shape[0:2]
|
||||||
elif args.scale == 1:
|
if max(h, w) > 1000 and args.netscale == 4:
|
||||||
mod_scale = 4
|
print('WARNING: The input image is large, try X2 model for better performace.')
|
||||||
else:
|
if max(h, w) < 500 and args.netscale == 2:
|
||||||
mod_scale = None
|
print('WARNING: The input image is small, try X4 model for better performace.')
|
||||||
if mod_scale is not None:
|
|
||||||
h_pad, w_pad = 0, 0
|
|
||||||
_, _, h, w = img.size()
|
|
||||||
if (h % mod_scale != 0):
|
|
||||||
h_pad = (mod_scale - h % mod_scale)
|
|
||||||
if (w % mod_scale != 0):
|
|
||||||
w_pad = (mod_scale - w % mod_scale)
|
|
||||||
img = F.pad(img, (0, w_pad, 0, h_pad), 'reflect')
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# inference
|
output, img_mode = upsampler.enhance(img, outscale=args.outscale)
|
||||||
with torch.no_grad():
|
|
||||||
output = model(img)
|
|
||||||
# remove extra pad
|
|
||||||
if mod_scale is not None:
|
|
||||||
_, _, h, w = output.size()
|
|
||||||
output = output[:, :, 0:h - h_pad, 0:w - w_pad]
|
|
||||||
# save image
|
|
||||||
output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
|
|
||||||
output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0))
|
|
||||||
output = (output * 255.0).round().astype(np.uint8)
|
|
||||||
cv2.imwrite(f'results/{imgname}_RealESRGAN.png', output)
|
|
||||||
except Exception as error:
|
except Exception as error:
|
||||||
print('Error', error)
|
print('Error', error)
|
||||||
|
else:
|
||||||
|
if args.ext == 'auto':
|
||||||
|
extension = extension[1:]
|
||||||
|
else:
|
||||||
|
extension = args.ext
|
||||||
|
if img_mode == 'RGBA': # RGBA images should be saved in png format
|
||||||
|
extension = 'png'
|
||||||
|
save_path = os.path.join(args.output, f'{imgname}_{args.suffix}.{extension}')
|
||||||
|
cv2.imwrite(save_path, output)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|||||||
BIN
inputs/tree_alpha_16bit.png
Normal file
BIN
inputs/tree_alpha_16bit.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 373 KiB |
BIN
inputs/wolf_gray.jpg
Normal file
BIN
inputs/wolf_gray.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 52 KiB |
@@ -39,7 +39,7 @@ datasets:
|
|||||||
name: DF2K+OST
|
name: DF2K+OST
|
||||||
type: RealESRGANDataset
|
type: RealESRGANDataset
|
||||||
dataroot_gt: datasets/DF2K
|
dataroot_gt: datasets/DF2K
|
||||||
meta_info: data/meta_info/meta_info_DF2Kmultiscale+OST_sub.txt
|
meta_info: realesrgan/data/meta_info/meta_info_DF2Kmultiscale+OST_sub.txt
|
||||||
io_backend:
|
io_backend:
|
||||||
type: disk
|
type: disk
|
||||||
|
|
||||||
@@ -100,7 +100,7 @@ network_d:
|
|||||||
# path
|
# path
|
||||||
path:
|
path:
|
||||||
# use the pre-trained Real-ESRNet model
|
# use the pre-trained Real-ESRNet model
|
||||||
pretrain_network_g: experiments/train_RealESRNetx4plus_1000k_B12G4_fromESRGAN/model/net_g_1000000.pth
|
pretrain_network_g: experiments/train_RealESRNetx4plus_1000k_B12G4_fromESRGAN/models/net_g_1000000.pth
|
||||||
param_key_g: params_ema
|
param_key_g: params_ema
|
||||||
strict_load_g: true
|
strict_load_g: true
|
||||||
resume_state: ~
|
resume_state: ~
|
||||||
|
|||||||
@@ -36,7 +36,7 @@ datasets:
|
|||||||
name: DF2K+OST
|
name: DF2K+OST
|
||||||
type: RealESRGANDataset
|
type: RealESRGANDataset
|
||||||
dataroot_gt: datasets/DF2K
|
dataroot_gt: datasets/DF2K
|
||||||
meta_info: data/meta_info/meta_info_DF2Kmultiscale+OST_sub.txt
|
meta_info: realesrgan/data/meta_info/meta_info_DF2Kmultiscale+OST_sub.txt
|
||||||
io_backend:
|
io_backend:
|
||||||
type: disk
|
type: disk
|
||||||
|
|
||||||
|
|||||||
6
realesrgan/__init__.py
Normal file
6
realesrgan/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
# flake8: noqa
|
||||||
|
from .archs import *
|
||||||
|
from .data import *
|
||||||
|
from .models import *
|
||||||
|
from .utils import *
|
||||||
|
from .version import __gitsha__, __version__
|
||||||
@@ -7,4 +7,4 @@ from os import path as osp
|
|||||||
arch_folder = osp.dirname(osp.abspath(__file__))
|
arch_folder = osp.dirname(osp.abspath(__file__))
|
||||||
arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')]
|
arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')]
|
||||||
# import all the arch modules
|
# import all the arch modules
|
||||||
_arch_modules = [importlib.import_module(f'archs.{file_name}') for file_name in arch_filenames]
|
_arch_modules = [importlib.import_module(f'realesrgan.archs.{file_name}') for file_name in arch_filenames]
|
||||||
@@ -7,4 +7,4 @@ from os import path as osp
|
|||||||
data_folder = osp.dirname(osp.abspath(__file__))
|
data_folder = osp.dirname(osp.abspath(__file__))
|
||||||
dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')]
|
dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')]
|
||||||
# import all the dataset modules
|
# import all the dataset modules
|
||||||
_dataset_modules = [importlib.import_module(f'data.{file_name}') for file_name in dataset_filenames]
|
_dataset_modules = [importlib.import_module(f'realesrgan.data.{file_name}') for file_name in dataset_filenames]
|
||||||
@@ -7,4 +7,4 @@ from os import path as osp
|
|||||||
model_folder = osp.dirname(osp.abspath(__file__))
|
model_folder = osp.dirname(osp.abspath(__file__))
|
||||||
model_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) if v.endswith('_model.py')]
|
model_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) if v.endswith('_model.py')]
|
||||||
# import all the model modules
|
# import all the model modules
|
||||||
_model_modules = [importlib.import_module(f'models.{file_name}') for file_name in model_filenames]
|
_model_modules = [importlib.import_module(f'realesrgan.models.{file_name}') for file_name in model_filenames]
|
||||||
@@ -18,7 +18,7 @@ class RealESRGANModel(SRGANModel):
|
|||||||
def __init__(self, opt):
|
def __init__(self, opt):
|
||||||
super(RealESRGANModel, self).__init__(opt)
|
super(RealESRGANModel, self).__init__(opt)
|
||||||
self.jpeger = DiffJPEG(differentiable=False).cuda()
|
self.jpeger = DiffJPEG(differentiable=False).cuda()
|
||||||
self.usm_shaper = USMSharp().cuda()
|
self.usm_sharpener = USMSharp().cuda()
|
||||||
self.queue_size = opt['queue_size']
|
self.queue_size = opt['queue_size']
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@@ -58,7 +58,7 @@ class RealESRGANModel(SRGANModel):
|
|||||||
if self.is_train:
|
if self.is_train:
|
||||||
# training data synthesis
|
# training data synthesis
|
||||||
self.gt = data['gt'].to(self.device)
|
self.gt = data['gt'].to(self.device)
|
||||||
self.gt_usm = self.usm_shaper(self.gt)
|
self.gt_usm = self.usm_sharpener(self.gt)
|
||||||
|
|
||||||
self.kernel1 = data['kernel1'].to(self.device)
|
self.kernel1 = data['kernel1'].to(self.device)
|
||||||
self.kernel2 = data['kernel2'].to(self.device)
|
self.kernel2 = data['kernel2'].to(self.device)
|
||||||
@@ -160,6 +160,8 @@ class RealESRGANModel(SRGANModel):
|
|||||||
|
|
||||||
# training pair pool
|
# training pair pool
|
||||||
self._dequeue_and_enqueue()
|
self._dequeue_and_enqueue()
|
||||||
|
# sharpen self.gt again, as we have changed the self.gt with self._dequeue_and_enqueue
|
||||||
|
self.gt_usm = self.usm_sharpener(self.gt)
|
||||||
else:
|
else:
|
||||||
self.lq = data['lq'].to(self.device)
|
self.lq = data['lq'].to(self.device)
|
||||||
if 'gt' in data:
|
if 'gt' in data:
|
||||||
@@ -17,7 +17,7 @@ class RealESRNetModel(SRModel):
|
|||||||
def __init__(self, opt):
|
def __init__(self, opt):
|
||||||
super(RealESRNetModel, self).__init__(opt)
|
super(RealESRNetModel, self).__init__(opt)
|
||||||
self.jpeger = DiffJPEG(differentiable=False).cuda()
|
self.jpeger = DiffJPEG(differentiable=False).cuda()
|
||||||
self.usm_shaper = USMSharp().cuda()
|
self.usm_sharpener = USMSharp().cuda()
|
||||||
self.queue_size = opt['queue_size']
|
self.queue_size = opt['queue_size']
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@@ -59,7 +59,7 @@ class RealESRNetModel(SRModel):
|
|||||||
self.gt = data['gt'].to(self.device)
|
self.gt = data['gt'].to(self.device)
|
||||||
# USM the GT images
|
# USM the GT images
|
||||||
if self.opt['gt_usm'] is True:
|
if self.opt['gt_usm'] is True:
|
||||||
self.gt = self.usm_shaper(self.gt)
|
self.gt = self.usm_sharpener(self.gt)
|
||||||
|
|
||||||
self.kernel1 = data['kernel1'].to(self.device)
|
self.kernel1 = data['kernel1'].to(self.device)
|
||||||
self.kernel2 = data['kernel2'].to(self.device)
|
self.kernel2 = data['kernel2'].to(self.device)
|
||||||
11
realesrgan/train.py
Normal file
11
realesrgan/train.py
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
# flake8: noqa
|
||||||
|
import os.path as osp
|
||||||
|
from basicsr.train import train_pipeline
|
||||||
|
|
||||||
|
import realesrgan.archs
|
||||||
|
import realesrgan.data
|
||||||
|
import realesrgan.models
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir))
|
||||||
|
train_pipeline(root_path)
|
||||||
231
realesrgan/utils.py
Normal file
231
realesrgan/utils.py
Normal file
@@ -0,0 +1,231 @@
|
|||||||
|
import cv2
|
||||||
|
import math
|
||||||
|
import numpy as np
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||||
|
from torch.hub import download_url_to_file, get_dir
|
||||||
|
from torch.nn import functional as F
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||||
|
|
||||||
|
|
||||||
|
class RealESRGANer():
|
||||||
|
|
||||||
|
def __init__(self, scale, model_path, tile=0, tile_pad=10, pre_pad=10, half=False):
|
||||||
|
self.scale = scale
|
||||||
|
self.tile_size = tile
|
||||||
|
self.tile_pad = tile_pad
|
||||||
|
self.pre_pad = pre_pad
|
||||||
|
self.mod_scale = None
|
||||||
|
self.half = half
|
||||||
|
|
||||||
|
# initialize model
|
||||||
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
|
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=scale)
|
||||||
|
|
||||||
|
if model_path.startswith('https://'):
|
||||||
|
model_path = load_file_from_url(
|
||||||
|
url=model_path, model_dir='realesrgan/weights', progress=True, file_name=None)
|
||||||
|
loadnet = torch.load(model_path)
|
||||||
|
if 'params_ema' in loadnet:
|
||||||
|
keyname = 'params_ema'
|
||||||
|
else:
|
||||||
|
keyname = 'params'
|
||||||
|
model.load_state_dict(loadnet[keyname], strict=True)
|
||||||
|
model.eval()
|
||||||
|
self.model = model.to(self.device)
|
||||||
|
if self.half:
|
||||||
|
self.model = self.model.half()
|
||||||
|
|
||||||
|
def pre_process(self, img):
|
||||||
|
img = torch.from_numpy(np.transpose(img, (2, 0, 1))).float()
|
||||||
|
self.img = img.unsqueeze(0).to(self.device)
|
||||||
|
if self.half:
|
||||||
|
self.img = self.img.half()
|
||||||
|
|
||||||
|
# pre_pad
|
||||||
|
if self.pre_pad != 0:
|
||||||
|
self.img = F.pad(self.img, (0, self.pre_pad, 0, self.pre_pad), 'reflect')
|
||||||
|
# mod pad
|
||||||
|
if self.scale == 2:
|
||||||
|
self.mod_scale = 2
|
||||||
|
elif self.scale == 1:
|
||||||
|
self.mod_scale = 4
|
||||||
|
if self.mod_scale is not None:
|
||||||
|
self.mod_pad_h, self.mod_pad_w = 0, 0
|
||||||
|
_, _, h, w = self.img.size()
|
||||||
|
if (h % self.mod_scale != 0):
|
||||||
|
self.mod_pad_h = (self.mod_scale - h % self.mod_scale)
|
||||||
|
if (w % self.mod_scale != 0):
|
||||||
|
self.mod_pad_w = (self.mod_scale - w % self.mod_scale)
|
||||||
|
self.img = F.pad(self.img, (0, self.mod_pad_w, 0, self.mod_pad_h), 'reflect')
|
||||||
|
|
||||||
|
def process(self):
|
||||||
|
self.output = self.model(self.img)
|
||||||
|
|
||||||
|
def tile_process(self):
|
||||||
|
"""Modified from: https://github.com/ata4/esrgan-launcher
|
||||||
|
"""
|
||||||
|
batch, channel, height, width = self.img.shape
|
||||||
|
output_height = height * self.scale
|
||||||
|
output_width = width * self.scale
|
||||||
|
output_shape = (batch, channel, output_height, output_width)
|
||||||
|
|
||||||
|
# start with black image
|
||||||
|
self.output = self.img.new_zeros(output_shape)
|
||||||
|
tiles_x = math.ceil(width / self.tile_size)
|
||||||
|
tiles_y = math.ceil(height / self.tile_size)
|
||||||
|
|
||||||
|
# loop over all tiles
|
||||||
|
for y in range(tiles_y):
|
||||||
|
for x in range(tiles_x):
|
||||||
|
# extract tile from input image
|
||||||
|
ofs_x = x * self.tile_size
|
||||||
|
ofs_y = y * self.tile_size
|
||||||
|
# input tile area on total image
|
||||||
|
input_start_x = ofs_x
|
||||||
|
input_end_x = min(ofs_x + self.tile_size, width)
|
||||||
|
input_start_y = ofs_y
|
||||||
|
input_end_y = min(ofs_y + self.tile_size, height)
|
||||||
|
|
||||||
|
# input tile area on total image with padding
|
||||||
|
input_start_x_pad = max(input_start_x - self.tile_pad, 0)
|
||||||
|
input_end_x_pad = min(input_end_x + self.tile_pad, width)
|
||||||
|
input_start_y_pad = max(input_start_y - self.tile_pad, 0)
|
||||||
|
input_end_y_pad = min(input_end_y + self.tile_pad, height)
|
||||||
|
|
||||||
|
# input tile dimensions
|
||||||
|
input_tile_width = input_end_x - input_start_x
|
||||||
|
input_tile_height = input_end_y - input_start_y
|
||||||
|
tile_idx = y * tiles_x + x + 1
|
||||||
|
input_tile = self.img[:, :, input_start_y_pad:input_end_y_pad, input_start_x_pad:input_end_x_pad]
|
||||||
|
|
||||||
|
# upscale tile
|
||||||
|
try:
|
||||||
|
with torch.no_grad():
|
||||||
|
output_tile = self.model(input_tile)
|
||||||
|
except Exception as error:
|
||||||
|
print('Error', error)
|
||||||
|
print(f'\tTile {tile_idx}/{tiles_x * tiles_y}')
|
||||||
|
|
||||||
|
# output tile area on total image
|
||||||
|
output_start_x = input_start_x * self.scale
|
||||||
|
output_end_x = input_end_x * self.scale
|
||||||
|
output_start_y = input_start_y * self.scale
|
||||||
|
output_end_y = input_end_y * self.scale
|
||||||
|
|
||||||
|
# output tile area without padding
|
||||||
|
output_start_x_tile = (input_start_x - input_start_x_pad) * self.scale
|
||||||
|
output_end_x_tile = output_start_x_tile + input_tile_width * self.scale
|
||||||
|
output_start_y_tile = (input_start_y - input_start_y_pad) * self.scale
|
||||||
|
output_end_y_tile = output_start_y_tile + input_tile_height * self.scale
|
||||||
|
|
||||||
|
# put tile into output image
|
||||||
|
self.output[:, :, output_start_y:output_end_y,
|
||||||
|
output_start_x:output_end_x] = output_tile[:, :, output_start_y_tile:output_end_y_tile,
|
||||||
|
output_start_x_tile:output_end_x_tile]
|
||||||
|
|
||||||
|
def post_process(self):
|
||||||
|
# remove extra pad
|
||||||
|
if self.mod_scale is not None:
|
||||||
|
_, _, h, w = self.output.size()
|
||||||
|
self.output = self.output[:, :, 0:h - self.mod_pad_h * self.scale, 0:w - self.mod_pad_w * self.scale]
|
||||||
|
# remove prepad
|
||||||
|
if self.pre_pad != 0:
|
||||||
|
_, _, h, w = self.output.size()
|
||||||
|
self.output = self.output[:, :, 0:h - self.pre_pad * self.scale, 0:w - self.pre_pad * self.scale]
|
||||||
|
return self.output
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def enhance(self, img, outscale=None, alpha_upsampler='realesrgan'):
|
||||||
|
h_input, w_input = img.shape[0:2]
|
||||||
|
# img: numpy
|
||||||
|
img = img.astype(np.float32)
|
||||||
|
if np.max(img) > 255: # 16-bit image
|
||||||
|
max_range = 65535
|
||||||
|
print('\tInput is a 16-bit image')
|
||||||
|
else:
|
||||||
|
max_range = 255
|
||||||
|
img = img / max_range
|
||||||
|
if len(img.shape) == 2: # gray image
|
||||||
|
img_mode = 'L'
|
||||||
|
img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
|
||||||
|
elif img.shape[2] == 4: # RGBA image with alpha channel
|
||||||
|
img_mode = 'RGBA'
|
||||||
|
alpha = img[:, :, 3]
|
||||||
|
img = img[:, :, 0:3]
|
||||||
|
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||||
|
if alpha_upsampler == 'realesrgan':
|
||||||
|
alpha = cv2.cvtColor(alpha, cv2.COLOR_GRAY2RGB)
|
||||||
|
else:
|
||||||
|
img_mode = 'RGB'
|
||||||
|
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||||
|
|
||||||
|
# ------------------- process image (without the alpha channel) ------------------- #
|
||||||
|
self.pre_process(img)
|
||||||
|
if self.tile_size > 0:
|
||||||
|
self.tile_process()
|
||||||
|
else:
|
||||||
|
self.process()
|
||||||
|
output_img = self.post_process()
|
||||||
|
output_img = output_img.data.squeeze().float().cpu().clamp_(0, 1).numpy()
|
||||||
|
output_img = np.transpose(output_img[[2, 1, 0], :, :], (1, 2, 0))
|
||||||
|
if img_mode == 'L':
|
||||||
|
output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2GRAY)
|
||||||
|
|
||||||
|
# ------------------- process the alpha channel if necessary ------------------- #
|
||||||
|
if img_mode == 'RGBA':
|
||||||
|
if alpha_upsampler == 'realesrgan':
|
||||||
|
self.pre_process(alpha)
|
||||||
|
if self.tile_size > 0:
|
||||||
|
self.tile_process()
|
||||||
|
else:
|
||||||
|
self.process()
|
||||||
|
output_alpha = self.post_process()
|
||||||
|
output_alpha = output_alpha.data.squeeze().float().cpu().clamp_(0, 1).numpy()
|
||||||
|
output_alpha = np.transpose(output_alpha[[2, 1, 0], :, :], (1, 2, 0))
|
||||||
|
output_alpha = cv2.cvtColor(output_alpha, cv2.COLOR_BGR2GRAY)
|
||||||
|
else:
|
||||||
|
h, w = alpha.shape[0:2]
|
||||||
|
output_alpha = cv2.resize(alpha, (w * self.scale, h * self.scale), interpolation=cv2.INTER_LINEAR)
|
||||||
|
|
||||||
|
# merge the alpha channel
|
||||||
|
output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2BGRA)
|
||||||
|
output_img[:, :, 3] = output_alpha
|
||||||
|
|
||||||
|
# ------------------------------ return ------------------------------ #
|
||||||
|
if max_range == 65535: # 16-bit image
|
||||||
|
output = (output_img * 65535.0).round().astype(np.uint16)
|
||||||
|
else:
|
||||||
|
output = (output_img * 255.0).round().astype(np.uint8)
|
||||||
|
|
||||||
|
if outscale is not None and outscale != float(self.scale):
|
||||||
|
output = cv2.resize(
|
||||||
|
output, (
|
||||||
|
int(w_input * outscale),
|
||||||
|
int(h_input * outscale),
|
||||||
|
), interpolation=cv2.INTER_LANCZOS4)
|
||||||
|
|
||||||
|
return output, img_mode
|
||||||
|
|
||||||
|
|
||||||
|
def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
|
||||||
|
"""Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py
|
||||||
|
"""
|
||||||
|
if model_dir is None:
|
||||||
|
hub_dir = get_dir()
|
||||||
|
model_dir = os.path.join(hub_dir, 'checkpoints')
|
||||||
|
|
||||||
|
os.makedirs(os.path.join(ROOT_DIR, model_dir), exist_ok=True)
|
||||||
|
|
||||||
|
parts = urlparse(url)
|
||||||
|
filename = os.path.basename(parts.path)
|
||||||
|
if file_name is not None:
|
||||||
|
filename = file_name
|
||||||
|
cached_file = os.path.abspath(os.path.join(ROOT_DIR, model_dir, filename))
|
||||||
|
if not os.path.exists(cached_file):
|
||||||
|
print(f'Downloading: "{url}" to {cached_file}\n')
|
||||||
|
download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
|
||||||
|
return cached_file
|
||||||
3
realesrgan/weights/README.md
Normal file
3
realesrgan/weights/README.md
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
# Weights
|
||||||
|
|
||||||
|
Put the downloaded weights to this folder.
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
basicsr
|
basicsr
|
||||||
cv2
|
|
||||||
numpy
|
numpy
|
||||||
|
opencv-python
|
||||||
torch>=1.7
|
torch>=1.7
|
||||||
|
|||||||
17
scripts/pytorch2onnx.py
Normal file
17
scripts/pytorch2onnx.py
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
import torch
|
||||||
|
import torch.onnx
|
||||||
|
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||||
|
|
||||||
|
# An instance of your model
|
||||||
|
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32)
|
||||||
|
model.load_state_dict(torch.load('experiments/pretrained_models/RealESRGAN_x4plus.pth')['params_ema'])
|
||||||
|
# set the train mode to false since we will only run the forward pass.
|
||||||
|
model.train(False)
|
||||||
|
model.cpu().eval()
|
||||||
|
|
||||||
|
# An example input you would normally provide to your model's forward() method
|
||||||
|
x = torch.rand(1, 3, 64, 64)
|
||||||
|
|
||||||
|
# Export the model
|
||||||
|
with torch.no_grad():
|
||||||
|
torch_out = torch.onnx._export(model, x, 'realesrgan-x4.onnx', opset_version=11, export_params=True)
|
||||||
@@ -16,7 +16,7 @@ split_before_expression_after_opening_paren = true
|
|||||||
line_length = 120
|
line_length = 120
|
||||||
multi_line_output = 0
|
multi_line_output = 0
|
||||||
known_standard_library = pkg_resources,setuptools
|
known_standard_library = pkg_resources,setuptools
|
||||||
known_first_party = basicsr # modify it!
|
known_first_party = realesrgan
|
||||||
known_third_party = basicsr,cv2,numpy,torch
|
known_third_party = basicsr,cv2,numpy,torch
|
||||||
no_lines_before = STDLIB,LOCALFOLDER
|
no_lines_before = STDLIB,LOCALFOLDER
|
||||||
default_section = THIRDPARTY
|
default_section = THIRDPARTY
|
||||||
|
|||||||
113
setup.py
Normal file
113
setup.py
Normal file
@@ -0,0 +1,113 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
from setuptools import find_packages, setup
|
||||||
|
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
import time
|
||||||
|
|
||||||
|
version_file = 'realesrgan/version.py'
|
||||||
|
|
||||||
|
|
||||||
|
def readme():
|
||||||
|
with open('README.md', encoding='utf-8') as f:
|
||||||
|
content = f.read()
|
||||||
|
return content
|
||||||
|
|
||||||
|
|
||||||
|
def get_git_hash():
|
||||||
|
|
||||||
|
def _minimal_ext_cmd(cmd):
|
||||||
|
# construct minimal environment
|
||||||
|
env = {}
|
||||||
|
for k in ['SYSTEMROOT', 'PATH', 'HOME']:
|
||||||
|
v = os.environ.get(k)
|
||||||
|
if v is not None:
|
||||||
|
env[k] = v
|
||||||
|
# LANGUAGE is used on win32
|
||||||
|
env['LANGUAGE'] = 'C'
|
||||||
|
env['LANG'] = 'C'
|
||||||
|
env['LC_ALL'] = 'C'
|
||||||
|
out = subprocess.Popen(cmd, stdout=subprocess.PIPE, env=env).communicate()[0]
|
||||||
|
return out
|
||||||
|
|
||||||
|
try:
|
||||||
|
out = _minimal_ext_cmd(['git', 'rev-parse', 'HEAD'])
|
||||||
|
sha = out.strip().decode('ascii')
|
||||||
|
except OSError:
|
||||||
|
sha = 'unknown'
|
||||||
|
|
||||||
|
return sha
|
||||||
|
|
||||||
|
|
||||||
|
def get_hash():
|
||||||
|
if os.path.exists('.git'):
|
||||||
|
sha = get_git_hash()[:7]
|
||||||
|
elif os.path.exists(version_file):
|
||||||
|
try:
|
||||||
|
from facexlib.version import __version__
|
||||||
|
sha = __version__.split('+')[-1]
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError('Unable to get git version')
|
||||||
|
else:
|
||||||
|
sha = 'unknown'
|
||||||
|
|
||||||
|
return sha
|
||||||
|
|
||||||
|
|
||||||
|
def write_version_py():
|
||||||
|
content = """# GENERATED VERSION FILE
|
||||||
|
# TIME: {}
|
||||||
|
__version__ = '{}'
|
||||||
|
__gitsha__ = '{}'
|
||||||
|
version_info = ({})
|
||||||
|
"""
|
||||||
|
sha = get_hash()
|
||||||
|
with open('VERSION', 'r') as f:
|
||||||
|
SHORT_VERSION = f.read().strip()
|
||||||
|
VERSION_INFO = ', '.join([x if x.isdigit() else f'"{x}"' for x in SHORT_VERSION.split('.')])
|
||||||
|
|
||||||
|
version_file_str = content.format(time.asctime(), SHORT_VERSION, sha, VERSION_INFO)
|
||||||
|
with open(version_file, 'w') as f:
|
||||||
|
f.write(version_file_str)
|
||||||
|
|
||||||
|
|
||||||
|
def get_version():
|
||||||
|
with open(version_file, 'r') as f:
|
||||||
|
exec(compile(f.read(), version_file, 'exec'))
|
||||||
|
return locals()['__version__']
|
||||||
|
|
||||||
|
|
||||||
|
def get_requirements(filename='requirements.txt'):
|
||||||
|
here = os.path.dirname(os.path.realpath(__file__))
|
||||||
|
with open(os.path.join(here, filename), 'r') as f:
|
||||||
|
requires = [line.replace('\n', '') for line in f.readlines()]
|
||||||
|
return requires
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
write_version_py()
|
||||||
|
setup(
|
||||||
|
name='realesrgan',
|
||||||
|
version=get_version(),
|
||||||
|
description='Real-ESRGAN aims at developing Practical Algorithms for General Image Restoration',
|
||||||
|
long_description=readme(),
|
||||||
|
long_description_content_type='text/markdown',
|
||||||
|
author='Xintao Wang',
|
||||||
|
author_email='xintao.wang@outlook.com',
|
||||||
|
keywords='computer vision, pytorch, image restoration, super-resolution, esrgan, real-esrgan',
|
||||||
|
url='https://github.com/xinntao/Real-ESRGAN',
|
||||||
|
include_package_data=True,
|
||||||
|
packages=find_packages(exclude=('options', 'datasets', 'experiments', 'results', 'tb_logger', 'wandb')),
|
||||||
|
classifiers=[
|
||||||
|
'Development Status :: 4 - Beta',
|
||||||
|
'License :: OSI Approved :: Apache Software License',
|
||||||
|
'Operating System :: OS Independent',
|
||||||
|
'Programming Language :: Python :: 3',
|
||||||
|
'Programming Language :: Python :: 3.7',
|
||||||
|
'Programming Language :: Python :: 3.8',
|
||||||
|
],
|
||||||
|
license='BSD-3-Clause License',
|
||||||
|
setup_requires=['cython', 'numpy'],
|
||||||
|
install_requires=get_requirements(),
|
||||||
|
zip_safe=False)
|
||||||
Reference in New Issue
Block a user