Compare commits
42 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9976a34454 | ||
|
|
424a09457b | ||
|
|
52f77e74a8 | ||
|
|
bfa4678bef | ||
|
|
68f9f2445e | ||
|
|
7840a3d16a | ||
|
|
b28958cdf2 | ||
|
|
667e34e7ca | ||
|
|
978def19a6 | ||
|
|
a7153c7fce | ||
|
|
00116244cb | ||
|
|
571b89257a | ||
|
|
bed7df7d99 | ||
|
|
fb3ff055e4 | ||
|
|
9ef97853f9 | ||
|
|
58fea8db69 | ||
|
|
3c6cf5290e | ||
|
|
64ad194dda | ||
|
|
5745599813 | ||
|
|
3ce0c97e89 | ||
|
|
13186ac2c2 | ||
|
|
4356ba0578 | ||
|
|
32a4fa1772 | ||
|
|
18ebf723f2 | ||
|
|
1f83ce5432 | ||
|
|
bef5e3cabd | ||
|
|
064df9956b | ||
|
|
52eab16d11 | ||
|
|
94ae626008 | ||
|
|
9baa0b3d00 | ||
|
|
f932289af1 | ||
|
|
f59a0c66ec | ||
|
|
935993a040 | ||
|
|
74fcfea286 | ||
|
|
da1e1ee805 | ||
|
|
1d8745eb61 | ||
|
|
c94d2de155 | ||
|
|
f4297a70af | ||
|
|
492a829c14 | ||
|
|
0573f32dd0 | ||
|
|
8454fd2c7a | ||
|
|
ad2ff81725 |
34
.github/workflows/no-response.yml
vendored
Normal file
34
.github/workflows/no-response.yml
vendored
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
name: No Response
|
||||||
|
|
||||||
|
# Modified from: https://raw.githubusercontent.com/github/docs/main/.github/workflows/no-response.yaml
|
||||||
|
|
||||||
|
# **What it does**: Closes issues that don't have enough information to be
|
||||||
|
# actionable.
|
||||||
|
# **Why we have it**: To remove the need for maintainers to remember to check
|
||||||
|
# back on issues periodically to see if contributors have
|
||||||
|
# responded.
|
||||||
|
# **Who does it impact**: Everyone that works on docs or docs-internal.
|
||||||
|
|
||||||
|
on:
|
||||||
|
issue_comment:
|
||||||
|
types: [created]
|
||||||
|
|
||||||
|
schedule:
|
||||||
|
# Schedule for five minutes after the hour every hour
|
||||||
|
- cron: '5 * * * *'
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
noResponse:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: lee-dohm/no-response@v0.5.0
|
||||||
|
with:
|
||||||
|
token: ${{ github.token }}
|
||||||
|
closeComment: >
|
||||||
|
This issue has been automatically closed because there has been no response
|
||||||
|
to our request for more information from the original author. With only the
|
||||||
|
information that is currently in the issue, we don't have enough information
|
||||||
|
to take action. Please reach out if you have or find the answers we need so
|
||||||
|
that we can investigate further.
|
||||||
|
If you still have questions, please improve your description and re-open it.
|
||||||
|
Thanks :-)
|
||||||
33
.github/workflows/publish-pip.yml
vendored
Normal file
33
.github/workflows/publish-pip.yml
vendored
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
name: PyPI Publish
|
||||||
|
|
||||||
|
on: push
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
build-n-publish:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
if: startsWith(github.event.ref, 'refs/tags')
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v2
|
||||||
|
- name: Set up Python 3.8
|
||||||
|
uses: actions/setup-python@v1
|
||||||
|
with:
|
||||||
|
python-version: 3.8
|
||||||
|
- name: Upgrade pip
|
||||||
|
run: pip install pip --upgrade
|
||||||
|
- name: Install PyTorch (cpu)
|
||||||
|
run: pip install torch==1.7.0+cpu torchvision==0.8.1+cpu -f https://download.pytorch.org/whl/torch_stable.html
|
||||||
|
- name: Install dependencies
|
||||||
|
run: |
|
||||||
|
pip install basicsr
|
||||||
|
pip install facexlib
|
||||||
|
pip install gfpgan
|
||||||
|
pip install -r requirements.txt
|
||||||
|
- name: Build and install
|
||||||
|
run: rm -rf .eggs && pip install -e .
|
||||||
|
- name: Build for distribution
|
||||||
|
run: python setup.py sdist bdist_wheel
|
||||||
|
- name: Publish distribution to PyPI
|
||||||
|
uses: pypa/gh-action-pypi-publish@master
|
||||||
|
with:
|
||||||
|
password: ${{ secrets.PYPI_API_TOKEN }}
|
||||||
4
.github/workflows/pylint.yml
vendored
4
.github/workflows/pylint.yml
vendored
@@ -26,5 +26,5 @@ jobs:
|
|||||||
- name: Lint
|
- name: Lint
|
||||||
run: |
|
run: |
|
||||||
flake8 .
|
flake8 .
|
||||||
isort --check-only --diff data/ models/ inference_realesrgan.py
|
isort --check-only --diff realesrgan/ scripts/ inference_realesrgan.py setup.py
|
||||||
yapf -r -d data/ models/ inference_realesrgan.py
|
yapf -r -d realesrgan/ scripts/ inference_realesrgan.py setup.py
|
||||||
|
|||||||
11
.gitignore
vendored
11
.gitignore
vendored
@@ -1,4 +1,13 @@
|
|||||||
.vscode
|
# ignored folders
|
||||||
|
datasets/*
|
||||||
|
experiments/*
|
||||||
|
results/*
|
||||||
|
tb_logger/*
|
||||||
|
wandb/*
|
||||||
|
tmp/*
|
||||||
|
realesrgan/weights/*
|
||||||
|
|
||||||
|
version.py
|
||||||
|
|
||||||
# Byte-compiled / optimized / DLL files
|
# Byte-compiled / optimized / DLL files
|
||||||
__pycache__/
|
__pycache__/
|
||||||
|
|||||||
19
.vscode/settings.json
vendored
Normal file
19
.vscode/settings.json
vendored
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
{
|
||||||
|
"files.trimTrailingWhitespace": true,
|
||||||
|
"editor.wordWrap": "on",
|
||||||
|
"editor.rulers": [
|
||||||
|
80,
|
||||||
|
120
|
||||||
|
],
|
||||||
|
"editor.renderWhitespace": "all",
|
||||||
|
"editor.renderControlCharacters": true,
|
||||||
|
"python.formatting.provider": "yapf",
|
||||||
|
"python.formatting.yapfArgs": [
|
||||||
|
"--style",
|
||||||
|
"{BASED_ON_STYLE = pep8, BLANK_LINE_BEFORE_NESTED_CLASS_OR_DEF = true, SPLIT_BEFORE_EXPRESSION_AFTER_OPENING_PAREN = true, COLUMN_LIMIT = 120}"
|
||||||
|
],
|
||||||
|
"python.linting.flake8Enabled": true,
|
||||||
|
"python.linting.flake8Args": [
|
||||||
|
"max-line-length=120"
|
||||||
|
],
|
||||||
|
}
|
||||||
9
FAQ.md
Normal file
9
FAQ.md
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
# FAQ
|
||||||
|
|
||||||
|
1. **What is the difference of `--netscale` and `outscale`?**
|
||||||
|
|
||||||
|
A: TODO.
|
||||||
|
|
||||||
|
1. **How to select models?**
|
||||||
|
|
||||||
|
A: TODO.
|
||||||
29
LICENSE
Normal file
29
LICENSE
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
BSD 3-Clause License
|
||||||
|
|
||||||
|
Copyright (c) 2021, Xintao Wang
|
||||||
|
All rights reserved.
|
||||||
|
|
||||||
|
Redistribution and use in source and binary forms, with or without
|
||||||
|
modification, are permitted provided that the following conditions are met:
|
||||||
|
|
||||||
|
1. Redistributions of source code must retain the above copyright notice, this
|
||||||
|
list of conditions and the following disclaimer.
|
||||||
|
|
||||||
|
2. Redistributions in binary form must reproduce the above copyright notice,
|
||||||
|
this list of conditions and the following disclaimer in the documentation
|
||||||
|
and/or other materials provided with the distribution.
|
||||||
|
|
||||||
|
3. Neither the name of the copyright holder nor the names of its
|
||||||
|
contributors may be used to endorse or promote products derived from
|
||||||
|
this software without specific prior written permission.
|
||||||
|
|
||||||
|
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||||
|
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||||
|
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||||
|
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||||
|
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||||
|
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||||
|
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||||
|
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||||
|
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
|
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
8
MANIFEST.in
Normal file
8
MANIFEST.in
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
include assets/*
|
||||||
|
include inputs/*
|
||||||
|
include scripts/*.py
|
||||||
|
include inference_realesrgan.py
|
||||||
|
include VERSION
|
||||||
|
include LICENSE
|
||||||
|
include requirements.txt
|
||||||
|
include realesrgan/weights/README.md
|
||||||
47
README.md
47
README.md
@@ -1,17 +1,31 @@
|
|||||||
# Real-ESRGAN
|
# Real-ESRGAN
|
||||||
|
|
||||||
[](https://github.com/xinntao/Real-ESRGAN/releases)
|
[](https://github.com/xinntao/Real-ESRGAN/releases)
|
||||||
|
[](https://pypi.org/project/realesrgan/)
|
||||||
[](https://github.com/xinntao/Real-ESRGAN/issues)
|
[](https://github.com/xinntao/Real-ESRGAN/issues)
|
||||||
[](https://github.com/xinntao/Real-ESRGAN/blob/master/LICENSE)
|
[](https://github.com/xinntao/Real-ESRGAN/blob/master/LICENSE)
|
||||||
[](https://github.com/xinntao/Real-ESRGAN/blob/master/.github/workflows/pylint.yml)
|
[](https://github.com/xinntao/Real-ESRGAN/blob/master/.github/workflows/pylint.yml)
|
||||||
|
[](https://github.com/xinntao/Real-ESRGAN/blob/master/.github/workflows/publish-pip.yml)
|
||||||
|
|
||||||
1. [Colab Demo](https://colab.research.google.com/drive/1sVsoBd9AjckIXThgtZhGrHRfFI6UUYOo) for Real-ESRGAN <a href="https://colab.research.google.com/drive/1k2Zod6kSHEvraybHl50Lys0LerhyTMCo?usp=sharing"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="google colab logo"></a>.
|
1. [Colab Demo](https://colab.research.google.com/drive/1k2Zod6kSHEvraybHl50Lys0LerhyTMCo?usp=sharing) for Real-ESRGAN <a href="https://colab.research.google.com/drive/1k2Zod6kSHEvraybHl50Lys0LerhyTMCo?usp=sharing"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="google colab logo"></a>.
|
||||||
2. [Portable Windows executable file](https://github.com/xinntao/Real-ESRGAN/releases). You can find more information [here](#Portable-executable-files).
|
2. Portable [Windows](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.2/realesrgan-ncnn-vulkan-20210801-windows.zip) / [Linux](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.2/realesrgan-ncnn-vulkan-20210801-ubuntu.zip) / [MacOS](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.2/realesrgan-ncnn-vulkan-20210801-macos.zip) **executable files for Intel/AMD/Nvidia GPU**. You can find more information [here](#Portable-executable-files).
|
||||||
|
|
||||||
Real-ESRGAN aims at developing **Practical Algorithms for General Image Restoration**.<br>
|
Real-ESRGAN aims at developing **Practical Algorithms for General Image Restoration**.<br>
|
||||||
We extend the powerful ESRGAN to a practical restoration application (namely, Real-ESRGAN), which is trained with pure synthetic data.
|
We extend the powerful ESRGAN to a practical restoration application (namely, Real-ESRGAN), which is trained with pure synthetic data.
|
||||||
|
|
||||||
:triangular_flag_on_post: The training codes have been released. A detailed guide can be found in [Training.md](Training.md).
|
:question: Frequently Asked Questions can be found in [FAQ.md](FAQ.md).
|
||||||
|
|
||||||
|
:triangular_flag_on_post: **Updates**
|
||||||
|
- :white_check_mark: Integrate [GFPGAN](https://github.com/TencentARC/GFPGAN) to support **face enhancement**.
|
||||||
|
- :white_check_mark: Integrated to [Huggingface Spaces](https://huggingface.co/spaces) with [Gradio](https://github.com/gradio-app/gradio). See [Gradio Web Demo](https://huggingface.co/spaces/akhaliq/Real-ESRGAN).
|
||||||
|
- :white_check_mark: Support arbitrary scale with `--outscale` (It actually further resizes outputs with `LANCZOS4`). Add *RealESRGAN_x2plus.pth* model.
|
||||||
|
- :white_check_mark: [The inference code](inference_realesrgan.py) supports: 1) **tile** options; 2) images with **alpha channel**; 3) **gray** images; 4) **16-bit** images.
|
||||||
|
- :white_check_mark: The training codes have been released. A detailed guide can be found in [Training.md](Training.md).
|
||||||
|
|
||||||
|
If Real-ESRGAN is helpful in your photos/projects, please help to :star: this repo. Thanks:blush: <br>
|
||||||
|
Other recommended projects:   :arrow_forward: [GFPGAN](https://github.com/TencentARC/GFPGAN)   :arrow_forward: [BasicSR](https://github.com/xinntao/BasicSR)   :arrow_forward: [facexlib](https://github.com/xinntao/facexlib)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
### :book: Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data
|
### :book: Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data
|
||||||
|
|
||||||
@@ -38,7 +52,7 @@ Here is a TODO list in the near future:
|
|||||||
|
|
||||||
- [ ] optimize for human faces
|
- [ ] optimize for human faces
|
||||||
- [ ] optimize for texts
|
- [ ] optimize for texts
|
||||||
- [ ] optimize for animation images
|
- [ ] optimize for anime images [in progress]
|
||||||
- [ ] support more scales
|
- [ ] support more scales
|
||||||
- [ ] support controllable restoration strength
|
- [ ] support controllable restoration strength
|
||||||
|
|
||||||
@@ -49,16 +63,24 @@ If you have some images that Real-ESRGAN could not well restored, please also op
|
|||||||
|
|
||||||
### Portable executable files
|
### Portable executable files
|
||||||
|
|
||||||
You can download **Windows executable files** from https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN-ncnn-vulkan.zip
|
You can download [Windows](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.2/realesrgan-ncnn-vulkan-20210801-windows.zip) / [Linux](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.2/realesrgan-ncnn-vulkan-20210801-ubuntu.zip) / [MacOS](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.2/realesrgan-ncnn-vulkan-20210801-macos.zip) **executable files for Intel/AMD/Nvidia GPU**.
|
||||||
|
|
||||||
This executable file is **portable** and includes all the binaries and models required. No CUDA or PyTorch environment is needed.<br>
|
This executable file is **portable** and includes all the binaries and models required. No CUDA or PyTorch environment is needed.<br>
|
||||||
|
|
||||||
You can simply run the following command:
|
You can simply run the following command (the Windows example, more information is in the README.md of each executable files):
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
./realesrgan-ncnn-vulkan.exe -i input.jpg -o output.png
|
./realesrgan-ncnn-vulkan.exe -i input.jpg -o output.png
|
||||||
```
|
```
|
||||||
|
|
||||||
|
We have provided three models:
|
||||||
|
|
||||||
|
1. realesrgan-x4plus (default)
|
||||||
|
2. realesrnet-x4plus
|
||||||
|
3. esrgan-x4
|
||||||
|
|
||||||
|
You can use the `-n` argument for other models, for example, `./realesrgan-ncnn-vulkan.exe -i input.jpg -o output.png -n realesrnet-x4plus`
|
||||||
|
|
||||||
Note that it may introduce block inconsistency (and also generate slightly different results from the PyTorch implementation), because this executable file first crops the input image into several tiles, and then processes them separately, finally stitches together.
|
Note that it may introduce block inconsistency (and also generate slightly different results from the PyTorch implementation), because this executable file first crops the input image into several tiles, and then processes them separately, finally stitches together.
|
||||||
|
|
||||||
This executable file is based on the wonderful [Tencent/ncnn](https://github.com/Tencent/ncnn) and [realsr-ncnn-vulkan](https://github.com/nihui/realsr-ncnn-vulkan) by [nihui](https://github.com/nihui).
|
This executable file is based on the wonderful [Tencent/ncnn](https://github.com/Tencent/ncnn) and [realsr-ncnn-vulkan](https://github.com/nihui/realsr-ncnn-vulkan) by [nihui](https://github.com/nihui).
|
||||||
@@ -85,7 +107,11 @@ This executable file is based on the wonderful [Tencent/ncnn](https://github.com
|
|||||||
# Install basicsr - https://github.com/xinntao/BasicSR
|
# Install basicsr - https://github.com/xinntao/BasicSR
|
||||||
# We use BasicSR for both training and inference
|
# We use BasicSR for both training and inference
|
||||||
pip install basicsr
|
pip install basicsr
|
||||||
|
# facexlib and gfpgan are for face enhancement
|
||||||
|
pip install facexlib
|
||||||
|
pip install gfpgan
|
||||||
pip install -r requirements.txt
|
pip install -r requirements.txt
|
||||||
|
python setup.py develop
|
||||||
```
|
```
|
||||||
|
|
||||||
## :zap: Quick Inference
|
## :zap: Quick Inference
|
||||||
@@ -101,11 +127,18 @@ wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_
|
|||||||
Inference!
|
Inference!
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python inference_realesrgan.py --model_path experiments/pretrained_models/RealESRGAN_x4plus.pth --input inputs
|
python inference_realesrgan.py --model_path experiments/pretrained_models/RealESRGAN_x4plus.pth --input inputs --face_enhance
|
||||||
```
|
```
|
||||||
|
|
||||||
Results are in the `results` folder
|
Results are in the `results` folder
|
||||||
|
|
||||||
|
## :european_castle: Model Zoo
|
||||||
|
|
||||||
|
- [RealESRGAN-x4plus](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth)
|
||||||
|
- [RealESRNet-x4plus](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth)
|
||||||
|
- [RealESRGAN-x2plus](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth)
|
||||||
|
- [official ESRGAN-x4](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth)
|
||||||
|
|
||||||
## :computer: Training
|
## :computer: Training
|
||||||
|
|
||||||
A detailed guide can be found in [Training.md](Training.md).
|
A detailed guide can be found in [Training.md](Training.md).
|
||||||
|
|||||||
17
Training.md
17
Training.md
@@ -1,7 +1,7 @@
|
|||||||
# :computer: How to Train Real-ESRGAN
|
# :computer: How to Train Real-ESRGAN
|
||||||
|
|
||||||
The training codes have been released. <br>
|
The training codes have been released. <br>
|
||||||
Note that the codes have a lot of refactoring. So there may be some bugs/performance drops. Welcome to report issues and I will also retrain the models.
|
Note that the codes have a lot of refactoring. So there may be some bugs/performance drops. Welcome to report bugs/issues.
|
||||||
|
|
||||||
## Overview
|
## Overview
|
||||||
|
|
||||||
@@ -34,14 +34,17 @@ DF2K_HR_sub/000001_s003.png
|
|||||||
|
|
||||||
## Train Real-ESRNet
|
## Train Real-ESRNet
|
||||||
|
|
||||||
1. Download pre-trained model [ESRGAN](https://drive.google.com/file/d/1b3_bWZTjNO3iL2js1yWkJfjZykcQgvzT/view?usp=sharing) into `experiments/pretrained_models`.
|
1. Download pre-trained model [ESRGAN](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth) into `experiments/pretrained_models`.
|
||||||
|
```bash
|
||||||
|
wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth -P experiments/pretrained_models
|
||||||
|
```
|
||||||
1. Modify the content in the option file `options/train_realesrnet_x4plus.yml` accordingly:
|
1. Modify the content in the option file `options/train_realesrnet_x4plus.yml` accordingly:
|
||||||
```yml
|
```yml
|
||||||
train:
|
train:
|
||||||
name: DF2K+OST
|
name: DF2K+OST
|
||||||
type: RealESRGANDataset
|
type: RealESRGANDataset
|
||||||
dataroot_gt: datasets/DF2K # modify to the root path of your folder
|
dataroot_gt: datasets/DF2K # modify to the root path of your folder
|
||||||
meta_info: data/meta_info/meta_info_DF2Kmultiscale+OST_sub.txt # modify to your own generate meta info txt
|
meta_info: realesrgan/meta_info/meta_info_DF2Kmultiscale+OST_sub.txt # modify to your own generate meta info txt
|
||||||
io_backend:
|
io_backend:
|
||||||
type: disk
|
type: disk
|
||||||
```
|
```
|
||||||
@@ -73,12 +76,12 @@ DF2K_HR_sub/000001_s003.png
|
|||||||
1. Before the formal training, you may run in the `--debug` mode to see whether everything is OK. We use four GPUs for training:
|
1. Before the formal training, you may run in the `--debug` mode to see whether everything is OK. We use four GPUs for training:
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0,1,2,3 \
|
CUDA_VISIBLE_DEVICES=0,1,2,3 \
|
||||||
python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 train.py -opt options/train_realesrnet_x4plus.yml --launcher pytorch --debug
|
python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 realesrgan/train.py -opt options/train_realesrnet_x4plus.yml --launcher pytorch --debug
|
||||||
```
|
```
|
||||||
1. The formal training. We use four GPUs for training. We use the `--auto_resume` argument to automatically resume the training if necessary.
|
1. The formal training. We use four GPUs for training. We use the `--auto_resume` argument to automatically resume the training if necessary.
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0,1,2,3 \
|
CUDA_VISIBLE_DEVICES=0,1,2,3 \
|
||||||
python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 train.py -opt options/train_realesrnet_x4plus.yml --launcher pytorch --auto_resume
|
python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 realesrgan/train.py -opt options/train_realesrnet_x4plus.yml --launcher pytorch --auto_resume
|
||||||
```
|
```
|
||||||
|
|
||||||
## Train Real-ESRGAN
|
## Train Real-ESRGAN
|
||||||
@@ -88,10 +91,10 @@ DF2K_HR_sub/000001_s003.png
|
|||||||
1. Before the formal training, you may run in the `--debug` mode to see whether everything is OK. We use four GPUs for training:
|
1. Before the formal training, you may run in the `--debug` mode to see whether everything is OK. We use four GPUs for training:
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0,1,2,3 \
|
CUDA_VISIBLE_DEVICES=0,1,2,3 \
|
||||||
python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 train.py -opt options/train_realesrgan_x4plus.yml --launcher pytorch --debug
|
python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 realesrgan/train.py -opt options/train_realesrgan_x4plus.yml --launcher pytorch --debug
|
||||||
```
|
```
|
||||||
1. The formal training. We use four GPUs for training. We use the `--auto_resume` argument to automatically resume the training if necessary.
|
1. The formal training. We use four GPUs for training. We use the `--auto_resume` argument to automatically resume the training if necessary.
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0,1,2,3 \
|
CUDA_VISIBLE_DEVICES=0,1,2,3 \
|
||||||
python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 train.py -opt options/train_realesrgan_x4plus.yml --launcher pytorch --auto_resume
|
python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 realesrgan/train.py -opt options/train_realesrgan_x4plus.yml --launcher pytorch --auto_resume
|
||||||
```
|
```
|
||||||
|
|||||||
BIN
assets/teaser-text.png
Normal file
BIN
assets/teaser-text.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 546 KiB |
@@ -1,67 +1,98 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import cv2
|
import cv2
|
||||||
import glob
|
import glob
|
||||||
import numpy as np
|
|
||||||
import os
|
import os
|
||||||
import torch
|
|
||||||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
from realesrgan import RealESRGANer
|
||||||
from torch.nn import functional as F
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('--model_path', type=str, default='experiments/pretrained_models/RealESRGAN_x4plus.pth')
|
parser.add_argument('--input', type=str, default='inputs', help='Input image or folder')
|
||||||
parser.add_argument('--scale', type=int, default=4)
|
parser.add_argument(
|
||||||
parser.add_argument('--input', type=str, default='inputs', help='input image or folder')
|
'--model_path',
|
||||||
|
type=str,
|
||||||
|
default='experiments/pretrained_models/RealESRGAN_x4plus.pth',
|
||||||
|
help='Path to the pre-trained model')
|
||||||
|
parser.add_argument('--output', type=str, default='results', help='Output folder')
|
||||||
|
parser.add_argument('--netscale', type=int, default=4, help='Upsample scale factor of the network')
|
||||||
|
parser.add_argument('--outscale', type=float, default=4, help='The final upsampling scale of the image')
|
||||||
|
parser.add_argument('--suffix', type=str, default='out', help='Suffix of the restored image')
|
||||||
|
parser.add_argument('--tile', type=int, default=800, help='Tile size, 0 for no tile during testing')
|
||||||
|
parser.add_argument('--tile_pad', type=int, default=10, help='Tile padding')
|
||||||
|
parser.add_argument('--pre_pad', type=int, default=0, help='Pre padding size at each border')
|
||||||
|
parser.add_argument('--face_enhance', action='store_true', help='Use GFPGAN to enhance face')
|
||||||
|
parser.add_argument('--half', action='store_true', help='Use half precision during inference')
|
||||||
|
parser.add_argument(
|
||||||
|
'--alpha_upsampler',
|
||||||
|
type=str,
|
||||||
|
default='realesrgan',
|
||||||
|
help='The upsampler for the alpha channels. Options: realesrgan | bicubic')
|
||||||
|
parser.add_argument(
|
||||||
|
'--ext',
|
||||||
|
type=str,
|
||||||
|
default='auto',
|
||||||
|
help='Image extension. Options: auto | jpg | png, auto means using the same extension as inputs')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
upsampler = RealESRGANer(
|
||||||
# set up model
|
scale=args.netscale,
|
||||||
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=args.scale)
|
model_path=args.model_path,
|
||||||
loadnet = torch.load(args.model_path)
|
tile=args.tile,
|
||||||
model.load_state_dict(loadnet['params_ema'], strict=True)
|
tile_pad=args.tile_pad,
|
||||||
model.eval()
|
pre_pad=args.pre_pad,
|
||||||
model = model.to(device)
|
half=args.half)
|
||||||
|
|
||||||
os.makedirs('results/', exist_ok=True)
|
if args.face_enhance:
|
||||||
for idx, path in enumerate(sorted(glob.glob(os.path.join(args.input, '*')))):
|
from gfpgan import GFPGANer
|
||||||
imgname = os.path.splitext(os.path.basename(path))[0]
|
face_enhancer = GFPGANer(
|
||||||
|
model_path='https://github.com/TencentARC/GFPGAN/releases/download/v0.2.0/GFPGANCleanv1-NoCE-C2.pth',
|
||||||
|
upscale=args.outscale,
|
||||||
|
arch='clean',
|
||||||
|
channel_multiplier=2,
|
||||||
|
bg_upsampler=upsampler)
|
||||||
|
os.makedirs(args.output, exist_ok=True)
|
||||||
|
|
||||||
|
if os.path.isfile(args.input):
|
||||||
|
paths = [args.input]
|
||||||
|
else:
|
||||||
|
paths = sorted(glob.glob(os.path.join(args.input, '*')))
|
||||||
|
|
||||||
|
for idx, path in enumerate(paths):
|
||||||
|
imgname, extension = os.path.splitext(os.path.basename(path))
|
||||||
print('Testing', idx, imgname)
|
print('Testing', idx, imgname)
|
||||||
# read image
|
|
||||||
img = cv2.imread(path, cv2.IMREAD_COLOR).astype(np.float32) / 255.
|
|
||||||
img = torch.from_numpy(np.transpose(img[:, :, [2, 1, 0]], (2, 0, 1))).float()
|
|
||||||
img = img.unsqueeze(0).to(device)
|
|
||||||
|
|
||||||
if args.scale == 2:
|
img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
|
||||||
mod_scale = 2
|
if len(img.shape) == 3 and img.shape[2] == 4:
|
||||||
elif args.scale == 1:
|
img_mode = 'RGBA'
|
||||||
mod_scale = 4
|
|
||||||
else:
|
else:
|
||||||
mod_scale = None
|
img_mode = None
|
||||||
if mod_scale is not None:
|
|
||||||
h_pad, w_pad = 0, 0
|
h, w = img.shape[0:2]
|
||||||
_, _, h, w = img.size()
|
if max(h, w) > 1000 and args.netscale == 4:
|
||||||
if (h % mod_scale != 0):
|
import warnings
|
||||||
h_pad = (mod_scale - h % mod_scale)
|
warnings.warn('The input image is large, try X2 model for better performace.')
|
||||||
if (w % mod_scale != 0):
|
if max(h, w) < 500 and args.netscale == 2:
|
||||||
w_pad = (mod_scale - w % mod_scale)
|
import warnings
|
||||||
img = F.pad(img, (0, w_pad, 0, h_pad), 'reflect')
|
warnings.warn('The input image is small, try X4 model for better performace.')
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# inference
|
if args.face_enhance:
|
||||||
with torch.no_grad():
|
_, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True)
|
||||||
output = model(img)
|
else:
|
||||||
# remove extra pad
|
output, _ = upsampler.enhance(img, outscale=args.outscale)
|
||||||
if mod_scale is not None:
|
|
||||||
_, _, h, w = output.size()
|
|
||||||
output = output[:, :, 0:h - h_pad, 0:w - w_pad]
|
|
||||||
# save image
|
|
||||||
output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
|
|
||||||
output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0))
|
|
||||||
output = (output * 255.0).round().astype(np.uint8)
|
|
||||||
cv2.imwrite(f'results/{imgname}_RealESRGAN.png', output)
|
|
||||||
except Exception as error:
|
except Exception as error:
|
||||||
print('Error', error)
|
print('Error', error)
|
||||||
|
print('If you encounter CUDA out of memory, try to set --tile with a smaller number.')
|
||||||
|
else:
|
||||||
|
if args.ext == 'auto':
|
||||||
|
extension = extension[1:]
|
||||||
|
else:
|
||||||
|
extension = args.ext
|
||||||
|
if img_mode == 'RGBA': # RGBA images should be saved in png format
|
||||||
|
extension = 'png'
|
||||||
|
save_path = os.path.join(args.output, f'{imgname}_{args.suffix}.{extension}')
|
||||||
|
cv2.imwrite(save_path, output)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|||||||
BIN
inputs/tree_alpha_16bit.png
Normal file
BIN
inputs/tree_alpha_16bit.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 373 KiB |
BIN
inputs/wolf_gray.jpg
Normal file
BIN
inputs/wolf_gray.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 52 KiB |
189
options/finetune_realesrgan_x4plus.yml
Normal file
189
options/finetune_realesrgan_x4plus.yml
Normal file
@@ -0,0 +1,189 @@
|
|||||||
|
# general settings
|
||||||
|
name: finetune_RealESRGANx4plus_400k_B12G4
|
||||||
|
model_type: RealESRGANModel
|
||||||
|
scale: 4
|
||||||
|
num_gpu: 4
|
||||||
|
manual_seed: 0
|
||||||
|
|
||||||
|
# ----------------- options for synthesizing training data in RealESRGANModel ----------------- #
|
||||||
|
# USM the ground-truth
|
||||||
|
l1_gt_usm: True
|
||||||
|
percep_gt_usm: True
|
||||||
|
gan_gt_usm: False
|
||||||
|
|
||||||
|
# the first degradation process
|
||||||
|
resize_prob: [0.2, 0.7, 0.1] # up, down, keep
|
||||||
|
resize_range: [0.15, 1.5]
|
||||||
|
gaussian_noise_prob: 0.5
|
||||||
|
noise_range: [1, 30]
|
||||||
|
poisson_scale_range: [0.05, 3]
|
||||||
|
gray_noise_prob: 0.4
|
||||||
|
jpeg_range: [30, 95]
|
||||||
|
|
||||||
|
# the second degradation process
|
||||||
|
second_blur_prob: 0.8
|
||||||
|
resize_prob2: [0.3, 0.4, 0.3] # up, down, keep
|
||||||
|
resize_range2: [0.3, 1.2]
|
||||||
|
gaussian_noise_prob2: 0.5
|
||||||
|
noise_range2: [1, 25]
|
||||||
|
poisson_scale_range2: [0.05, 2.5]
|
||||||
|
gray_noise_prob2: 0.4
|
||||||
|
jpeg_range2: [30, 95]
|
||||||
|
|
||||||
|
gt_size: 256
|
||||||
|
queue_size: 180
|
||||||
|
|
||||||
|
# dataset and data loader settings
|
||||||
|
datasets:
|
||||||
|
train:
|
||||||
|
name: DF2K+OST
|
||||||
|
type: RealESRGANDataset
|
||||||
|
dataroot_gt: datasets/DF2K
|
||||||
|
meta_info: datasets/DF2K/meta_info/meta_info_DF2Kmultiscale+OST_sub.txt
|
||||||
|
io_backend:
|
||||||
|
type: disk
|
||||||
|
|
||||||
|
blur_kernel_size: 21
|
||||||
|
kernel_list: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
|
||||||
|
kernel_prob: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
|
||||||
|
sinc_prob: 0.1
|
||||||
|
blur_sigma: [0.2, 3]
|
||||||
|
betag_range: [0.5, 4]
|
||||||
|
betap_range: [1, 2]
|
||||||
|
|
||||||
|
blur_kernel_size2: 21
|
||||||
|
kernel_list2: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
|
||||||
|
kernel_prob2: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
|
||||||
|
sinc_prob2: 0.1
|
||||||
|
blur_sigma2: [0.2, 1.5]
|
||||||
|
betag_range2: [0.5, 4]
|
||||||
|
betap_range2: [1, 2]
|
||||||
|
|
||||||
|
final_sinc_prob: 0.8
|
||||||
|
|
||||||
|
gt_size: 256
|
||||||
|
use_hflip: True
|
||||||
|
use_rot: False
|
||||||
|
|
||||||
|
# data loader
|
||||||
|
use_shuffle: true
|
||||||
|
num_worker_per_gpu: 5
|
||||||
|
batch_size_per_gpu: 12
|
||||||
|
dataset_enlarge_ratio: 1
|
||||||
|
prefetch_mode: ~
|
||||||
|
|
||||||
|
# Uncomment these for validation
|
||||||
|
# val:
|
||||||
|
# name: validation
|
||||||
|
# type: PairedImageDataset
|
||||||
|
# dataroot_gt: path_to_gt
|
||||||
|
# dataroot_lq: path_to_lq
|
||||||
|
# io_backend:
|
||||||
|
# type: disk
|
||||||
|
|
||||||
|
# network structures
|
||||||
|
network_g:
|
||||||
|
type: RRDBNet
|
||||||
|
num_in_ch: 3
|
||||||
|
num_out_ch: 3
|
||||||
|
num_feat: 64
|
||||||
|
num_block: 23
|
||||||
|
num_grow_ch: 32
|
||||||
|
|
||||||
|
|
||||||
|
network_d:
|
||||||
|
type: UNetDiscriminatorSN
|
||||||
|
num_in_ch: 3
|
||||||
|
num_feat: 64
|
||||||
|
skip_connection: True
|
||||||
|
|
||||||
|
# path
|
||||||
|
path:
|
||||||
|
# use the pre-trained Real-ESRNet model
|
||||||
|
pretrain_network_g: experiments/pretrained_models/RealESRNet_x4plus.pth
|
||||||
|
param_key_g: params_ema
|
||||||
|
strict_load_g: true
|
||||||
|
pretrain_network_d: experiments/pretrained_models/RealESRGAN_x4plus_netD.pth
|
||||||
|
param_key_d: params
|
||||||
|
strict_load_d: true
|
||||||
|
resume_state: ~
|
||||||
|
|
||||||
|
# training settings
|
||||||
|
train:
|
||||||
|
ema_decay: 0.999
|
||||||
|
optim_g:
|
||||||
|
type: Adam
|
||||||
|
lr: !!float 1e-4
|
||||||
|
weight_decay: 0
|
||||||
|
betas: [0.9, 0.99]
|
||||||
|
optim_d:
|
||||||
|
type: Adam
|
||||||
|
lr: !!float 1e-4
|
||||||
|
weight_decay: 0
|
||||||
|
betas: [0.9, 0.99]
|
||||||
|
|
||||||
|
scheduler:
|
||||||
|
type: MultiStepLR
|
||||||
|
milestones: [400000]
|
||||||
|
gamma: 0.5
|
||||||
|
|
||||||
|
total_iter: 400000
|
||||||
|
warmup_iter: -1 # no warm up
|
||||||
|
|
||||||
|
# losses
|
||||||
|
pixel_opt:
|
||||||
|
type: L1Loss
|
||||||
|
loss_weight: 1.0
|
||||||
|
reduction: mean
|
||||||
|
# perceptual loss (content and style losses)
|
||||||
|
perceptual_opt:
|
||||||
|
type: PerceptualLoss
|
||||||
|
layer_weights:
|
||||||
|
# before relu
|
||||||
|
'conv1_2': 0.1
|
||||||
|
'conv2_2': 0.1
|
||||||
|
'conv3_4': 1
|
||||||
|
'conv4_4': 1
|
||||||
|
'conv5_4': 1
|
||||||
|
vgg_type: vgg19
|
||||||
|
use_input_norm: true
|
||||||
|
perceptual_weight: !!float 1.0
|
||||||
|
style_weight: 0
|
||||||
|
range_norm: false
|
||||||
|
criterion: l1
|
||||||
|
# gan loss
|
||||||
|
gan_opt:
|
||||||
|
type: GANLoss
|
||||||
|
gan_type: vanilla
|
||||||
|
real_label_val: 1.0
|
||||||
|
fake_label_val: 0.0
|
||||||
|
loss_weight: !!float 1e-1
|
||||||
|
|
||||||
|
net_d_iters: 1
|
||||||
|
net_d_init_iters: 0
|
||||||
|
|
||||||
|
# Uncomment these for validation
|
||||||
|
# validation settings
|
||||||
|
# val:
|
||||||
|
# val_freq: !!float 5e3
|
||||||
|
# save_img: True
|
||||||
|
|
||||||
|
# metrics:
|
||||||
|
# psnr: # metric name, can be arbitrary
|
||||||
|
# type: calculate_psnr
|
||||||
|
# crop_border: 4
|
||||||
|
# test_y_channel: false
|
||||||
|
|
||||||
|
# logging settings
|
||||||
|
logger:
|
||||||
|
print_freq: 100
|
||||||
|
save_checkpoint_freq: !!float 5e3
|
||||||
|
use_tb_logger: true
|
||||||
|
wandb:
|
||||||
|
project: ~
|
||||||
|
resume_id: ~
|
||||||
|
|
||||||
|
# dist training settings
|
||||||
|
dist_params:
|
||||||
|
backend: nccl
|
||||||
|
port: 29500
|
||||||
187
options/train_realesrgan_x2plus.yml
Normal file
187
options/train_realesrgan_x2plus.yml
Normal file
@@ -0,0 +1,187 @@
|
|||||||
|
# general settings
|
||||||
|
name: train_RealESRGANx2plus_400k_B12G4
|
||||||
|
model_type: RealESRGANModel
|
||||||
|
scale: 2
|
||||||
|
num_gpu: 4
|
||||||
|
manual_seed: 0
|
||||||
|
|
||||||
|
# ----------------- options for synthesizing training data in RealESRGANModel ----------------- #
|
||||||
|
# USM the ground-truth
|
||||||
|
l1_gt_usm: True
|
||||||
|
percep_gt_usm: True
|
||||||
|
gan_gt_usm: False
|
||||||
|
|
||||||
|
# the first degradation process
|
||||||
|
resize_prob: [0.2, 0.7, 0.1] # up, down, keep
|
||||||
|
resize_range: [0.15, 1.5]
|
||||||
|
gaussian_noise_prob: 0.5
|
||||||
|
noise_range: [1, 30]
|
||||||
|
poisson_scale_range: [0.05, 3]
|
||||||
|
gray_noise_prob: 0.4
|
||||||
|
jpeg_range: [30, 95]
|
||||||
|
|
||||||
|
# the second degradation process
|
||||||
|
second_blur_prob: 0.8
|
||||||
|
resize_prob2: [0.3, 0.4, 0.3] # up, down, keep
|
||||||
|
resize_range2: [0.3, 1.2]
|
||||||
|
gaussian_noise_prob2: 0.5
|
||||||
|
noise_range2: [1, 25]
|
||||||
|
poisson_scale_range2: [0.05, 2.5]
|
||||||
|
gray_noise_prob2: 0.4
|
||||||
|
jpeg_range2: [30, 95]
|
||||||
|
|
||||||
|
gt_size: 256
|
||||||
|
queue_size: 180
|
||||||
|
|
||||||
|
# dataset and data loader settings
|
||||||
|
datasets:
|
||||||
|
train:
|
||||||
|
name: DF2K+OST
|
||||||
|
type: RealESRGANDataset
|
||||||
|
dataroot_gt: datasets/DF2K
|
||||||
|
meta_info: datasets/DF2K/meta_info/meta_info_DF2Kmultiscale+OST_sub.txt
|
||||||
|
io_backend:
|
||||||
|
type: disk
|
||||||
|
|
||||||
|
blur_kernel_size: 21
|
||||||
|
kernel_list: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
|
||||||
|
kernel_prob: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
|
||||||
|
sinc_prob: 0.1
|
||||||
|
blur_sigma: [0.2, 3]
|
||||||
|
betag_range: [0.5, 4]
|
||||||
|
betap_range: [1, 2]
|
||||||
|
|
||||||
|
blur_kernel_size2: 21
|
||||||
|
kernel_list2: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
|
||||||
|
kernel_prob2: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
|
||||||
|
sinc_prob2: 0.1
|
||||||
|
blur_sigma2: [0.2, 1.5]
|
||||||
|
betag_range2: [0.5, 4]
|
||||||
|
betap_range2: [1, 2]
|
||||||
|
|
||||||
|
final_sinc_prob: 0.8
|
||||||
|
|
||||||
|
gt_size: 256
|
||||||
|
use_hflip: True
|
||||||
|
use_rot: False
|
||||||
|
|
||||||
|
# data loader
|
||||||
|
use_shuffle: true
|
||||||
|
num_worker_per_gpu: 5
|
||||||
|
batch_size_per_gpu: 12
|
||||||
|
dataset_enlarge_ratio: 1
|
||||||
|
prefetch_mode: ~
|
||||||
|
|
||||||
|
# Uncomment these for validation
|
||||||
|
# val:
|
||||||
|
# name: validation
|
||||||
|
# type: PairedImageDataset
|
||||||
|
# dataroot_gt: path_to_gt
|
||||||
|
# dataroot_lq: path_to_lq
|
||||||
|
# io_backend:
|
||||||
|
# type: disk
|
||||||
|
|
||||||
|
# network structures
|
||||||
|
network_g:
|
||||||
|
type: RRDBNet
|
||||||
|
num_in_ch: 3
|
||||||
|
num_out_ch: 3
|
||||||
|
num_feat: 64
|
||||||
|
num_block: 23
|
||||||
|
num_grow_ch: 32
|
||||||
|
scale: 2
|
||||||
|
|
||||||
|
|
||||||
|
network_d:
|
||||||
|
type: UNetDiscriminatorSN
|
||||||
|
num_in_ch: 3
|
||||||
|
num_feat: 64
|
||||||
|
skip_connection: True
|
||||||
|
|
||||||
|
# path
|
||||||
|
path:
|
||||||
|
# use the pre-trained Real-ESRNet model
|
||||||
|
pretrain_network_g: experiments/pretrained_models/RealESRNet_x2plus.pth
|
||||||
|
param_key_g: params_ema
|
||||||
|
strict_load_g: true
|
||||||
|
resume_state: ~
|
||||||
|
|
||||||
|
# training settings
|
||||||
|
train:
|
||||||
|
ema_decay: 0.999
|
||||||
|
optim_g:
|
||||||
|
type: Adam
|
||||||
|
lr: !!float 1e-4
|
||||||
|
weight_decay: 0
|
||||||
|
betas: [0.9, 0.99]
|
||||||
|
optim_d:
|
||||||
|
type: Adam
|
||||||
|
lr: !!float 1e-4
|
||||||
|
weight_decay: 0
|
||||||
|
betas: [0.9, 0.99]
|
||||||
|
|
||||||
|
scheduler:
|
||||||
|
type: MultiStepLR
|
||||||
|
milestones: [400000]
|
||||||
|
gamma: 0.5
|
||||||
|
|
||||||
|
total_iter: 400000
|
||||||
|
warmup_iter: -1 # no warm up
|
||||||
|
|
||||||
|
# losses
|
||||||
|
pixel_opt:
|
||||||
|
type: L1Loss
|
||||||
|
loss_weight: 1.0
|
||||||
|
reduction: mean
|
||||||
|
# perceptual loss (content and style losses)
|
||||||
|
perceptual_opt:
|
||||||
|
type: PerceptualLoss
|
||||||
|
layer_weights:
|
||||||
|
# before relu
|
||||||
|
'conv1_2': 0.1
|
||||||
|
'conv2_2': 0.1
|
||||||
|
'conv3_4': 1
|
||||||
|
'conv4_4': 1
|
||||||
|
'conv5_4': 1
|
||||||
|
vgg_type: vgg19
|
||||||
|
use_input_norm: true
|
||||||
|
perceptual_weight: !!float 1.0
|
||||||
|
style_weight: 0
|
||||||
|
range_norm: false
|
||||||
|
criterion: l1
|
||||||
|
# gan loss
|
||||||
|
gan_opt:
|
||||||
|
type: GANLoss
|
||||||
|
gan_type: vanilla
|
||||||
|
real_label_val: 1.0
|
||||||
|
fake_label_val: 0.0
|
||||||
|
loss_weight: !!float 1e-1
|
||||||
|
|
||||||
|
net_d_iters: 1
|
||||||
|
net_d_init_iters: 0
|
||||||
|
|
||||||
|
# Uncomment these for validation
|
||||||
|
# validation settings
|
||||||
|
# val:
|
||||||
|
# val_freq: !!float 5e3
|
||||||
|
# save_img: True
|
||||||
|
|
||||||
|
# metrics:
|
||||||
|
# psnr: # metric name, can be arbitrary
|
||||||
|
# type: calculate_psnr
|
||||||
|
# crop_border: 4
|
||||||
|
# test_y_channel: false
|
||||||
|
|
||||||
|
# logging settings
|
||||||
|
logger:
|
||||||
|
print_freq: 100
|
||||||
|
save_checkpoint_freq: !!float 5e3
|
||||||
|
use_tb_logger: true
|
||||||
|
wandb:
|
||||||
|
project: ~
|
||||||
|
resume_id: ~
|
||||||
|
|
||||||
|
# dist training settings
|
||||||
|
dist_params:
|
||||||
|
backend: nccl
|
||||||
|
port: 29500
|
||||||
@@ -1,5 +1,5 @@
|
|||||||
# general settings
|
# general settings
|
||||||
name: train_RealESRGANx4plus_400k_B12G4_fromRealESRNet
|
name: train_RealESRGANx4plus_400k_B12G4
|
||||||
model_type: RealESRGANModel
|
model_type: RealESRGANModel
|
||||||
scale: 4
|
scale: 4
|
||||||
num_gpu: 4
|
num_gpu: 4
|
||||||
@@ -39,7 +39,7 @@ datasets:
|
|||||||
name: DF2K+OST
|
name: DF2K+OST
|
||||||
type: RealESRGANDataset
|
type: RealESRGANDataset
|
||||||
dataroot_gt: datasets/DF2K
|
dataroot_gt: datasets/DF2K
|
||||||
meta_info: data/meta_info/meta_info_DF2Kmultiscale+OST_sub.txt
|
meta_info: datasets/DF2K/meta_info/meta_info_DF2Kmultiscale+OST_sub.txt
|
||||||
io_backend:
|
io_backend:
|
||||||
type: disk
|
type: disk
|
||||||
|
|
||||||
@@ -100,7 +100,7 @@ network_d:
|
|||||||
# path
|
# path
|
||||||
path:
|
path:
|
||||||
# use the pre-trained Real-ESRNet model
|
# use the pre-trained Real-ESRNet model
|
||||||
pretrain_network_g: experiments/train_RealESRNetx4plus_1000k_B12G4_fromESRGAN/model/net_g_1000000.pth
|
pretrain_network_g: experiments/pretrained_models/RealESRNet_x4plus.pth
|
||||||
param_key_g: params_ema
|
param_key_g: params_ema
|
||||||
strict_load_g: true
|
strict_load_g: true
|
||||||
resume_state: ~
|
resume_state: ~
|
||||||
|
|||||||
145
options/train_realesrnet_x2plus.yml
Normal file
145
options/train_realesrnet_x2plus.yml
Normal file
@@ -0,0 +1,145 @@
|
|||||||
|
# general settings
|
||||||
|
name: train_RealESRNetx2plus_1000k_B12G4
|
||||||
|
model_type: RealESRNetModel
|
||||||
|
scale: 2
|
||||||
|
num_gpu: 4
|
||||||
|
manual_seed: 0
|
||||||
|
|
||||||
|
# ----------------- options for synthesizing training data in RealESRNetModel ----------------- #
|
||||||
|
gt_usm: True # USM the ground-truth
|
||||||
|
|
||||||
|
# the first degradation process
|
||||||
|
resize_prob: [0.2, 0.7, 0.1] # up, down, keep
|
||||||
|
resize_range: [0.15, 1.5]
|
||||||
|
gaussian_noise_prob: 0.5
|
||||||
|
noise_range: [1, 30]
|
||||||
|
poisson_scale_range: [0.05, 3]
|
||||||
|
gray_noise_prob: 0.4
|
||||||
|
jpeg_range: [30, 95]
|
||||||
|
|
||||||
|
# the second degradation process
|
||||||
|
second_blur_prob: 0.8
|
||||||
|
resize_prob2: [0.3, 0.4, 0.3] # up, down, keep
|
||||||
|
resize_range2: [0.3, 1.2]
|
||||||
|
gaussian_noise_prob2: 0.5
|
||||||
|
noise_range2: [1, 25]
|
||||||
|
poisson_scale_range2: [0.05, 2.5]
|
||||||
|
gray_noise_prob2: 0.4
|
||||||
|
jpeg_range2: [30, 95]
|
||||||
|
|
||||||
|
gt_size: 256
|
||||||
|
queue_size: 180
|
||||||
|
|
||||||
|
# dataset and data loader settings
|
||||||
|
datasets:
|
||||||
|
train:
|
||||||
|
name: DF2K+OST
|
||||||
|
type: RealESRGANDataset
|
||||||
|
dataroot_gt: datasets/DF2K
|
||||||
|
meta_info: datasets/DF2K/meta_info/meta_info_DF2Kmultiscale+OST_sub.txt
|
||||||
|
io_backend:
|
||||||
|
type: disk
|
||||||
|
|
||||||
|
blur_kernel_size: 21
|
||||||
|
kernel_list: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
|
||||||
|
kernel_prob: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
|
||||||
|
sinc_prob: 0.1
|
||||||
|
blur_sigma: [0.2, 3]
|
||||||
|
betag_range: [0.5, 4]
|
||||||
|
betap_range: [1, 2]
|
||||||
|
|
||||||
|
blur_kernel_size2: 21
|
||||||
|
kernel_list2: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
|
||||||
|
kernel_prob2: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
|
||||||
|
sinc_prob2: 0.1
|
||||||
|
blur_sigma2: [0.2, 1.5]
|
||||||
|
betag_range2: [0.5, 4]
|
||||||
|
betap_range2: [1, 2]
|
||||||
|
|
||||||
|
final_sinc_prob: 0.8
|
||||||
|
|
||||||
|
gt_size: 256
|
||||||
|
use_hflip: True
|
||||||
|
use_rot: False
|
||||||
|
|
||||||
|
# data loader
|
||||||
|
use_shuffle: true
|
||||||
|
num_worker_per_gpu: 5
|
||||||
|
batch_size_per_gpu: 12
|
||||||
|
dataset_enlarge_ratio: 1
|
||||||
|
prefetch_mode: ~
|
||||||
|
|
||||||
|
# Uncomment these for validation
|
||||||
|
# val:
|
||||||
|
# name: validation
|
||||||
|
# type: PairedImageDataset
|
||||||
|
# dataroot_gt: path_to_gt
|
||||||
|
# dataroot_lq: path_to_lq
|
||||||
|
# io_backend:
|
||||||
|
# type: disk
|
||||||
|
|
||||||
|
# network structures
|
||||||
|
network_g:
|
||||||
|
type: RRDBNet
|
||||||
|
num_in_ch: 3
|
||||||
|
num_out_ch: 3
|
||||||
|
num_feat: 64
|
||||||
|
num_block: 23
|
||||||
|
num_grow_ch: 32
|
||||||
|
scale: 2
|
||||||
|
|
||||||
|
# path
|
||||||
|
path:
|
||||||
|
pretrain_network_g: experiments/pretrained_models/RealESRGAN_x4plus.pth
|
||||||
|
param_key_g: params_ema
|
||||||
|
strict_load_g: False
|
||||||
|
resume_state: ~
|
||||||
|
|
||||||
|
# training settings
|
||||||
|
train:
|
||||||
|
ema_decay: 0.999
|
||||||
|
optim_g:
|
||||||
|
type: Adam
|
||||||
|
lr: !!float 2e-4
|
||||||
|
weight_decay: 0
|
||||||
|
betas: [0.9, 0.99]
|
||||||
|
|
||||||
|
scheduler:
|
||||||
|
type: MultiStepLR
|
||||||
|
milestones: [1000000]
|
||||||
|
gamma: 0.5
|
||||||
|
|
||||||
|
total_iter: 1000000
|
||||||
|
warmup_iter: -1 # no warm up
|
||||||
|
|
||||||
|
# losses
|
||||||
|
pixel_opt:
|
||||||
|
type: L1Loss
|
||||||
|
loss_weight: 1.0
|
||||||
|
reduction: mean
|
||||||
|
|
||||||
|
# Uncomment these for validation
|
||||||
|
# validation settings
|
||||||
|
# val:
|
||||||
|
# val_freq: !!float 5e3
|
||||||
|
# save_img: True
|
||||||
|
|
||||||
|
# metrics:
|
||||||
|
# psnr: # metric name, can be arbitrary
|
||||||
|
# type: calculate_psnr
|
||||||
|
# crop_border: 4
|
||||||
|
# test_y_channel: false
|
||||||
|
|
||||||
|
# logging settings
|
||||||
|
logger:
|
||||||
|
print_freq: 100
|
||||||
|
save_checkpoint_freq: !!float 5e3
|
||||||
|
use_tb_logger: true
|
||||||
|
wandb:
|
||||||
|
project: ~
|
||||||
|
resume_id: ~
|
||||||
|
|
||||||
|
# dist training settings
|
||||||
|
dist_params:
|
||||||
|
backend: nccl
|
||||||
|
port: 29500
|
||||||
@@ -1,5 +1,5 @@
|
|||||||
# general settings
|
# general settings
|
||||||
name: train_RealESRNetx4plus_1000k_B12G4_fromESRGAN
|
name: train_RealESRNetx4plus_1000k_B12G4
|
||||||
model_type: RealESRNetModel
|
model_type: RealESRNetModel
|
||||||
scale: 4
|
scale: 4
|
||||||
num_gpu: 4
|
num_gpu: 4
|
||||||
@@ -36,7 +36,7 @@ datasets:
|
|||||||
name: DF2K+OST
|
name: DF2K+OST
|
||||||
type: RealESRGANDataset
|
type: RealESRGANDataset
|
||||||
dataroot_gt: datasets/DF2K
|
dataroot_gt: datasets/DF2K
|
||||||
meta_info: data/meta_info/meta_info_DF2Kmultiscale+OST_sub.txt
|
meta_info: datasets/DF2K/meta_info/meta_info_DF2Kmultiscale+OST_sub.txt
|
||||||
io_backend:
|
io_backend:
|
||||||
type: disk
|
type: disk
|
||||||
|
|
||||||
|
|||||||
6
realesrgan/__init__.py
Normal file
6
realesrgan/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
# flake8: noqa
|
||||||
|
from .archs import *
|
||||||
|
from .data import *
|
||||||
|
from .models import *
|
||||||
|
from .utils import *
|
||||||
|
from .version import __gitsha__, __version__
|
||||||
@@ -7,4 +7,4 @@ from os import path as osp
|
|||||||
arch_folder = osp.dirname(osp.abspath(__file__))
|
arch_folder = osp.dirname(osp.abspath(__file__))
|
||||||
arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')]
|
arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')]
|
||||||
# import all the arch modules
|
# import all the arch modules
|
||||||
_arch_modules = [importlib.import_module(f'archs.{file_name}') for file_name in arch_filenames]
|
_arch_modules = [importlib.import_module(f'realesrgan.archs.{file_name}') for file_name in arch_filenames]
|
||||||
@@ -7,4 +7,4 @@ from os import path as osp
|
|||||||
data_folder = osp.dirname(osp.abspath(__file__))
|
data_folder = osp.dirname(osp.abspath(__file__))
|
||||||
dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')]
|
dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')]
|
||||||
# import all the dataset modules
|
# import all the dataset modules
|
||||||
_dataset_modules = [importlib.import_module(f'data.{file_name}') for file_name in dataset_filenames]
|
_dataset_modules = [importlib.import_module(f'realesrgan.data.{file_name}') for file_name in dataset_filenames]
|
||||||
@@ -7,4 +7,4 @@ from os import path as osp
|
|||||||
model_folder = osp.dirname(osp.abspath(__file__))
|
model_folder = osp.dirname(osp.abspath(__file__))
|
||||||
model_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) if v.endswith('_model.py')]
|
model_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) if v.endswith('_model.py')]
|
||||||
# import all the model modules
|
# import all the model modules
|
||||||
_model_modules = [importlib.import_module(f'models.{file_name}') for file_name in model_filenames]
|
_model_modules = [importlib.import_module(f'realesrgan.models.{file_name}') for file_name in model_filenames]
|
||||||
@@ -18,7 +18,7 @@ class RealESRGANModel(SRGANModel):
|
|||||||
def __init__(self, opt):
|
def __init__(self, opt):
|
||||||
super(RealESRGANModel, self).__init__(opt)
|
super(RealESRGANModel, self).__init__(opt)
|
||||||
self.jpeger = DiffJPEG(differentiable=False).cuda()
|
self.jpeger = DiffJPEG(differentiable=False).cuda()
|
||||||
self.usm_shaper = USMSharp().cuda()
|
self.usm_sharpener = USMSharp().cuda()
|
||||||
self.queue_size = opt['queue_size']
|
self.queue_size = opt['queue_size']
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@@ -58,7 +58,7 @@ class RealESRGANModel(SRGANModel):
|
|||||||
if self.is_train:
|
if self.is_train:
|
||||||
# training data synthesis
|
# training data synthesis
|
||||||
self.gt = data['gt'].to(self.device)
|
self.gt = data['gt'].to(self.device)
|
||||||
self.gt_usm = self.usm_shaper(self.gt)
|
self.gt_usm = self.usm_sharpener(self.gt)
|
||||||
|
|
||||||
self.kernel1 = data['kernel1'].to(self.device)
|
self.kernel1 = data['kernel1'].to(self.device)
|
||||||
self.kernel2 = data['kernel2'].to(self.device)
|
self.kernel2 = data['kernel2'].to(self.device)
|
||||||
@@ -160,6 +160,8 @@ class RealESRGANModel(SRGANModel):
|
|||||||
|
|
||||||
# training pair pool
|
# training pair pool
|
||||||
self._dequeue_and_enqueue()
|
self._dequeue_and_enqueue()
|
||||||
|
# sharpen self.gt again, as we have changed the self.gt with self._dequeue_and_enqueue
|
||||||
|
self.gt_usm = self.usm_sharpener(self.gt)
|
||||||
else:
|
else:
|
||||||
self.lq = data['lq'].to(self.device)
|
self.lq = data['lq'].to(self.device)
|
||||||
if 'gt' in data:
|
if 'gt' in data:
|
||||||
@@ -17,7 +17,7 @@ class RealESRNetModel(SRModel):
|
|||||||
def __init__(self, opt):
|
def __init__(self, opt):
|
||||||
super(RealESRNetModel, self).__init__(opt)
|
super(RealESRNetModel, self).__init__(opt)
|
||||||
self.jpeger = DiffJPEG(differentiable=False).cuda()
|
self.jpeger = DiffJPEG(differentiable=False).cuda()
|
||||||
self.usm_shaper = USMSharp().cuda()
|
self.usm_sharpener = USMSharp().cuda()
|
||||||
self.queue_size = opt['queue_size']
|
self.queue_size = opt['queue_size']
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@@ -59,7 +59,7 @@ class RealESRNetModel(SRModel):
|
|||||||
self.gt = data['gt'].to(self.device)
|
self.gt = data['gt'].to(self.device)
|
||||||
# USM the GT images
|
# USM the GT images
|
||||||
if self.opt['gt_usm'] is True:
|
if self.opt['gt_usm'] is True:
|
||||||
self.gt = self.usm_shaper(self.gt)
|
self.gt = self.usm_sharpener(self.gt)
|
||||||
|
|
||||||
self.kernel1 = data['kernel1'].to(self.device)
|
self.kernel1 = data['kernel1'].to(self.device)
|
||||||
self.kernel2 = data['kernel2'].to(self.device)
|
self.kernel2 = data['kernel2'].to(self.device)
|
||||||
11
realesrgan/train.py
Normal file
11
realesrgan/train.py
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
# flake8: noqa
|
||||||
|
import os.path as osp
|
||||||
|
from basicsr.train import train_pipeline
|
||||||
|
|
||||||
|
import realesrgan.archs
|
||||||
|
import realesrgan.data
|
||||||
|
import realesrgan.models
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir))
|
||||||
|
train_pipeline(root_path)
|
||||||
231
realesrgan/utils.py
Normal file
231
realesrgan/utils.py
Normal file
@@ -0,0 +1,231 @@
|
|||||||
|
import cv2
|
||||||
|
import math
|
||||||
|
import numpy as np
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||||
|
from torch.hub import download_url_to_file, get_dir
|
||||||
|
from torch.nn import functional as F
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||||
|
|
||||||
|
|
||||||
|
class RealESRGANer():
|
||||||
|
|
||||||
|
def __init__(self, scale, model_path, tile=0, tile_pad=10, pre_pad=10, half=False):
|
||||||
|
self.scale = scale
|
||||||
|
self.tile_size = tile
|
||||||
|
self.tile_pad = tile_pad
|
||||||
|
self.pre_pad = pre_pad
|
||||||
|
self.mod_scale = None
|
||||||
|
self.half = half
|
||||||
|
|
||||||
|
# initialize model
|
||||||
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
|
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=scale)
|
||||||
|
|
||||||
|
if model_path.startswith('https://'):
|
||||||
|
model_path = load_file_from_url(
|
||||||
|
url=model_path, model_dir='realesrgan/weights', progress=True, file_name=None)
|
||||||
|
loadnet = torch.load(model_path)
|
||||||
|
if 'params_ema' in loadnet:
|
||||||
|
keyname = 'params_ema'
|
||||||
|
else:
|
||||||
|
keyname = 'params'
|
||||||
|
model.load_state_dict(loadnet[keyname], strict=True)
|
||||||
|
model.eval()
|
||||||
|
self.model = model.to(self.device)
|
||||||
|
if self.half:
|
||||||
|
self.model = self.model.half()
|
||||||
|
|
||||||
|
def pre_process(self, img):
|
||||||
|
img = torch.from_numpy(np.transpose(img, (2, 0, 1))).float()
|
||||||
|
self.img = img.unsqueeze(0).to(self.device)
|
||||||
|
if self.half:
|
||||||
|
self.img = self.img.half()
|
||||||
|
|
||||||
|
# pre_pad
|
||||||
|
if self.pre_pad != 0:
|
||||||
|
self.img = F.pad(self.img, (0, self.pre_pad, 0, self.pre_pad), 'reflect')
|
||||||
|
# mod pad
|
||||||
|
if self.scale == 2:
|
||||||
|
self.mod_scale = 2
|
||||||
|
elif self.scale == 1:
|
||||||
|
self.mod_scale = 4
|
||||||
|
if self.mod_scale is not None:
|
||||||
|
self.mod_pad_h, self.mod_pad_w = 0, 0
|
||||||
|
_, _, h, w = self.img.size()
|
||||||
|
if (h % self.mod_scale != 0):
|
||||||
|
self.mod_pad_h = (self.mod_scale - h % self.mod_scale)
|
||||||
|
if (w % self.mod_scale != 0):
|
||||||
|
self.mod_pad_w = (self.mod_scale - w % self.mod_scale)
|
||||||
|
self.img = F.pad(self.img, (0, self.mod_pad_w, 0, self.mod_pad_h), 'reflect')
|
||||||
|
|
||||||
|
def process(self):
|
||||||
|
self.output = self.model(self.img)
|
||||||
|
|
||||||
|
def tile_process(self):
|
||||||
|
"""Modified from: https://github.com/ata4/esrgan-launcher
|
||||||
|
"""
|
||||||
|
batch, channel, height, width = self.img.shape
|
||||||
|
output_height = height * self.scale
|
||||||
|
output_width = width * self.scale
|
||||||
|
output_shape = (batch, channel, output_height, output_width)
|
||||||
|
|
||||||
|
# start with black image
|
||||||
|
self.output = self.img.new_zeros(output_shape)
|
||||||
|
tiles_x = math.ceil(width / self.tile_size)
|
||||||
|
tiles_y = math.ceil(height / self.tile_size)
|
||||||
|
|
||||||
|
# loop over all tiles
|
||||||
|
for y in range(tiles_y):
|
||||||
|
for x in range(tiles_x):
|
||||||
|
# extract tile from input image
|
||||||
|
ofs_x = x * self.tile_size
|
||||||
|
ofs_y = y * self.tile_size
|
||||||
|
# input tile area on total image
|
||||||
|
input_start_x = ofs_x
|
||||||
|
input_end_x = min(ofs_x + self.tile_size, width)
|
||||||
|
input_start_y = ofs_y
|
||||||
|
input_end_y = min(ofs_y + self.tile_size, height)
|
||||||
|
|
||||||
|
# input tile area on total image with padding
|
||||||
|
input_start_x_pad = max(input_start_x - self.tile_pad, 0)
|
||||||
|
input_end_x_pad = min(input_end_x + self.tile_pad, width)
|
||||||
|
input_start_y_pad = max(input_start_y - self.tile_pad, 0)
|
||||||
|
input_end_y_pad = min(input_end_y + self.tile_pad, height)
|
||||||
|
|
||||||
|
# input tile dimensions
|
||||||
|
input_tile_width = input_end_x - input_start_x
|
||||||
|
input_tile_height = input_end_y - input_start_y
|
||||||
|
tile_idx = y * tiles_x + x + 1
|
||||||
|
input_tile = self.img[:, :, input_start_y_pad:input_end_y_pad, input_start_x_pad:input_end_x_pad]
|
||||||
|
|
||||||
|
# upscale tile
|
||||||
|
try:
|
||||||
|
with torch.no_grad():
|
||||||
|
output_tile = self.model(input_tile)
|
||||||
|
except Exception as error:
|
||||||
|
print('Error', error)
|
||||||
|
print(f'\tTile {tile_idx}/{tiles_x * tiles_y}')
|
||||||
|
|
||||||
|
# output tile area on total image
|
||||||
|
output_start_x = input_start_x * self.scale
|
||||||
|
output_end_x = input_end_x * self.scale
|
||||||
|
output_start_y = input_start_y * self.scale
|
||||||
|
output_end_y = input_end_y * self.scale
|
||||||
|
|
||||||
|
# output tile area without padding
|
||||||
|
output_start_x_tile = (input_start_x - input_start_x_pad) * self.scale
|
||||||
|
output_end_x_tile = output_start_x_tile + input_tile_width * self.scale
|
||||||
|
output_start_y_tile = (input_start_y - input_start_y_pad) * self.scale
|
||||||
|
output_end_y_tile = output_start_y_tile + input_tile_height * self.scale
|
||||||
|
|
||||||
|
# put tile into output image
|
||||||
|
self.output[:, :, output_start_y:output_end_y,
|
||||||
|
output_start_x:output_end_x] = output_tile[:, :, output_start_y_tile:output_end_y_tile,
|
||||||
|
output_start_x_tile:output_end_x_tile]
|
||||||
|
|
||||||
|
def post_process(self):
|
||||||
|
# remove extra pad
|
||||||
|
if self.mod_scale is not None:
|
||||||
|
_, _, h, w = self.output.size()
|
||||||
|
self.output = self.output[:, :, 0:h - self.mod_pad_h * self.scale, 0:w - self.mod_pad_w * self.scale]
|
||||||
|
# remove prepad
|
||||||
|
if self.pre_pad != 0:
|
||||||
|
_, _, h, w = self.output.size()
|
||||||
|
self.output = self.output[:, :, 0:h - self.pre_pad * self.scale, 0:w - self.pre_pad * self.scale]
|
||||||
|
return self.output
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def enhance(self, img, outscale=None, alpha_upsampler='realesrgan'):
|
||||||
|
h_input, w_input = img.shape[0:2]
|
||||||
|
# img: numpy
|
||||||
|
img = img.astype(np.float32)
|
||||||
|
if np.max(img) > 256: # 16-bit image
|
||||||
|
max_range = 65535
|
||||||
|
print('\tInput is a 16-bit image')
|
||||||
|
else:
|
||||||
|
max_range = 255
|
||||||
|
img = img / max_range
|
||||||
|
if len(img.shape) == 2: # gray image
|
||||||
|
img_mode = 'L'
|
||||||
|
img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
|
||||||
|
elif img.shape[2] == 4: # RGBA image with alpha channel
|
||||||
|
img_mode = 'RGBA'
|
||||||
|
alpha = img[:, :, 3]
|
||||||
|
img = img[:, :, 0:3]
|
||||||
|
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||||
|
if alpha_upsampler == 'realesrgan':
|
||||||
|
alpha = cv2.cvtColor(alpha, cv2.COLOR_GRAY2RGB)
|
||||||
|
else:
|
||||||
|
img_mode = 'RGB'
|
||||||
|
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||||
|
|
||||||
|
# ------------------- process image (without the alpha channel) ------------------- #
|
||||||
|
self.pre_process(img)
|
||||||
|
if self.tile_size > 0:
|
||||||
|
self.tile_process()
|
||||||
|
else:
|
||||||
|
self.process()
|
||||||
|
output_img = self.post_process()
|
||||||
|
output_img = output_img.data.squeeze().float().cpu().clamp_(0, 1).numpy()
|
||||||
|
output_img = np.transpose(output_img[[2, 1, 0], :, :], (1, 2, 0))
|
||||||
|
if img_mode == 'L':
|
||||||
|
output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2GRAY)
|
||||||
|
|
||||||
|
# ------------------- process the alpha channel if necessary ------------------- #
|
||||||
|
if img_mode == 'RGBA':
|
||||||
|
if alpha_upsampler == 'realesrgan':
|
||||||
|
self.pre_process(alpha)
|
||||||
|
if self.tile_size > 0:
|
||||||
|
self.tile_process()
|
||||||
|
else:
|
||||||
|
self.process()
|
||||||
|
output_alpha = self.post_process()
|
||||||
|
output_alpha = output_alpha.data.squeeze().float().cpu().clamp_(0, 1).numpy()
|
||||||
|
output_alpha = np.transpose(output_alpha[[2, 1, 0], :, :], (1, 2, 0))
|
||||||
|
output_alpha = cv2.cvtColor(output_alpha, cv2.COLOR_BGR2GRAY)
|
||||||
|
else:
|
||||||
|
h, w = alpha.shape[0:2]
|
||||||
|
output_alpha = cv2.resize(alpha, (w * self.scale, h * self.scale), interpolation=cv2.INTER_LINEAR)
|
||||||
|
|
||||||
|
# merge the alpha channel
|
||||||
|
output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2BGRA)
|
||||||
|
output_img[:, :, 3] = output_alpha
|
||||||
|
|
||||||
|
# ------------------------------ return ------------------------------ #
|
||||||
|
if max_range == 65535: # 16-bit image
|
||||||
|
output = (output_img * 65535.0).round().astype(np.uint16)
|
||||||
|
else:
|
||||||
|
output = (output_img * 255.0).round().astype(np.uint8)
|
||||||
|
|
||||||
|
if outscale is not None and outscale != float(self.scale):
|
||||||
|
output = cv2.resize(
|
||||||
|
output, (
|
||||||
|
int(w_input * outscale),
|
||||||
|
int(h_input * outscale),
|
||||||
|
), interpolation=cv2.INTER_LANCZOS4)
|
||||||
|
|
||||||
|
return output, img_mode
|
||||||
|
|
||||||
|
|
||||||
|
def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
|
||||||
|
"""Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py
|
||||||
|
"""
|
||||||
|
if model_dir is None:
|
||||||
|
hub_dir = get_dir()
|
||||||
|
model_dir = os.path.join(hub_dir, 'checkpoints')
|
||||||
|
|
||||||
|
os.makedirs(os.path.join(ROOT_DIR, model_dir), exist_ok=True)
|
||||||
|
|
||||||
|
parts = urlparse(url)
|
||||||
|
filename = os.path.basename(parts.path)
|
||||||
|
if file_name is not None:
|
||||||
|
filename = file_name
|
||||||
|
cached_file = os.path.abspath(os.path.join(ROOT_DIR, model_dir, filename))
|
||||||
|
if not os.path.exists(cached_file):
|
||||||
|
print(f'Downloading: "{url}" to {cached_file}\n')
|
||||||
|
download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
|
||||||
|
return cached_file
|
||||||
3
realesrgan/weights/README.md
Normal file
3
realesrgan/weights/README.md
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
# Weights
|
||||||
|
|
||||||
|
Put the downloaded weights to this folder.
|
||||||
@@ -1,4 +1,7 @@
|
|||||||
basicsr
|
basicsr>=1.3.3.11
|
||||||
cv2
|
facexlib>=0.2.0.3
|
||||||
|
gfpgan>=0.2.1
|
||||||
numpy
|
numpy
|
||||||
|
opencv-python
|
||||||
|
Pillow
|
||||||
torch>=1.7
|
torch>=1.7
|
||||||
|
|||||||
40
scripts/generate_meta_info.py
Normal file
40
scripts/generate_meta_info.py
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
import argparse
|
||||||
|
import glob
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
txt_file = open(args.meta_info, 'w')
|
||||||
|
for folder, root in zip(args.input, args.root):
|
||||||
|
img_paths = sorted(glob.glob(os.path.join(folder, '*')))
|
||||||
|
for img_path in img_paths:
|
||||||
|
img_name = os.path.relpath(img_path, root)
|
||||||
|
print(img_name)
|
||||||
|
txt_file.write(f'{img_name}\n')
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
"""Generate meta info (txt file) for only Ground-Truth images.
|
||||||
|
|
||||||
|
It can also generate meta info from several folders into one txt file.
|
||||||
|
"""
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
'--input',
|
||||||
|
nargs='+',
|
||||||
|
default=['datasets/DF2K/DF2K_HR', 'datasets/DF2K/DF2K_multiscale'],
|
||||||
|
help='Input folder, can be a list')
|
||||||
|
parser.add_argument(
|
||||||
|
'--root',
|
||||||
|
nargs='+',
|
||||||
|
default=['datasets/DF2K', 'datasets/DF2K'],
|
||||||
|
help='Folder root, should have the length as input folders')
|
||||||
|
parser.add_argument(
|
||||||
|
'--meta_info',
|
||||||
|
type=str,
|
||||||
|
default='datasets/DF2K/meta_info/meta_info_DF2Kmultiscale.txt',
|
||||||
|
help='txt path for meta info')
|
||||||
|
args = parser.parse_args()
|
||||||
|
assert len(args.input) == len(args.root), ('Input folder and folder root should have the same length, but got '
|
||||||
|
f'{len(args.input)} and {len(args.root)}.')
|
||||||
|
main(args)
|
||||||
46
scripts/generate_multiscale_DF2K.py
Normal file
46
scripts/generate_multiscale_DF2K.py
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
import argparse
|
||||||
|
import glob
|
||||||
|
import os
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
|
||||||
|
# For DF2K, we consider the following three scales,
|
||||||
|
# and the smallest image whose shortest edge is 400
|
||||||
|
scale_list = [0.75, 0.5, 1 / 3]
|
||||||
|
shortest_edge = 400
|
||||||
|
|
||||||
|
path_list = sorted(glob.glob(os.path.join(args.input, '*')))
|
||||||
|
for path in path_list:
|
||||||
|
print(path)
|
||||||
|
basename = os.path.splitext(os.path.basename(path))[0]
|
||||||
|
|
||||||
|
img = Image.open(path)
|
||||||
|
width, height = img.size
|
||||||
|
for idx, scale in enumerate(scale_list):
|
||||||
|
print(f'\t{scale:.2f}')
|
||||||
|
rlt = img.resize((int(width * scale), int(height * scale)), resample=Image.LANCZOS)
|
||||||
|
rlt.save(os.path.join(args.output, f'{basename}T{idx}.png'))
|
||||||
|
|
||||||
|
# save the smallest image which the shortest edge is 400
|
||||||
|
if width < height:
|
||||||
|
ratio = height / width
|
||||||
|
width = shortest_edge
|
||||||
|
height = int(width * ratio)
|
||||||
|
else:
|
||||||
|
ratio = width / height
|
||||||
|
height = shortest_edge
|
||||||
|
width = int(height * ratio)
|
||||||
|
rlt = img.resize((int(width), int(height)), resample=Image.LANCZOS)
|
||||||
|
rlt.save(os.path.join(args.output, f'{basename}T{idx+1}.png'))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--input', type=str, default='datasets/DF2K/DF2K_HR', help='Input folder')
|
||||||
|
parser.add_argument('--output', type=str, default='datasets/DF2K/DF2K_multiscale', help='Output folder')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
os.makedirs(args.output, exist_ok=True)
|
||||||
|
main(args)
|
||||||
17
scripts/pytorch2onnx.py
Normal file
17
scripts/pytorch2onnx.py
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
import torch
|
||||||
|
import torch.onnx
|
||||||
|
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||||
|
|
||||||
|
# An instance of your model
|
||||||
|
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32)
|
||||||
|
model.load_state_dict(torch.load('experiments/pretrained_models/RealESRGAN_x4plus.pth')['params_ema'])
|
||||||
|
# set the train mode to false since we will only run the forward pass.
|
||||||
|
model.train(False)
|
||||||
|
model.cpu().eval()
|
||||||
|
|
||||||
|
# An example input you would normally provide to your model's forward() method
|
||||||
|
x = torch.rand(1, 3, 64, 64)
|
||||||
|
|
||||||
|
# Export the model
|
||||||
|
with torch.no_grad():
|
||||||
|
torch_out = torch.onnx._export(model, x, 'realesrgan-x4.onnx', opset_version=11, export_params=True)
|
||||||
@@ -16,7 +16,7 @@ split_before_expression_after_opening_paren = true
|
|||||||
line_length = 120
|
line_length = 120
|
||||||
multi_line_output = 0
|
multi_line_output = 0
|
||||||
known_standard_library = pkg_resources,setuptools
|
known_standard_library = pkg_resources,setuptools
|
||||||
known_first_party = basicsr # modify it!
|
known_first_party = realesrgan
|
||||||
known_third_party = basicsr,cv2,numpy,torch
|
known_third_party = PIL,basicsr,cv2,numpy,torch
|
||||||
no_lines_before = STDLIB,LOCALFOLDER
|
no_lines_before = STDLIB,LOCALFOLDER
|
||||||
default_section = THIRDPARTY
|
default_section = THIRDPARTY
|
||||||
|
|||||||
113
setup.py
Normal file
113
setup.py
Normal file
@@ -0,0 +1,113 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
from setuptools import find_packages, setup
|
||||||
|
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
import time
|
||||||
|
|
||||||
|
version_file = 'realesrgan/version.py'
|
||||||
|
|
||||||
|
|
||||||
|
def readme():
|
||||||
|
with open('README.md', encoding='utf-8') as f:
|
||||||
|
content = f.read()
|
||||||
|
return content
|
||||||
|
|
||||||
|
|
||||||
|
def get_git_hash():
|
||||||
|
|
||||||
|
def _minimal_ext_cmd(cmd):
|
||||||
|
# construct minimal environment
|
||||||
|
env = {}
|
||||||
|
for k in ['SYSTEMROOT', 'PATH', 'HOME']:
|
||||||
|
v = os.environ.get(k)
|
||||||
|
if v is not None:
|
||||||
|
env[k] = v
|
||||||
|
# LANGUAGE is used on win32
|
||||||
|
env['LANGUAGE'] = 'C'
|
||||||
|
env['LANG'] = 'C'
|
||||||
|
env['LC_ALL'] = 'C'
|
||||||
|
out = subprocess.Popen(cmd, stdout=subprocess.PIPE, env=env).communicate()[0]
|
||||||
|
return out
|
||||||
|
|
||||||
|
try:
|
||||||
|
out = _minimal_ext_cmd(['git', 'rev-parse', 'HEAD'])
|
||||||
|
sha = out.strip().decode('ascii')
|
||||||
|
except OSError:
|
||||||
|
sha = 'unknown'
|
||||||
|
|
||||||
|
return sha
|
||||||
|
|
||||||
|
|
||||||
|
def get_hash():
|
||||||
|
if os.path.exists('.git'):
|
||||||
|
sha = get_git_hash()[:7]
|
||||||
|
elif os.path.exists(version_file):
|
||||||
|
try:
|
||||||
|
from facexlib.version import __version__
|
||||||
|
sha = __version__.split('+')[-1]
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError('Unable to get git version')
|
||||||
|
else:
|
||||||
|
sha = 'unknown'
|
||||||
|
|
||||||
|
return sha
|
||||||
|
|
||||||
|
|
||||||
|
def write_version_py():
|
||||||
|
content = """# GENERATED VERSION FILE
|
||||||
|
# TIME: {}
|
||||||
|
__version__ = '{}'
|
||||||
|
__gitsha__ = '{}'
|
||||||
|
version_info = ({})
|
||||||
|
"""
|
||||||
|
sha = get_hash()
|
||||||
|
with open('VERSION', 'r') as f:
|
||||||
|
SHORT_VERSION = f.read().strip()
|
||||||
|
VERSION_INFO = ', '.join([x if x.isdigit() else f'"{x}"' for x in SHORT_VERSION.split('.')])
|
||||||
|
|
||||||
|
version_file_str = content.format(time.asctime(), SHORT_VERSION, sha, VERSION_INFO)
|
||||||
|
with open(version_file, 'w') as f:
|
||||||
|
f.write(version_file_str)
|
||||||
|
|
||||||
|
|
||||||
|
def get_version():
|
||||||
|
with open(version_file, 'r') as f:
|
||||||
|
exec(compile(f.read(), version_file, 'exec'))
|
||||||
|
return locals()['__version__']
|
||||||
|
|
||||||
|
|
||||||
|
def get_requirements(filename='requirements.txt'):
|
||||||
|
here = os.path.dirname(os.path.realpath(__file__))
|
||||||
|
with open(os.path.join(here, filename), 'r') as f:
|
||||||
|
requires = [line.replace('\n', '') for line in f.readlines()]
|
||||||
|
return requires
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
write_version_py()
|
||||||
|
setup(
|
||||||
|
name='realesrgan',
|
||||||
|
version=get_version(),
|
||||||
|
description='Real-ESRGAN aims at developing Practical Algorithms for General Image Restoration',
|
||||||
|
long_description=readme(),
|
||||||
|
long_description_content_type='text/markdown',
|
||||||
|
author='Xintao Wang',
|
||||||
|
author_email='xintao.wang@outlook.com',
|
||||||
|
keywords='computer vision, pytorch, image restoration, super-resolution, esrgan, real-esrgan',
|
||||||
|
url='https://github.com/xinntao/Real-ESRGAN',
|
||||||
|
include_package_data=True,
|
||||||
|
packages=find_packages(exclude=('options', 'datasets', 'experiments', 'results', 'tb_logger', 'wandb')),
|
||||||
|
classifiers=[
|
||||||
|
'Development Status :: 4 - Beta',
|
||||||
|
'License :: OSI Approved :: Apache Software License',
|
||||||
|
'Operating System :: OS Independent',
|
||||||
|
'Programming Language :: Python :: 3',
|
||||||
|
'Programming Language :: Python :: 3.7',
|
||||||
|
'Programming Language :: Python :: 3.8',
|
||||||
|
],
|
||||||
|
license='BSD-3-Clause License',
|
||||||
|
setup_requires=['cython', 'numpy'],
|
||||||
|
install_requires=get_requirements(),
|
||||||
|
zip_safe=False)
|
||||||
Reference in New Issue
Block a user