Compare commits
7 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3e0085aeda | ||
|
|
42110857ef | ||
|
|
1d180efaf3 | ||
|
|
7dd860a881 | ||
|
|
35ee6f781e | ||
|
|
c9023b3d7a | ||
|
|
fb79d65ff3 |
9
.github/workflows/no-response.yml
vendored
9
.github/workflows/no-response.yml
vendored
@@ -1,12 +1,11 @@
|
||||
name: No Response
|
||||
|
||||
# TODO: it seems not to work
|
||||
# Modified from: https://raw.githubusercontent.com/github/docs/main/.github/workflows/no-response.yaml
|
||||
|
||||
# **What it does**: Closes issues that don't have enough information to be
|
||||
# actionable.
|
||||
# **Why we have it**: To remove the need for maintainers to remember to check
|
||||
# back on issues periodically to see if contributors have
|
||||
# responded.
|
||||
# **What it does**: Closes issues that don't have enough information to be actionable.
|
||||
# **Why we have it**: To remove the need for maintainers to remember to check back on issues periodically
|
||||
# to see if contributors have responded.
|
||||
# **Who does it impact**: Everyone that works on docs or docs-internal.
|
||||
|
||||
on:
|
||||
|
||||
@@ -39,7 +39,3 @@ Here are some TODOs:
|
||||
- [ ] support controllable restoration strength
|
||||
|
||||
:one: There are also [several issues](https://github.com/xinntao/Real-ESRGAN/issues) that require helpers to improve. If you can help, please let me know :smile:
|
||||
|
||||
## Contributors
|
||||
|
||||
- [AK391](https://github.com/AK391): Integrate RealESRGAN to [Huggingface Spaces](https://huggingface.co/spaces) with [Gradio](https://github.com/gradio-app/gradio). See [Gradio Web Demo](https://huggingface.co/spaces/akhaliq/Real-ESRGAN).
|
||||
|
||||
18
README.md
18
README.md
@@ -8,17 +8,21 @@
|
||||
[](https://github.com/xinntao/Real-ESRGAN/blob/master/.github/workflows/pylint.yml)
|
||||
[](https://github.com/xinntao/Real-ESRGAN/blob/master/.github/workflows/publish-pip.yml)
|
||||
|
||||
[English](README.md) **|** [简体中文](README_CN.md)
|
||||
|
||||
1. [Colab Demo](https://colab.research.google.com/drive/1k2Zod6kSHEvraybHl50Lys0LerhyTMCo?usp=sharing) for Real-ESRGAN <a href="https://colab.research.google.com/drive/1k2Zod6kSHEvraybHl50Lys0LerhyTMCo?usp=sharing"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="google colab logo"></a>.
|
||||
2. Portable [Windows](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/realesrgan-ncnn-vulkan-20210901-windows.zip) / [Linux](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/realesrgan-ncnn-vulkan-20210901-ubuntu.zip) / [MacOS](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/realesrgan-ncnn-vulkan-20210901-macos.zip) **executable files for Intel/AMD/Nvidia GPU**. You can find more information [here](#Portable-executable-files). The ncnn implementation is in [Real-ESRGAN-ncnn-vulkan](https://github.com/xinntao/Real-ESRGAN-ncnn-vulkan).
|
||||
|
||||
感谢大家的关注和使用:-) 关于动漫插画的模型,目前还有很多问题,主要有: 1. 视频处理不了; 2. 景深虚化有问题; 3. 不可调节, 效果过了; 4. 改变原来的风格。大家提供了很好的反馈。我会逐步整理这些反馈,更新在 [这个文档](feedback.md)。希望不久之后,有新模型可以使用
|
||||
Thanks for your interests and use:-) There are still many problems about the anime/illustration model, mainly including: 1. It cannot deal with videos; 2. It cannot be aware of depth/depth-of-field; 3. It is not adjustable; 4. May change the original style. Thanks for your valuable feedbacks/suggestions. All the feedbacks are updated in [feedback.md](feedback.md). Hopefully, a new model will be available soon.
|
||||
|
||||
感谢大家的关注和使用:-) 关于动漫插画的模型,目前还有很多问题,主要有: 1. 视频处理不了; 2. 景深虚化有问题; 3. 不可调节, 效果过了; 4. 改变原来的风格。大家提供了很好的反馈。这些反馈会逐步更新在 [这个文档](feedback.md)。希望不久之后,有新模型可以使用.
|
||||
|
||||
Real-ESRGAN aims at developing **Practical Algorithms for General Image Restoration**.<br>
|
||||
We extend the powerful ESRGAN to a practical restoration application (namely, Real-ESRGAN), which is trained with pure synthetic data.
|
||||
|
||||
:art: Real-ESRGAN needs your contributions. Any contributions are welcome, such as new features/models/typo fixes/suggestions/maintenance, *etc*. See [CONTRIBUTING.md](CONTRIBUTING.md). All contributors are list [here](CONTRIBUTING.md#Contributors).
|
||||
:art: Real-ESRGAN needs your contributions. Any contributions are welcome, such as new features/models/typo fixes/suggestions/maintenance, *etc*. See [CONTRIBUTING.md](CONTRIBUTING.md). All contributors are list [here](README.md#hugs-acknowledgement).
|
||||
|
||||
:question: Frequently Asked Questions can be found in [FAQ.md](FAQ.md).
|
||||
:question: Frequently Asked Questions can be found in [FAQ.md](FAQ.md) (Well, it is still empty there =-=||).
|
||||
|
||||
:triangular_flag_on_post: **Updates**
|
||||
- :white_check_mark: Add the ncnn implementation [Real-ESRGAN-ncnn-vulkan](https://github.com/xinntao/Real-ESRGAN-ncnn-vulkan).
|
||||
@@ -238,3 +242,11 @@ A detailed guide can be found in [Training.md](Training.md).
|
||||
## :e-mail: Contact
|
||||
|
||||
If you have any question, please email `xintao.wang@outlook.com` or `xintaowang@tencent.com`.
|
||||
|
||||
## :hugs: Acknowledgement
|
||||
|
||||
Thanks for all the contributors.
|
||||
|
||||
- [AK391](https://github.com/AK391): Integrate RealESRGAN to [Huggingface Spaces](https://huggingface.co/spaces) with [Gradio](https://github.com/gradio-app/gradio). See [Gradio Web Demo](https://huggingface.co/spaces/akhaliq/Real-ESRGAN).
|
||||
- [Asiimoviet](https://github.com/Asiimoviet): Translate the README.md to Chinese (中文).
|
||||
- [2ji3150](https://github.com/2ji3150): Thanks for the [detailed and valuable feedbacks/suggestions](https://github.com/xinntao/Real-ESRGAN/issues/131).
|
||||
|
||||
248
README_CN.md
Normal file
248
README_CN.md
Normal file
@@ -0,0 +1,248 @@
|
||||
# Real-ESRGAN
|
||||
|
||||
[](https://github.com/xinntao/Real-ESRGAN/releases)
|
||||
[](https://pypi.org/project/realesrgan/)
|
||||
[](https://github.com/xinntao/Real-ESRGAN/issues)
|
||||
[](https://github.com/xinntao/Real-ESRGAN/issues)
|
||||
[](https://github.com/xinntao/Real-ESRGAN/blob/master/LICENSE)
|
||||
[](https://github.com/xinntao/Real-ESRGAN/blob/master/.github/workflows/pylint.yml)
|
||||
[](https://github.com/xinntao/Real-ESRGAN/blob/master/.github/workflows/publish-pip.yml)
|
||||
|
||||
[English](README.md) **|** [简体中文](README_CN.md)
|
||||
|
||||
1. Real-ESRGAN的[Colab Demo](https://colab.research.google.com/drive/1k2Zod6kSHEvraybHl50Lys0LerhyTMCo?usp=sharing) <a href="https://colab.research.google.com/drive/1k2Zod6kSHEvraybHl50Lys0LerhyTMCo?usp=sharing"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="google colab logo"></a>.
|
||||
2. **支持Intel/AMD/Nvidia显卡**的绿色版exe文件: [Windows版](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/realesrgan-ncnn-vulkan-20210901-windows.zip) / [Linux版](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/realesrgan-ncnn-vulkan-20210901-ubuntu.zip) / [macOS版](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/realesrgan-ncnn-vulkan-20210901-macos.zip),详情请移步[这里](#便携版(绿色版)可执行文件)。NCNN的实现在 [Real-ESRGAN-ncnn-vulkan](https://github.com/xinntao/Real-ESRGAN-ncnn-vulkan)。
|
||||
|
||||
感谢大家的关注和使用:-) 关于动漫插画的模型,目前还有很多问题,主要有: 1. 视频处理不了; 2. 景深虚化有问题; 3. 不可调节, 效果过了; 4. 改变原来的风格。大家提供了很好的反馈。这些反馈会逐步更新在 [这个文档](feedback.md)。希望不久之后,有新模型可以使用.
|
||||
|
||||
Real-ESRGAN 的目标是开发出**实用的图像修复算法**。<br>
|
||||
我们在 ESRGAN 的基础上使用纯合成的数据来进行训练,以使其能被应用于实际的图片修复的场景(顾名思义:Real-ESRGAN)。
|
||||
|
||||
:art: Real-ESRGAN 需要,也很欢迎你的贡献,如新功能、模型、bug修复、建议、维护等等。详情可以查看[CONTRIBUTING.md](CONTRIBUTING.md),所有的贡献者都会被列在[此处](README_CN.md#hugs-感谢)。
|
||||
|
||||
:question: 常见的问题可以在[FAQ.md](FAQ.md)中找到答案。(好吧,现在还是空白的=-=||)
|
||||
|
||||
:triangular_flag_on_post: **更新**
|
||||
- :white_check_mark: 添加了ncnn 实现:[Real-ESRGAN-ncnn-vulkan](https://github.com/xinntao/Real-ESRGAN-ncnn-vulkan).
|
||||
- :white_check_mark: 添加了 [*RealESRGAN_x4plus_anime_6B.pth*](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth),对二次元图片进行了优化,并减少了model的大小。详情 以及 与[waifu2x](https://github.com/nihui/waifu2x-ncnn-vulkan)的对比请查看[**anime_model.md**](docs/anime_model.md)
|
||||
- :white_check_mark: 支持用户在自己的数据上进行微调 (finetune):[详情](Training.md#Finetune-Real-ESRGAN-on-your-own-dataset)
|
||||
- :white_check_mark: 支持使用[GFPGAN](https://github.com/TencentARC/GFPGAN)**增强人脸**
|
||||
- :white_check_mark: 通过[Gradio](https://github.com/gradio-app/gradio)添加到了[Huggingface Spaces](https://huggingface.co/spaces)(一个机器学习应用的在线平台):[Gradio在线版](https://huggingface.co/spaces/akhaliq/Real-ESRGAN)。感谢[@AK391](https://github.com/AK391)
|
||||
- :white_check_mark: 支持任意比例的缩放:`--outscale`(实际上使用`LANCZOS4`来更进一步调整输出图像的尺寸)。添加了*RealESRGAN_x2plus.pth*模型
|
||||
- :white_check_mark: [推断脚本](inference_realesrgan.py)支持: 1) 分块处理**tile**; 2) 带**alpha通道**的图像; 3) **灰色**图像; 4) **16-bit**图像.
|
||||
- :white_check_mark: 训练代码已经发布,具体做法可查看:[Training.md](Training.md)。
|
||||
|
||||
---
|
||||
|
||||
如果 Real-ESRGAN 对你有帮助,可以给本项目一个 Star :star: ,或者推荐给你的朋友们,谢谢!:blush: <br/>
|
||||
其他推荐的项目:<br/>
|
||||
:arrow_forward: [GFPGAN](https://github.com/TencentARC/GFPGAN): 实用的人脸复原算法 <br>
|
||||
:arrow_forward: [BasicSR](https://github.com/xinntao/BasicSR): 开源的图像和视频工具箱<br>
|
||||
:arrow_forward: [facexlib](https://github.com/xinntao/facexlib): 提供与人脸相关的工具箱<br>
|
||||
:arrow_forward: [HandyView](https://github.com/xinntao/HandyView): 基于PyQt5的图片查看器,方便查看以及比较 <br>
|
||||
|
||||
---
|
||||
|
||||
### :book: Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data
|
||||
|
||||
> [[论文](https://arxiv.org/abs/2107.10833)]   [项目主页]   [[YouTube 视频](https://www.youtube.com/watch?v=fxHWoDSSvSc)]   [[B站视频](https://www.bilibili.com/video/BV1H34y1m7sS/)]   [[Poster](https://xinntao.github.io/projects/RealESRGAN_src/RealESRGAN_poster.pdf)]   [[PPT](https://docs.google.com/presentation/d/1QtW6Iy8rm8rGLsJ0Ldti6kP-7Qyzy6XL/edit?usp=sharing&ouid=109799856763657548160&rtpof=true&sd=true)]<br>
|
||||
> [Xintao Wang](https://xinntao.github.io/), Liangbin Xie, [Chao Dong](https://scholar.google.com.hk/citations?user=OSDCB0UAAAAJ), [Ying Shan](https://scholar.google.com/citations?user=4oXBp9UAAAAJ&hl=en) <br>
|
||||
> Tencent ARC Lab; Shenzhen Institutes of Advanced Technology, Chinese Academy of Sciences
|
||||
|
||||
<p align="center">
|
||||
<img src="assets/teaser.jpg">
|
||||
</p>
|
||||
|
||||
---
|
||||
|
||||
我们提供了一套训练好的模型(*RealESRGAN_x4plus.pth*),可以进行4倍的超分辨率。<br>
|
||||
**现在的 Real-ESRGAN 还是有几率失败的,因为现实生活的降质过程比较复杂。**<br>
|
||||
而且,本项目对**人脸以及文字之类**的效果还不是太好,但是我们会持续进行优化的。<br>
|
||||
|
||||
Real-ESRGAN 将会被长期支持,我会在空闲的时间中持续维护更新。
|
||||
|
||||
这些是未来计划的几个新功能:
|
||||
|
||||
- [ ] 优化人脸
|
||||
- [ ] 优化文字
|
||||
- [x] 优化动画图像
|
||||
- [ ] 支持更多的超分辨率比例
|
||||
- [ ] 可调节的复原
|
||||
|
||||
如果你有好主意或需求,欢迎在 issue 或 discussion 中提出。<br/>
|
||||
如果你有一些 Real-ESRGAN 中有问题的照片,你也可以在 issue 或者 discussion 中发出来。我会留意(但是不一定能解决:stuck_out_tongue:)。如果有必要的话,我还会专门开一页来记录那些有待解决的图像。
|
||||
|
||||
---
|
||||
|
||||
### 便携版(绿色版)可执行文件
|
||||
|
||||
你可以下载**支持Intel/AMD/Nvidia显卡**的绿色版exe文件: [Windows版](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/realesrgan-ncnn-vulkan-20210901-windows.zip) / [Linux版](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/realesrgan-ncnn-vulkan-20210901-ubuntu.zip) / [macOS版](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/realesrgan-ncnn-vulkan-20210901-macos.zip)。
|
||||
|
||||
绿色版指的是这些exe你可以直接运行(放U盘里拷走都没问题),因为里面已经有所需的文件和模型了。它不需要 CUDA 或者 PyTorch运行环境。<br>
|
||||
|
||||
你可以通过下面这个命令来运行(Windows版本的例子,更多信息请查看对应版本的README.md):
|
||||
|
||||
```bash
|
||||
./realesrgan-ncnn-vulkan.exe -i 输入图像.jpg -o 输出图像.png
|
||||
```
|
||||
|
||||
我们提供了三种模型:
|
||||
|
||||
1. realesrgan-x4plus(默认)
|
||||
2. reaesrnet-x4plus
|
||||
3. realesrgan-x4plus-anime(针对动漫插画图像优化,有更小的体积)
|
||||
|
||||
你可以通过`-n`参数来使用其他模型,例如`./realesrgan-ncnn-vulkan.exe -i 二次元图片.jpg -o 二刺螈图片.png -n realesrgan-x4plus-anime`
|
||||
|
||||
### 可执行文件的用法
|
||||
|
||||
1. 更多细节可以参考 [Real-ESRGAN-ncnn-vulkan](https://github.com/xinntao/Real-ESRGAN-ncnn-vulkan#computer-usages).
|
||||
2. 注意:可执行文件并没有支持 python 脚本 `inference_realesrgan.py` 中所有的功能,比如 `outscale` 选项) .
|
||||
|
||||
```console
|
||||
Usage: realesrgan-ncnn-vulkan.exe -i infile -o outfile [options]...
|
||||
|
||||
-h show this help
|
||||
-v verbose output
|
||||
-i input-path input image path (jpg/png/webp) or directory
|
||||
-o output-path output image path (jpg/png/webp) or directory
|
||||
-s scale upscale ratio (4, default=4)
|
||||
-t tile-size tile size (>=32/0=auto, default=0) can be 0,0,0 for multi-gpu
|
||||
-m model-path folder path to pre-trained models(default=models)
|
||||
-n model-name model name (default=realesrgan-x4plus, can be realesrgan-x4plus | realesrgan-x4plus-anime | realesrnet-x4plus)
|
||||
-g gpu-id gpu device to use (default=0) can be 0,1,2 for multi-gpu
|
||||
-j load:proc:save thread count for load/proc/save (default=1:2:2) can be 1:2,2,2:2 for multi-gpu
|
||||
-x enable tta mode
|
||||
-f format output image format (jpg/png/webp, default=ext/png)
|
||||
```
|
||||
|
||||
由于这些exe文件会把图像分成几个板块,然后来分别进行处理,再合成导出,输出的图像可能会有一点割裂感(而且可能跟PyTorch的输出不太一样)
|
||||
|
||||
这些exe文件均基于[Tencent/ncnn](https://github.com/Tencent/ncnn)以及[nihui](https://github.com/nihui)的[realsr-ncnn-vulkan](https://github.com/nihui/realsr-ncnn-vulkan),感谢!
|
||||
|
||||
---
|
||||
|
||||
## :wrench: 依赖以及安装
|
||||
|
||||
- Python >= 3.7 (推荐使用[Anaconda](https://www.anaconda.com/download/#linux)或[Miniconda](https://docs.conda.io/en/latest/miniconda.html))
|
||||
- [PyTorch >= 1.7](https://pytorch.org/)
|
||||
|
||||
#### 安装
|
||||
|
||||
1. 把项目克隆到本地
|
||||
|
||||
```bash
|
||||
git clone https://github.com/xinntao/Real-ESRGAN.git
|
||||
cd Real-ESRGAN
|
||||
```
|
||||
|
||||
2. 安装各种依赖
|
||||
|
||||
```bash
|
||||
# 安装 basicsr - https://github.com/xinntao/BasicSR
|
||||
# 我们使用BasicSR来训练以及推断
|
||||
pip install basicsr
|
||||
# facexlib和gfpgan是用来增强人脸的
|
||||
pip install facexlib
|
||||
pip install gfpgan
|
||||
pip install -r requirements.txt
|
||||
python setup.py develop
|
||||
```
|
||||
|
||||
## :zap: 快速上手
|
||||
|
||||
### 普通图片
|
||||
|
||||
下载我们训练好的模型: [RealESRGAN_x4plus.pth](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth)
|
||||
|
||||
```bash
|
||||
wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth -P experiments/pretrained_models
|
||||
```
|
||||
|
||||
推断!
|
||||
|
||||
```bash
|
||||
python inference_realesrgan.py --model_path experiments/pretrained_models/RealESRGAN_x4plus.pth --input inputs --face_enhance
|
||||
```
|
||||
|
||||
结果在`results`文件夹
|
||||
|
||||
### 动画图片
|
||||
|
||||
<p align="center">
|
||||
<img src="https://raw.githubusercontent.com/xinntao/public-figures/master/Real-ESRGAN/cmp_realesrgan_anime_1.png">
|
||||
</p>
|
||||
|
||||
训练好的模型: [RealESRGAN_x4plus_anime_6B](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth)<br>
|
||||
有关[waifu2x](https://github.com/nihui/waifu2x-ncnn-vulkan)的更多信息和对比在[**anime_model.md**](docs/anime_model.md)中。
|
||||
|
||||
```bash
|
||||
# 下载模型
|
||||
wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth -P experiments/pretrained_models
|
||||
# 推断
|
||||
python inference_realesrgan.py --model_path experiments/pretrained_models/RealESRGAN_x4plus_anime_6B.pth --input inputs
|
||||
```
|
||||
|
||||
结果在`results`文件夹
|
||||
|
||||
### Python 脚本的用法
|
||||
|
||||
1. 虽然你实用了 X4 模型,但是你可以 **输出任意尺寸比例的图片**,只要实用了 `outscale` 参数. 程序会进一步对模型的输出图像进行缩放。
|
||||
|
||||
```console
|
||||
Usage: python inference_realesrgan.py --model_path experiments/pretrained_models/RealESRGAN_x4plus.pth --input infile --output outfile [options]...
|
||||
|
||||
A common command: python inference_realesrgan.py --model_path experiments/pretrained_models/RealESRGAN_x4plus.pth --input infile --netscale 4 --outscale 3.5 --half --face_enhance
|
||||
|
||||
-h show this help
|
||||
--input Input image or folder. Default: inputs
|
||||
--output Output folder. Default: results
|
||||
--model_path Path to the pre-trained model. Default: experiments/pretrained_models/RealESRGAN_x4plus.pth
|
||||
--netscale Upsample scale factor of the network. Default: 4
|
||||
--outscale The final upsampling scale of the image. Default: 4
|
||||
--suffix Suffix of the restored image. Default: out
|
||||
--tile Tile size, 0 for no tile during testing. Default: 0
|
||||
--face_enhance Whether to use GFPGAN to enhance face. Default: False
|
||||
--half Whether to use half precision during inference. Default: False
|
||||
--ext Image extension. Options: auto | jpg | png, auto means using the same extension as inputs. Default: auto
|
||||
```
|
||||
|
||||
## :european_castle: 模型库
|
||||
|
||||
- [RealESRGAN_x4plus](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth): X4 model for general images
|
||||
- [RealESRGAN_x4plus_anime_6B](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth): Optimized for anime images; 6 RRDB blocks (slightly smaller network)
|
||||
- [RealESRGAN_x2plus](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth): X2 model for general images
|
||||
- [RealESRNet_x4plus](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth): X4 model with MSE loss (over-smooth effects)
|
||||
|
||||
- [official ESRGAN_x4](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth): official ESRGAN model (X4)
|
||||
|
||||
下面是 **判别器** 模型, 他们经常被用来微调(fine-tune)模型.
|
||||
|
||||
- [RealESRGAN_x4plus_netD](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.3/RealESRGAN_x4plus_netD.pth)
|
||||
- [RealESRGAN_x2plus_netD](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.3/RealESRGAN_x2plus_netD.pth)
|
||||
- [RealESRGAN_x4plus_anime_6B_netD](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B_netD.pth)
|
||||
|
||||
## :computer: 训练,在你的数据上微调(Fine-tune)
|
||||
|
||||
这里有一份详细的指南:[Training.md](Training.md).
|
||||
|
||||
## BibTeX 引用
|
||||
|
||||
@Article{wang2021realesrgan,
|
||||
title={Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data},
|
||||
author={Xintao Wang and Liangbin Xie and Chao Dong and Ying Shan},
|
||||
journal={arXiv:2107.10833},
|
||||
year={2021}
|
||||
}
|
||||
|
||||
## :e-mail: 联系我们
|
||||
|
||||
如果你有任何问题,请通过 `xintao.wang@outlook.com` 或 `xintaowang@tencent.com` 联系我们。
|
||||
|
||||
## :hugs: 感谢
|
||||
|
||||
感谢所有的贡献者大大们~
|
||||
|
||||
- [AK391](https://github.com/AK391): 通过[Gradio](https://github.com/gradio-app/gradio)添加到了[Huggingface Spaces](https://huggingface.co/spaces)(一个机器学习应用的在线平台):[Gradio在线版](https://huggingface.co/spaces/akhaliq/Real-ESRGAN)。
|
||||
- [Asiimoviet](https://github.com/Asiimoviet): 把 README.md 文档 翻译成了中文。
|
||||
- [2ji3150](https://github.com/2ji3150): 感谢详尽并且富有价值的[反馈、建议](https://github.com/xinntao/Real-ESRGAN/issues/131).
|
||||
@@ -7,3 +7,5 @@
|
||||
1. 不可以调节: 像 Waifu2X 可以调节。可以根据自己的喜好,做调整,但是 Real-ESRGAN-anime 并不可以。导致有些恢复效果过了
|
||||
1. 把原来的风格改变了: 不同的动漫插画都有自己的风格,现在的 Real-ESRGAN-anime 倾向于恢复成一种风格(这是受到训练数据集影响的)。风格是动漫很重要的一个要素,所以要尽可能保持
|
||||
1. 模型太大: 目前的模型处理太慢,能够更快。这个我们有相关的工作在探究,希望能够尽快有结果,并应用到 Real-ESRGAN 这一系列的模型上
|
||||
|
||||
Thanks for the [detailed and valuable feedbacks/suggestions](https://github.com/xinntao/Real-ESRGAN/issues/131) by [2ji3150](https://github.com/2ji3150).
|
||||
|
||||
@@ -8,6 +8,8 @@ from realesrgan import RealESRGANer
|
||||
|
||||
|
||||
def main():
|
||||
"""Inference demo for Real-ESRGAN.
|
||||
"""
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--input', type=str, default='inputs', help='Input image or folder')
|
||||
parser.add_argument(
|
||||
@@ -53,7 +55,7 @@ def main():
|
||||
pre_pad=args.pre_pad,
|
||||
half=args.half)
|
||||
|
||||
if args.face_enhance:
|
||||
if args.face_enhance: # Use GFPGAN for face enhancement
|
||||
from gfpgan import GFPGANer
|
||||
face_enhancer = GFPGANer(
|
||||
model_path='https://github.com/TencentARC/GFPGAN/releases/download/v0.2.0/GFPGANCleanv1-NoCE-C2.pth',
|
||||
@@ -78,6 +80,7 @@ def main():
|
||||
else:
|
||||
img_mode = None
|
||||
|
||||
# give warnings for too large/small images
|
||||
h, w = img.shape[0:2]
|
||||
if max(h, w) > 1000 and args.netscale == 4:
|
||||
import warnings
|
||||
@@ -91,7 +94,7 @@ def main():
|
||||
_, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True)
|
||||
else:
|
||||
output, _ = upsampler.enhance(img, outscale=args.outscale)
|
||||
except Exception as error:
|
||||
except RuntimeError as error:
|
||||
print('Error', error)
|
||||
print('If you encounter CUDA out of memory, try to set --tile with a smaller number.')
|
||||
else:
|
||||
|
||||
@@ -90,7 +90,6 @@ network_g:
|
||||
num_block: 23
|
||||
num_grow_ch: 32
|
||||
|
||||
|
||||
network_d:
|
||||
type: UNetDiscriminatorSN
|
||||
num_in_ch: 3
|
||||
@@ -169,7 +168,7 @@ train:
|
||||
# save_img: True
|
||||
|
||||
# metrics:
|
||||
# psnr: # metric name, can be arbitrary
|
||||
# psnr: # metric name
|
||||
# type: calculate_psnr
|
||||
# crop_border: 4
|
||||
# test_y_channel: false
|
||||
|
||||
@@ -52,7 +52,6 @@ network_g:
|
||||
num_block: 23
|
||||
num_grow_ch: 32
|
||||
|
||||
|
||||
network_d:
|
||||
type: UNetDiscriminatorSN
|
||||
num_in_ch: 3
|
||||
@@ -131,7 +130,7 @@ train:
|
||||
# save_img: True
|
||||
|
||||
# metrics:
|
||||
# psnr: # metric name, can be arbitrary
|
||||
# psnr: # metric name
|
||||
# type: calculate_psnr
|
||||
# crop_border: 4
|
||||
# test_y_channel: false
|
||||
|
||||
@@ -91,7 +91,6 @@ network_g:
|
||||
num_grow_ch: 32
|
||||
scale: 2
|
||||
|
||||
|
||||
network_d:
|
||||
type: UNetDiscriminatorSN
|
||||
num_in_ch: 3
|
||||
@@ -167,7 +166,7 @@ train:
|
||||
# save_img: True
|
||||
|
||||
# metrics:
|
||||
# psnr: # metric name, can be arbitrary
|
||||
# psnr: # metric name
|
||||
# type: calculate_psnr
|
||||
# crop_border: 4
|
||||
# test_y_channel: false
|
||||
|
||||
@@ -90,7 +90,6 @@ network_g:
|
||||
num_block: 23
|
||||
num_grow_ch: 32
|
||||
|
||||
|
||||
network_d:
|
||||
type: UNetDiscriminatorSN
|
||||
num_in_ch: 3
|
||||
@@ -166,7 +165,7 @@ train:
|
||||
# save_img: True
|
||||
|
||||
# metrics:
|
||||
# psnr: # metric name, can be arbitrary
|
||||
# psnr: # metric name
|
||||
# type: calculate_psnr
|
||||
# crop_border: 4
|
||||
# test_y_channel: false
|
||||
|
||||
@@ -125,7 +125,7 @@ train:
|
||||
# save_img: True
|
||||
|
||||
# metrics:
|
||||
# psnr: # metric name, can be arbitrary
|
||||
# psnr: # metric name
|
||||
# type: calculate_psnr
|
||||
# crop_border: 4
|
||||
# test_y_channel: false
|
||||
|
||||
@@ -124,7 +124,7 @@ train:
|
||||
# save_img: True
|
||||
|
||||
# metrics:
|
||||
# psnr: # metric name, can be arbitrary
|
||||
# psnr: # metric name
|
||||
# type: calculate_psnr
|
||||
# crop_border: 4
|
||||
# test_y_channel: false
|
||||
|
||||
@@ -3,4 +3,4 @@ from .archs import *
|
||||
from .data import *
|
||||
from .models import *
|
||||
from .utils import *
|
||||
from .version import __gitsha__, __version__
|
||||
from .version import __version__
|
||||
|
||||
@@ -6,15 +6,23 @@ from torch.nn.utils import spectral_norm
|
||||
|
||||
@ARCH_REGISTRY.register()
|
||||
class UNetDiscriminatorSN(nn.Module):
|
||||
"""Defines a U-Net discriminator with spectral normalization (SN)"""
|
||||
"""Defines a U-Net discriminator with spectral normalization (SN)
|
||||
|
||||
It is used in Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data.
|
||||
|
||||
Arg:
|
||||
num_in_ch (int): Channel number of inputs. Default: 3.
|
||||
num_feat (int): Channel number of base intermediate features. Default: 64.
|
||||
skip_connection (bool): Whether to use skip connections between U-Net. Default: True.
|
||||
"""
|
||||
|
||||
def __init__(self, num_in_ch, num_feat=64, skip_connection=True):
|
||||
super(UNetDiscriminatorSN, self).__init__()
|
||||
self.skip_connection = skip_connection
|
||||
norm = spectral_norm
|
||||
|
||||
# the first convolution
|
||||
self.conv0 = nn.Conv2d(num_in_ch, num_feat, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
# downsample
|
||||
self.conv1 = norm(nn.Conv2d(num_feat, num_feat * 2, 4, 2, 1, bias=False))
|
||||
self.conv2 = norm(nn.Conv2d(num_feat * 2, num_feat * 4, 4, 2, 1, bias=False))
|
||||
self.conv3 = norm(nn.Conv2d(num_feat * 4, num_feat * 8, 4, 2, 1, bias=False))
|
||||
@@ -22,14 +30,13 @@ class UNetDiscriminatorSN(nn.Module):
|
||||
self.conv4 = norm(nn.Conv2d(num_feat * 8, num_feat * 4, 3, 1, 1, bias=False))
|
||||
self.conv5 = norm(nn.Conv2d(num_feat * 4, num_feat * 2, 3, 1, 1, bias=False))
|
||||
self.conv6 = norm(nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1, bias=False))
|
||||
|
||||
# extra
|
||||
# extra convolutions
|
||||
self.conv7 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False))
|
||||
self.conv8 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False))
|
||||
|
||||
self.conv9 = nn.Conv2d(num_feat, 1, 3, 1, 1)
|
||||
|
||||
def forward(self, x):
|
||||
# downsample
|
||||
x0 = F.leaky_relu(self.conv0(x), negative_slope=0.2, inplace=True)
|
||||
x1 = F.leaky_relu(self.conv1(x0), negative_slope=0.2, inplace=True)
|
||||
x2 = F.leaky_relu(self.conv2(x1), negative_slope=0.2, inplace=True)
|
||||
@@ -52,7 +59,7 @@ class UNetDiscriminatorSN(nn.Module):
|
||||
if self.skip_connection:
|
||||
x6 = x6 + x0
|
||||
|
||||
# extra
|
||||
# extra convolutions
|
||||
out = F.leaky_relu(self.conv7(x6), negative_slope=0.2, inplace=True)
|
||||
out = F.leaky_relu(self.conv8(out), negative_slope=0.2, inplace=True)
|
||||
out = self.conv9(out)
|
||||
|
||||
@@ -15,18 +15,31 @@ from torch.utils import data as data
|
||||
|
||||
@DATASET_REGISTRY.register()
|
||||
class RealESRGANDataset(data.Dataset):
|
||||
"""
|
||||
Dataset used for Real-ESRGAN model.
|
||||
"""Dataset used for Real-ESRGAN model:
|
||||
Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data.
|
||||
|
||||
It loads gt (Ground-Truth) images, and augments them.
|
||||
It also generates blur kernels and sinc kernels for generating low-quality images.
|
||||
Note that the low-quality images are processed in tensors on GPUS for faster processing.
|
||||
|
||||
Args:
|
||||
opt (dict): Config for train datasets. It contains the following keys:
|
||||
dataroot_gt (str): Data root path for gt.
|
||||
meta_info (str): Path for meta information file.
|
||||
io_backend (dict): IO backend type and other kwarg.
|
||||
use_hflip (bool): Use horizontal flips.
|
||||
use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation).
|
||||
Please see more options in the codes.
|
||||
"""
|
||||
|
||||
def __init__(self, opt):
|
||||
super(RealESRGANDataset, self).__init__()
|
||||
self.opt = opt
|
||||
# file client (io backend)
|
||||
self.file_client = None
|
||||
self.io_backend_opt = opt['io_backend']
|
||||
self.gt_folder = opt['dataroot_gt']
|
||||
|
||||
# file client (lmdb io backend)
|
||||
if self.io_backend_opt['type'] == 'lmdb':
|
||||
self.io_backend_opt['db_paths'] = [self.gt_folder]
|
||||
self.io_backend_opt['client_keys'] = ['gt']
|
||||
@@ -35,18 +48,20 @@ class RealESRGANDataset(data.Dataset):
|
||||
with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin:
|
||||
self.paths = [line.split('.')[0] for line in fin]
|
||||
else:
|
||||
# disk backend with meta_info
|
||||
# Each line in the meta_info describes the relative path to an image
|
||||
with open(self.opt['meta_info']) as fin:
|
||||
paths = [line.strip() for line in fin]
|
||||
paths = [line.strip().split(' ')[0] for line in fin]
|
||||
self.paths = [os.path.join(self.gt_folder, v) for v in paths]
|
||||
|
||||
# blur settings for the first degradation
|
||||
self.blur_kernel_size = opt['blur_kernel_size']
|
||||
self.kernel_list = opt['kernel_list']
|
||||
self.kernel_prob = opt['kernel_prob']
|
||||
self.kernel_prob = opt['kernel_prob'] # a list for each kernel probability
|
||||
self.blur_sigma = opt['blur_sigma']
|
||||
self.betag_range = opt['betag_range']
|
||||
self.betap_range = opt['betap_range']
|
||||
self.sinc_prob = opt['sinc_prob']
|
||||
self.betag_range = opt['betag_range'] # betag used in generalized Gaussian blur kernels
|
||||
self.betap_range = opt['betap_range'] # betap used in plateau blur kernels
|
||||
self.sinc_prob = opt['sinc_prob'] # the probability for sinc filters
|
||||
|
||||
# blur settings for the second degradation
|
||||
self.blur_kernel_size2 = opt['blur_kernel_size2']
|
||||
@@ -61,6 +76,7 @@ class RealESRGANDataset(data.Dataset):
|
||||
self.final_sinc_prob = opt['final_sinc_prob']
|
||||
|
||||
self.kernel_range = [2 * v + 1 for v in range(3, 11)] # kernel size ranges from 7 to 21
|
||||
# TODO: kernel range is now hard-coded, should be in the configure file
|
||||
self.pulse_tensor = torch.zeros(21, 21).float() # convolving with pulse tensor brings no blurry effect
|
||||
self.pulse_tensor[10, 10] = 1
|
||||
|
||||
@@ -76,7 +92,7 @@ class RealESRGANDataset(data.Dataset):
|
||||
while retry > 0:
|
||||
try:
|
||||
img_bytes = self.file_client.get(gt_path, 'gt')
|
||||
except Exception as e:
|
||||
except (IOError, OSError) as e:
|
||||
logger = get_root_logger()
|
||||
logger.warn(f'File client error: {e}, remaining retry times: {retry - 1}')
|
||||
# change another file to read
|
||||
@@ -89,10 +105,11 @@ class RealESRGANDataset(data.Dataset):
|
||||
retry -= 1
|
||||
img_gt = imfrombytes(img_bytes, float32=True)
|
||||
|
||||
# -------------------- augmentation for training: flip, rotation -------------------- #
|
||||
# -------------------- Do augmentation for training: flip, rotation -------------------- #
|
||||
img_gt = augment(img_gt, self.opt['use_hflip'], self.opt['use_rot'])
|
||||
|
||||
# crop or pad to 400: 400 is hard-coded. You may change it accordingly
|
||||
# crop or pad to 400
|
||||
# TODO: 400 is hard-coded. You may change it accordingly
|
||||
h, w = img_gt.shape[0:2]
|
||||
crop_pad_size = 400
|
||||
# pad
|
||||
@@ -154,7 +171,7 @@ class RealESRGANDataset(data.Dataset):
|
||||
pad_size = (21 - kernel_size) // 2
|
||||
kernel2 = np.pad(kernel2, ((pad_size, pad_size), (pad_size, pad_size)))
|
||||
|
||||
# ------------------------------------- sinc kernel ------------------------------------- #
|
||||
# ------------------------------------- the final sinc kernel ------------------------------------- #
|
||||
if np.random.uniform() < self.opt['final_sinc_prob']:
|
||||
kernel_size = random.choice(self.kernel_range)
|
||||
omega_c = np.random.uniform(np.pi / 3, np.pi)
|
||||
|
||||
@@ -11,8 +11,7 @@ from torchvision.transforms.functional import normalize
|
||||
class RealESRGANPairedDataset(data.Dataset):
|
||||
"""Paired image dataset for image restoration.
|
||||
|
||||
Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc) and
|
||||
GT image pairs.
|
||||
Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc) and GT image pairs.
|
||||
|
||||
There are three modes:
|
||||
1. 'lmdb': Use lmdb files.
|
||||
@@ -28,8 +27,8 @@ class RealESRGANPairedDataset(data.Dataset):
|
||||
dataroot_lq (str): Data root path for lq.
|
||||
meta_info (str): Path for meta information file.
|
||||
io_backend (dict): IO backend type and other kwarg.
|
||||
filename_tmpl (str): Template for each filename. Note that the
|
||||
template excludes the file extension. Default: '{}'.
|
||||
filename_tmpl (str): Template for each filename. Note that the template excludes the file extension.
|
||||
Default: '{}'.
|
||||
gt_size (int): Cropped patched size for gt patches.
|
||||
use_hflip (bool): Use horizontal flips.
|
||||
use_rot (bool): Use rotation (use vertical flip and transposing h
|
||||
@@ -42,23 +41,23 @@ class RealESRGANPairedDataset(data.Dataset):
|
||||
def __init__(self, opt):
|
||||
super(RealESRGANPairedDataset, self).__init__()
|
||||
self.opt = opt
|
||||
# file client (io backend)
|
||||
self.file_client = None
|
||||
self.io_backend_opt = opt['io_backend']
|
||||
# mean and std for normalizing the input images
|
||||
self.mean = opt['mean'] if 'mean' in opt else None
|
||||
self.std = opt['std'] if 'std' in opt else None
|
||||
|
||||
self.gt_folder, self.lq_folder = opt['dataroot_gt'], opt['dataroot_lq']
|
||||
if 'filename_tmpl' in opt:
|
||||
self.filename_tmpl = opt['filename_tmpl']
|
||||
else:
|
||||
self.filename_tmpl = '{}'
|
||||
self.filename_tmpl = opt['filename_tmpl'] if 'filename_tmpl' in opt else '{}'
|
||||
|
||||
# file client (lmdb io backend)
|
||||
if self.io_backend_opt['type'] == 'lmdb':
|
||||
self.io_backend_opt['db_paths'] = [self.lq_folder, self.gt_folder]
|
||||
self.io_backend_opt['client_keys'] = ['lq', 'gt']
|
||||
self.paths = paired_paths_from_lmdb([self.lq_folder, self.gt_folder], ['lq', 'gt'])
|
||||
elif 'meta_info' in self.opt and self.opt['meta_info'] is not None:
|
||||
# disk backend with meta_info
|
||||
# Each line in the meta_info describes the relative path to an image
|
||||
with open(self.opt['meta_info']) as fin:
|
||||
paths = [line.strip() for line in fin]
|
||||
self.paths = []
|
||||
@@ -68,6 +67,9 @@ class RealESRGANPairedDataset(data.Dataset):
|
||||
lq_path = os.path.join(self.lq_folder, lq_path)
|
||||
self.paths.append(dict([('gt_path', gt_path), ('lq_path', lq_path)]))
|
||||
else:
|
||||
# disk backend
|
||||
# it will scan the whole folder to get meta info
|
||||
# it will be time-consuming for folders with too many files. It is recommended using an extra meta txt file
|
||||
self.paths = paired_paths_from_folder([self.lq_folder, self.gt_folder], ['lq', 'gt'], self.filename_tmpl)
|
||||
|
||||
def __getitem__(self, index):
|
||||
|
||||
@@ -13,35 +13,45 @@ from torch.nn import functional as F
|
||||
|
||||
@MODEL_REGISTRY.register()
|
||||
class RealESRGANModel(SRGANModel):
|
||||
"""RealESRGAN Model"""
|
||||
"""RealESRGAN Model for Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data.
|
||||
|
||||
It mainly performs:
|
||||
1. randomly synthesize LQ images in GPU tensors
|
||||
2. optimize the networks with GAN training.
|
||||
"""
|
||||
|
||||
def __init__(self, opt):
|
||||
super(RealESRGANModel, self).__init__(opt)
|
||||
self.jpeger = DiffJPEG(differentiable=False).cuda()
|
||||
self.usm_sharpener = USMSharp().cuda()
|
||||
self.jpeger = DiffJPEG(differentiable=False).cuda() # simulate JPEG compression artifacts
|
||||
self.usm_sharpener = USMSharp().cuda() # do usm sharpening
|
||||
self.queue_size = opt.get('queue_size', 180)
|
||||
|
||||
@torch.no_grad()
|
||||
def _dequeue_and_enqueue(self):
|
||||
# training pair pool
|
||||
"""It is the training pair pool for increasing the diversity in a batch.
|
||||
|
||||
Batch processing limits the diversity of synthetic degradations in a batch. For example, samples in a
|
||||
batch could not have different resize scaling factors. Therefore, we employ this training pair pool
|
||||
to increase the degradation diversity in a batch.
|
||||
"""
|
||||
# initialize
|
||||
b, c, h, w = self.lq.size()
|
||||
if not hasattr(self, 'queue_lr'):
|
||||
assert self.queue_size % b == 0, 'queue size should be divisible by batch size'
|
||||
assert self.queue_size % b == 0, f'queue size {self.queue_size} should be divisible by batch size {b}'
|
||||
self.queue_lr = torch.zeros(self.queue_size, c, h, w).cuda()
|
||||
_, c, h, w = self.gt.size()
|
||||
self.queue_gt = torch.zeros(self.queue_size, c, h, w).cuda()
|
||||
self.queue_ptr = 0
|
||||
if self.queue_ptr == self.queue_size: # full
|
||||
if self.queue_ptr == self.queue_size: # the pool is full
|
||||
# do dequeue and enqueue
|
||||
# shuffle
|
||||
idx = torch.randperm(self.queue_size)
|
||||
self.queue_lr = self.queue_lr[idx]
|
||||
self.queue_gt = self.queue_gt[idx]
|
||||
# get
|
||||
# get first b samples
|
||||
lq_dequeue = self.queue_lr[0:b, :, :, :].clone()
|
||||
gt_dequeue = self.queue_gt[0:b, :, :, :].clone()
|
||||
# update
|
||||
# update the queue
|
||||
self.queue_lr[0:b, :, :, :] = self.lq.clone()
|
||||
self.queue_gt[0:b, :, :, :] = self.gt.clone()
|
||||
|
||||
@@ -55,6 +65,8 @@ class RealESRGANModel(SRGANModel):
|
||||
|
||||
@torch.no_grad()
|
||||
def feed_data(self, data):
|
||||
"""Accept data from dataloader, and then add two-order degradations to obtain LQ images.
|
||||
"""
|
||||
if self.is_train and self.opt.get('high_order_degradation', True):
|
||||
# training data synthesis
|
||||
self.gt = data['gt'].to(self.device)
|
||||
@@ -79,7 +91,7 @@ class RealESRGANModel(SRGANModel):
|
||||
scale = 1
|
||||
mode = random.choice(['area', 'bilinear', 'bicubic'])
|
||||
out = F.interpolate(out, scale_factor=scale, mode=mode)
|
||||
# noise
|
||||
# add noise
|
||||
gray_noise_prob = self.opt['gray_noise_prob']
|
||||
if np.random.uniform() < self.opt['gaussian_noise_prob']:
|
||||
out = random_add_gaussian_noise_pt(
|
||||
@@ -93,7 +105,7 @@ class RealESRGANModel(SRGANModel):
|
||||
rounds=False)
|
||||
# JPEG compression
|
||||
jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range'])
|
||||
out = torch.clamp(out, 0, 1)
|
||||
out = torch.clamp(out, 0, 1) # clamp to [0, 1], otherwise JPEGer will result in unpleasant artifacts
|
||||
out = self.jpeger(out, quality=jpeg_p)
|
||||
|
||||
# ----------------------- The second degradation process ----------------------- #
|
||||
@@ -111,7 +123,7 @@ class RealESRGANModel(SRGANModel):
|
||||
mode = random.choice(['area', 'bilinear', 'bicubic'])
|
||||
out = F.interpolate(
|
||||
out, size=(int(ori_h / self.opt['scale'] * scale), int(ori_w / self.opt['scale'] * scale)), mode=mode)
|
||||
# noise
|
||||
# add noise
|
||||
gray_noise_prob = self.opt['gray_noise_prob2']
|
||||
if np.random.uniform() < self.opt['gaussian_noise_prob2']:
|
||||
out = random_add_gaussian_noise_pt(
|
||||
@@ -162,7 +174,9 @@ class RealESRGANModel(SRGANModel):
|
||||
self._dequeue_and_enqueue()
|
||||
# sharpen self.gt again, as we have changed the self.gt with self._dequeue_and_enqueue
|
||||
self.gt_usm = self.usm_sharpener(self.gt)
|
||||
self.lq = self.lq.contiguous() # for the warning: grad and param do not obey the gradient layout contract
|
||||
else:
|
||||
# for paired training or validation
|
||||
self.lq = data['lq'].to(self.device)
|
||||
if 'gt' in data:
|
||||
self.gt = data['gt'].to(self.device)
|
||||
@@ -175,6 +189,7 @@ class RealESRGANModel(SRGANModel):
|
||||
self.is_train = True
|
||||
|
||||
def optimize_parameters(self, current_iter):
|
||||
# usm sharpening
|
||||
l1_gt = self.gt_usm
|
||||
percep_gt = self.gt_usm
|
||||
gan_gt = self.gt_usm
|
||||
|
||||
@@ -12,35 +12,46 @@ from torch.nn import functional as F
|
||||
|
||||
@MODEL_REGISTRY.register()
|
||||
class RealESRNetModel(SRModel):
|
||||
"""RealESRNet Model"""
|
||||
"""RealESRNet Model for Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data.
|
||||
|
||||
It is trained without GAN losses.
|
||||
It mainly performs:
|
||||
1. randomly synthesize LQ images in GPU tensors
|
||||
2. optimize the networks with GAN training.
|
||||
"""
|
||||
|
||||
def __init__(self, opt):
|
||||
super(RealESRNetModel, self).__init__(opt)
|
||||
self.jpeger = DiffJPEG(differentiable=False).cuda()
|
||||
self.usm_sharpener = USMSharp().cuda()
|
||||
self.jpeger = DiffJPEG(differentiable=False).cuda() # simulate JPEG compression artifacts
|
||||
self.usm_sharpener = USMSharp().cuda() # do usm sharpening
|
||||
self.queue_size = opt.get('queue_size', 180)
|
||||
|
||||
@torch.no_grad()
|
||||
def _dequeue_and_enqueue(self):
|
||||
# training pair pool
|
||||
"""It is the training pair pool for increasing the diversity in a batch.
|
||||
|
||||
Batch processing limits the diversity of synthetic degradations in a batch. For example, samples in a
|
||||
batch could not have different resize scaling factors. Therefore, we employ this training pair pool
|
||||
to increase the degradation diversity in a batch.
|
||||
"""
|
||||
# initialize
|
||||
b, c, h, w = self.lq.size()
|
||||
if not hasattr(self, 'queue_lr'):
|
||||
assert self.queue_size % b == 0, 'queue size should be divisible by batch size'
|
||||
assert self.queue_size % b == 0, f'queue size {self.queue_size} should be divisible by batch size {b}'
|
||||
self.queue_lr = torch.zeros(self.queue_size, c, h, w).cuda()
|
||||
_, c, h, w = self.gt.size()
|
||||
self.queue_gt = torch.zeros(self.queue_size, c, h, w).cuda()
|
||||
self.queue_ptr = 0
|
||||
if self.queue_ptr == self.queue_size: # full
|
||||
if self.queue_ptr == self.queue_size: # the pool is full
|
||||
# do dequeue and enqueue
|
||||
# shuffle
|
||||
idx = torch.randperm(self.queue_size)
|
||||
self.queue_lr = self.queue_lr[idx]
|
||||
self.queue_gt = self.queue_gt[idx]
|
||||
# get
|
||||
# get first b samples
|
||||
lq_dequeue = self.queue_lr[0:b, :, :, :].clone()
|
||||
gt_dequeue = self.queue_gt[0:b, :, :, :].clone()
|
||||
# update
|
||||
# update the queue
|
||||
self.queue_lr[0:b, :, :, :] = self.lq.clone()
|
||||
self.queue_gt[0:b, :, :, :] = self.gt.clone()
|
||||
|
||||
@@ -54,10 +65,12 @@ class RealESRNetModel(SRModel):
|
||||
|
||||
@torch.no_grad()
|
||||
def feed_data(self, data):
|
||||
"""Accept data from dataloader, and then add two-order degradations to obtain LQ images.
|
||||
"""
|
||||
if self.is_train and self.opt.get('high_order_degradation', True):
|
||||
# training data synthesis
|
||||
self.gt = data['gt'].to(self.device)
|
||||
# USM the GT images
|
||||
# USM sharpen the GT images
|
||||
if self.opt['gt_usm'] is True:
|
||||
self.gt = self.usm_sharpener(self.gt)
|
||||
|
||||
@@ -80,7 +93,7 @@ class RealESRNetModel(SRModel):
|
||||
scale = 1
|
||||
mode = random.choice(['area', 'bilinear', 'bicubic'])
|
||||
out = F.interpolate(out, scale_factor=scale, mode=mode)
|
||||
# noise
|
||||
# add noise
|
||||
gray_noise_prob = self.opt['gray_noise_prob']
|
||||
if np.random.uniform() < self.opt['gaussian_noise_prob']:
|
||||
out = random_add_gaussian_noise_pt(
|
||||
@@ -94,7 +107,7 @@ class RealESRNetModel(SRModel):
|
||||
rounds=False)
|
||||
# JPEG compression
|
||||
jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range'])
|
||||
out = torch.clamp(out, 0, 1)
|
||||
out = torch.clamp(out, 0, 1) # clamp to [0, 1], otherwise JPEGer will result in unpleasant artifacts
|
||||
out = self.jpeger(out, quality=jpeg_p)
|
||||
|
||||
# ----------------------- The second degradation process ----------------------- #
|
||||
@@ -112,7 +125,7 @@ class RealESRNetModel(SRModel):
|
||||
mode = random.choice(['area', 'bilinear', 'bicubic'])
|
||||
out = F.interpolate(
|
||||
out, size=(int(ori_h / self.opt['scale'] * scale), int(ori_w / self.opt['scale'] * scale)), mode=mode)
|
||||
# noise
|
||||
# add noise
|
||||
gray_noise_prob = self.opt['gray_noise_prob2']
|
||||
if np.random.uniform() < self.opt['gaussian_noise_prob2']:
|
||||
out = random_add_gaussian_noise_pt(
|
||||
@@ -160,7 +173,9 @@ class RealESRNetModel(SRModel):
|
||||
|
||||
# training pair pool
|
||||
self._dequeue_and_enqueue()
|
||||
self.lq = self.lq.contiguous() # for the warning: grad and param do not obey the gradient layout contract
|
||||
else:
|
||||
# for paired training or validation
|
||||
self.lq = data['lq'].to(self.device)
|
||||
if 'gt' in data:
|
||||
self.gt = data['gt'].to(self.device)
|
||||
|
||||
@@ -4,14 +4,26 @@ import numpy as np
|
||||
import os
|
||||
import torch
|
||||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||
from torch.hub import download_url_to_file, get_dir
|
||||
from basicsr.utils.download_util import load_file_from_url
|
||||
from torch.nn import functional as F
|
||||
from urllib.parse import urlparse
|
||||
|
||||
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
|
||||
class RealESRGANer():
|
||||
"""A helper class for upsampling images with RealESRGAN.
|
||||
|
||||
Args:
|
||||
scale (int): Upsampling scale factor used in the networks. It is usually 2 or 4.
|
||||
model_path (str): The path to the pretrained model. It can be urls (will first download it automatically).
|
||||
model (nn.Module): The defined network. If None, the model will be constructed here. Default: None.
|
||||
tile (int): As too large images result in the out of GPU memory issue, so this tile option will first crop
|
||||
input images into tiles, and then process each of them. Finally, they will be merged into one image.
|
||||
0 denotes for do not use tile. Default: 0.
|
||||
tile_pad (int): The pad size for each tile, to remove border artifacts. Default: 10.
|
||||
pre_pad (int): Pad the input images to avoid border artifacts. Default: 10.
|
||||
half (float): Whether to use half precision during inference. Default: False.
|
||||
"""
|
||||
|
||||
def __init__(self, scale, model_path, model=None, tile=0, tile_pad=10, pre_pad=10, half=False):
|
||||
self.scale = scale
|
||||
@@ -26,10 +38,12 @@ class RealESRGANer():
|
||||
if model is None:
|
||||
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=scale)
|
||||
|
||||
# if the model_path starts with https, it will first download models to the folder: realesrgan/weights
|
||||
if model_path.startswith('https://'):
|
||||
model_path = load_file_from_url(
|
||||
url=model_path, model_dir='realesrgan/weights', progress=True, file_name=None)
|
||||
url=model_path, model_dir=os.path.join(ROOT_DIR, 'realesrgan/weights'), progress=True, file_name=None)
|
||||
loadnet = torch.load(model_path)
|
||||
# prefer to use params_ema
|
||||
if 'params_ema' in loadnet:
|
||||
keyname = 'params_ema'
|
||||
else:
|
||||
@@ -41,6 +55,8 @@ class RealESRGANer():
|
||||
self.model = self.model.half()
|
||||
|
||||
def pre_process(self, img):
|
||||
"""Pre-process, such as pre-pad and mod pad, so that the images can be divisible
|
||||
"""
|
||||
img = torch.from_numpy(np.transpose(img, (2, 0, 1))).float()
|
||||
self.img = img.unsqueeze(0).to(self.device)
|
||||
if self.half:
|
||||
@@ -49,7 +65,7 @@ class RealESRGANer():
|
||||
# pre_pad
|
||||
if self.pre_pad != 0:
|
||||
self.img = F.pad(self.img, (0, self.pre_pad, 0, self.pre_pad), 'reflect')
|
||||
# mod pad
|
||||
# mod pad for divisible borders
|
||||
if self.scale == 2:
|
||||
self.mod_scale = 2
|
||||
elif self.scale == 1:
|
||||
@@ -64,10 +80,14 @@ class RealESRGANer():
|
||||
self.img = F.pad(self.img, (0, self.mod_pad_w, 0, self.mod_pad_h), 'reflect')
|
||||
|
||||
def process(self):
|
||||
# model inference
|
||||
self.output = self.model(self.img)
|
||||
|
||||
def tile_process(self):
|
||||
"""Modified from: https://github.com/ata4/esrgan-launcher
|
||||
"""It will first crop input images to tiles, and then process each tile.
|
||||
Finally, all the processed tiles are merged into one images.
|
||||
|
||||
Modified from: https://github.com/ata4/esrgan-launcher
|
||||
"""
|
||||
batch, channel, height, width = self.img.shape
|
||||
output_height = height * self.scale
|
||||
@@ -107,7 +127,7 @@ class RealESRGANer():
|
||||
try:
|
||||
with torch.no_grad():
|
||||
output_tile = self.model(input_tile)
|
||||
except Exception as error:
|
||||
except RuntimeError as error:
|
||||
print('Error', error)
|
||||
print(f'\tTile {tile_idx}/{tiles_x * tiles_y}')
|
||||
|
||||
@@ -188,7 +208,7 @@ class RealESRGANer():
|
||||
output_alpha = output_alpha.data.squeeze().float().cpu().clamp_(0, 1).numpy()
|
||||
output_alpha = np.transpose(output_alpha[[2, 1, 0], :, :], (1, 2, 0))
|
||||
output_alpha = cv2.cvtColor(output_alpha, cv2.COLOR_BGR2GRAY)
|
||||
else:
|
||||
else: # use the cv2 resize for alpha channel
|
||||
h, w = alpha.shape[0:2]
|
||||
output_alpha = cv2.resize(alpha, (w * self.scale, h * self.scale), interpolation=cv2.INTER_LINEAR)
|
||||
|
||||
@@ -210,23 +230,3 @@ class RealESRGANer():
|
||||
), interpolation=cv2.INTER_LANCZOS4)
|
||||
|
||||
return output, img_mode
|
||||
|
||||
|
||||
def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
|
||||
"""Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py
|
||||
"""
|
||||
if model_dir is None:
|
||||
hub_dir = get_dir()
|
||||
model_dir = os.path.join(hub_dir, 'checkpoints')
|
||||
|
||||
os.makedirs(os.path.join(ROOT_DIR, model_dir), exist_ok=True)
|
||||
|
||||
parts = urlparse(url)
|
||||
filename = os.path.basename(parts.path)
|
||||
if file_name is not None:
|
||||
filename = file_name
|
||||
cached_file = os.path.abspath(os.path.join(ROOT_DIR, model_dir, filename))
|
||||
if not os.path.exists(cached_file):
|
||||
print(f'Downloading: "{url}" to {cached_file}\n')
|
||||
download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
|
||||
return cached_file
|
||||
|
||||
@@ -14,34 +14,24 @@ def main(args):
|
||||
|
||||
opt (dict): Configuration dict. It contains:
|
||||
n_thread (int): Thread number.
|
||||
compression_level (int): CV_IMWRITE_PNG_COMPRESSION from 0 to 9.
|
||||
A higher value means a smaller size and longer compression time.
|
||||
Use 0 for faster CPU decompression. Default: 3, same in cv2.
|
||||
|
||||
compression_level (int): CV_IMWRITE_PNG_COMPRESSION from 0 to 9. A higher value means a smaller size
|
||||
and longer compression time. Use 0 for faster CPU decompression. Default: 3, same in cv2.
|
||||
input_folder (str): Path to the input folder.
|
||||
save_folder (str): Path to save folder.
|
||||
crop_size (int): Crop size.
|
||||
step (int): Step for overlapped sliding window.
|
||||
thresh_size (int): Threshold size. Patches whose size is lower
|
||||
than thresh_size will be dropped.
|
||||
thresh_size (int): Threshold size. Patches whose size is lower than thresh_size will be dropped.
|
||||
|
||||
Usage:
|
||||
For each folder, run this script.
|
||||
Typically, there are four folders to be processed for DIV2K dataset.
|
||||
DIV2K_train_HR
|
||||
DIV2K_train_LR_bicubic/X2
|
||||
DIV2K_train_LR_bicubic/X3
|
||||
DIV2K_train_LR_bicubic/X4
|
||||
After process, each sub_folder should have the same number of
|
||||
subimages.
|
||||
Typically, there are GT folder and LQ folder to be processed for DIV2K dataset.
|
||||
After process, each sub_folder should have the same number of subimages.
|
||||
Remember to modify opt configurations according to your settings.
|
||||
"""
|
||||
|
||||
opt = {}
|
||||
opt['n_thread'] = args.n_thread
|
||||
opt['compression_level'] = args.compression_level
|
||||
|
||||
# HR images
|
||||
opt['input_folder'] = args.input
|
||||
opt['save_folder'] = args.output
|
||||
opt['crop_size'] = args.crop_size
|
||||
@@ -68,6 +58,7 @@ def extract_subimages(opt):
|
||||
print(f'Folder {save_folder} already exists. Exit.')
|
||||
sys.exit(1)
|
||||
|
||||
# scan all images
|
||||
img_list = list(scandir(input_folder, full_path=True))
|
||||
|
||||
pbar = tqdm(total=len(img_list), unit='image', desc='Extract')
|
||||
@@ -88,8 +79,7 @@ def worker(path, opt):
|
||||
opt (dict): Configuration dict. It contains:
|
||||
crop_size (int): Crop size.
|
||||
step (int): Step for overlapped sliding window.
|
||||
thresh_size (int): Threshold size. Patches whose size is lower
|
||||
than thresh_size will be dropped.
|
||||
thresh_size (int): Threshold size. Patches whose size is lower than thresh_size will be dropped.
|
||||
save_folder (str): Path to save folder.
|
||||
compression_level (int): for cv2.IMWRITE_PNG_COMPRESSION.
|
||||
|
||||
|
||||
@@ -11,15 +11,17 @@ def main(args):
|
||||
for img_path in img_paths:
|
||||
status = True
|
||||
if args.check:
|
||||
# read the image once for check, as some images may have errors
|
||||
try:
|
||||
img = cv2.imread(img_path)
|
||||
except Exception as error:
|
||||
except (IOError, OSError) as error:
|
||||
print(f'Read {img_path} error: {error}')
|
||||
status = False
|
||||
if img is None:
|
||||
status = False
|
||||
print(f'Img is None: {img_path}')
|
||||
if status:
|
||||
# get the relative path
|
||||
img_name = os.path.relpath(img_path, root)
|
||||
print(img_name)
|
||||
txt_file.write(f'{img_name}\n')
|
||||
|
||||
@@ -5,6 +5,7 @@ import os
|
||||
|
||||
def main(args):
|
||||
txt_file = open(args.meta_info, 'w')
|
||||
# sca images
|
||||
img_paths_gt = sorted(glob.glob(os.path.join(args.input[0], '*')))
|
||||
img_paths_lq = sorted(glob.glob(os.path.join(args.input[1], '*')))
|
||||
|
||||
@@ -12,6 +13,7 @@ def main(args):
|
||||
f'{len(img_paths_gt)} and {len(img_paths_lq)}.')
|
||||
|
||||
for img_path_gt, img_path_lq in zip(img_paths_gt, img_paths_lq):
|
||||
# get the relative paths
|
||||
img_name_gt = os.path.relpath(img_path_gt, args.root[0])
|
||||
img_name_lq = os.path.relpath(img_path_lq, args.root[1])
|
||||
print(f'{img_name_gt}, {img_name_lq}')
|
||||
@@ -19,7 +21,7 @@ def main(args):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
"""Generate meta info (txt file) for paired images.
|
||||
"""This script is used to generate meta info (txt file) for paired images.
|
||||
"""
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
|
||||
@@ -5,7 +5,6 @@ from PIL import Image
|
||||
|
||||
|
||||
def main(args):
|
||||
|
||||
# For DF2K, we consider the following three scales,
|
||||
# and the smallest image whose shortest edge is 400
|
||||
scale_list = [0.75, 0.5, 1 / 3]
|
||||
@@ -37,6 +36,9 @@ def main(args):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
"""Generate multi-scale versions for GT images with LANCZOS resampling.
|
||||
It is now used for DF2K dataset (DIV2K + Flickr 2K)
|
||||
"""
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--input', type=str, default='datasets/DF2K/DF2K_HR', help='Input folder')
|
||||
parser.add_argument('--output', type=str, default='datasets/DF2K/DF2K_multiscale', help='Output folder')
|
||||
|
||||
@@ -1,17 +1,36 @@
|
||||
import argparse
|
||||
import torch
|
||||
import torch.onnx
|
||||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||
|
||||
# An instance of your model
|
||||
|
||||
def main(args):
|
||||
# An instance of the model
|
||||
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
|
||||
model.load_state_dict(torch.load('experiments/pretrained_models/RealESRGAN_x4plus.pth')['params_ema'])
|
||||
if args.params:
|
||||
keyname = 'params'
|
||||
else:
|
||||
keyname = 'params_ema'
|
||||
model.load_state_dict(torch.load(args.input)[keyname])
|
||||
# set the train mode to false since we will only run the forward pass.
|
||||
model.train(False)
|
||||
model.cpu().eval()
|
||||
|
||||
# An example input you would normally provide to your model's forward() method
|
||||
# An example input
|
||||
x = torch.rand(1, 3, 64, 64)
|
||||
|
||||
# Export the model
|
||||
with torch.no_grad():
|
||||
torch_out = torch.onnx._export(model, x, 'realesrgan-x4.onnx', opset_version=11, export_params=True)
|
||||
torch_out = torch.onnx._export(model, x, args.output, opset_version=11, export_params=True)
|
||||
print(torch_out.shape)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
"""Convert pytorch model to onnx models"""
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
'--input', type=str, default='experiments/pretrained_models/RealESRGAN_x4plus.pth', help='Input model path')
|
||||
parser.add_argument('--output', type=str, default='realesrgan-x4.onnx', help='Output onnx path')
|
||||
parser.add_argument('--params', action='store_false', help='Use params instead of params_ema')
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
||||
|
||||
@@ -17,7 +17,7 @@ line_length = 120
|
||||
multi_line_output = 0
|
||||
known_standard_library = pkg_resources,setuptools
|
||||
known_first_party = realesrgan
|
||||
known_third_party = PIL,basicsr,cv2,numpy,torch,torchvision,tqdm
|
||||
known_third_party = PIL,basicsr,cv2,numpy,pytest,torch,torchvision,tqdm,yaml
|
||||
no_lines_before = STDLIB,LOCALFOLDER
|
||||
default_section = THIRDPARTY
|
||||
|
||||
@@ -25,3 +25,9 @@ default_section = THIRDPARTY
|
||||
skip = .git,./docs/build
|
||||
count =
|
||||
quiet-level = 3
|
||||
|
||||
[aliases]
|
||||
test=pytest
|
||||
|
||||
[tool:pytest]
|
||||
addopts=tests/
|
||||
|
||||
BIN
tests/data/gt.lmdb/data.mdb
Normal file
BIN
tests/data/gt.lmdb/data.mdb
Normal file
Binary file not shown.
BIN
tests/data/gt.lmdb/lock.mdb
Normal file
BIN
tests/data/gt.lmdb/lock.mdb
Normal file
Binary file not shown.
2
tests/data/gt.lmdb/meta_info.txt
Normal file
2
tests/data/gt.lmdb/meta_info.txt
Normal file
@@ -0,0 +1,2 @@
|
||||
baboon.png (480,500,3) 1
|
||||
comic.png (360,240,3) 1
|
||||
BIN
tests/data/gt/baboon.png
Normal file
BIN
tests/data/gt/baboon.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 532 KiB |
BIN
tests/data/gt/comic.png
Normal file
BIN
tests/data/gt/comic.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 195 KiB |
BIN
tests/data/lq.lmdb/data.mdb
Normal file
BIN
tests/data/lq.lmdb/data.mdb
Normal file
Binary file not shown.
BIN
tests/data/lq.lmdb/lock.mdb
Normal file
BIN
tests/data/lq.lmdb/lock.mdb
Normal file
Binary file not shown.
2
tests/data/lq.lmdb/meta_info.txt
Normal file
2
tests/data/lq.lmdb/meta_info.txt
Normal file
@@ -0,0 +1,2 @@
|
||||
baboon.png (120,125,3) 1
|
||||
comic.png (80,60,3) 1
|
||||
BIN
tests/data/lq/baboon.png
Normal file
BIN
tests/data/lq/baboon.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 35 KiB |
BIN
tests/data/lq/comic.png
Normal file
BIN
tests/data/lq/comic.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 14 KiB |
2
tests/data/meta_info_gt.txt
Normal file
2
tests/data/meta_info_gt.txt
Normal file
@@ -0,0 +1,2 @@
|
||||
baboon.png
|
||||
comic.png
|
||||
2
tests/data/meta_info_pair.txt
Normal file
2
tests/data/meta_info_pair.txt
Normal file
@@ -0,0 +1,2 @@
|
||||
gt/baboon.png, lq/baboon.png
|
||||
gt/comic.png, lq/comic.png
|
||||
28
tests/data/test_realesrgan_dataset.yml
Normal file
28
tests/data/test_realesrgan_dataset.yml
Normal file
@@ -0,0 +1,28 @@
|
||||
name: Demo
|
||||
type: RealESRGANDataset
|
||||
dataroot_gt: tests/data/gt
|
||||
meta_info: tests/data/meta_info_gt.txt
|
||||
io_backend:
|
||||
type: disk
|
||||
|
||||
blur_kernel_size: 21
|
||||
kernel_list: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
|
||||
kernel_prob: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
|
||||
sinc_prob: 1
|
||||
blur_sigma: [0.2, 3]
|
||||
betag_range: [0.5, 4]
|
||||
betap_range: [1, 2]
|
||||
|
||||
blur_kernel_size2: 21
|
||||
kernel_list2: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
|
||||
kernel_prob2: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
|
||||
sinc_prob2: 1
|
||||
blur_sigma2: [0.2, 1.5]
|
||||
betag_range2: [0.5, 4]
|
||||
betap_range2: [1, 2]
|
||||
|
||||
final_sinc_prob: 1
|
||||
|
||||
gt_size: 128
|
||||
use_hflip: True
|
||||
use_rot: False
|
||||
115
tests/data/test_realesrgan_model.yml
Normal file
115
tests/data/test_realesrgan_model.yml
Normal file
@@ -0,0 +1,115 @@
|
||||
scale: 4
|
||||
num_gpu: 1
|
||||
manual_seed: 0
|
||||
is_train: True
|
||||
dist: False
|
||||
|
||||
# ----------------- options for synthesizing training data ----------------- #
|
||||
# USM the ground-truth
|
||||
l1_gt_usm: True
|
||||
percep_gt_usm: True
|
||||
gan_gt_usm: False
|
||||
|
||||
# the first degradation process
|
||||
resize_prob: [0.2, 0.7, 0.1] # up, down, keep
|
||||
resize_range: [0.15, 1.5]
|
||||
gaussian_noise_prob: 1
|
||||
noise_range: [1, 30]
|
||||
poisson_scale_range: [0.05, 3]
|
||||
gray_noise_prob: 1
|
||||
jpeg_range: [30, 95]
|
||||
|
||||
# the second degradation process
|
||||
second_blur_prob: 1
|
||||
resize_prob2: [0.3, 0.4, 0.3] # up, down, keep
|
||||
resize_range2: [0.3, 1.2]
|
||||
gaussian_noise_prob2: 1
|
||||
noise_range2: [1, 25]
|
||||
poisson_scale_range2: [0.05, 2.5]
|
||||
gray_noise_prob2: 1
|
||||
jpeg_range2: [30, 95]
|
||||
|
||||
gt_size: 32
|
||||
queue_size: 1
|
||||
|
||||
# network structures
|
||||
network_g:
|
||||
type: RRDBNet
|
||||
num_in_ch: 3
|
||||
num_out_ch: 3
|
||||
num_feat: 4
|
||||
num_block: 1
|
||||
num_grow_ch: 2
|
||||
|
||||
network_d:
|
||||
type: UNetDiscriminatorSN
|
||||
num_in_ch: 3
|
||||
num_feat: 2
|
||||
skip_connection: True
|
||||
|
||||
# path
|
||||
path:
|
||||
pretrain_network_g: ~
|
||||
param_key_g: params_ema
|
||||
strict_load_g: true
|
||||
resume_state: ~
|
||||
|
||||
# training settings
|
||||
train:
|
||||
ema_decay: 0.999
|
||||
optim_g:
|
||||
type: Adam
|
||||
lr: !!float 1e-4
|
||||
weight_decay: 0
|
||||
betas: [0.9, 0.99]
|
||||
optim_d:
|
||||
type: Adam
|
||||
lr: !!float 1e-4
|
||||
weight_decay: 0
|
||||
betas: [0.9, 0.99]
|
||||
|
||||
scheduler:
|
||||
type: MultiStepLR
|
||||
milestones: [400000]
|
||||
gamma: 0.5
|
||||
|
||||
total_iter: 400000
|
||||
warmup_iter: -1 # no warm up
|
||||
|
||||
# losses
|
||||
pixel_opt:
|
||||
type: L1Loss
|
||||
loss_weight: 1.0
|
||||
reduction: mean
|
||||
# perceptual loss (content and style losses)
|
||||
perceptual_opt:
|
||||
type: PerceptualLoss
|
||||
layer_weights:
|
||||
# before relu
|
||||
'conv1_2': 0.1
|
||||
'conv2_2': 0.1
|
||||
'conv3_4': 1
|
||||
'conv4_4': 1
|
||||
'conv5_4': 1
|
||||
vgg_type: vgg19
|
||||
use_input_norm: true
|
||||
perceptual_weight: !!float 1.0
|
||||
style_weight: 0
|
||||
range_norm: false
|
||||
criterion: l1
|
||||
# gan loss
|
||||
gan_opt:
|
||||
type: GANLoss
|
||||
gan_type: vanilla
|
||||
real_label_val: 1.0
|
||||
fake_label_val: 0.0
|
||||
loss_weight: !!float 1e-1
|
||||
|
||||
net_d_iters: 1
|
||||
net_d_init_iters: 0
|
||||
|
||||
|
||||
# validation settings
|
||||
val:
|
||||
val_freq: !!float 5e3
|
||||
save_img: False
|
||||
13
tests/data/test_realesrgan_paired_dataset.yml
Normal file
13
tests/data/test_realesrgan_paired_dataset.yml
Normal file
@@ -0,0 +1,13 @@
|
||||
name: Demo
|
||||
type: RealESRGANPairedDataset
|
||||
scale: 4
|
||||
dataroot_gt: tests/data
|
||||
dataroot_lq: tests/data
|
||||
meta_info: tests/data/meta_info_pair.txt
|
||||
io_backend:
|
||||
type: disk
|
||||
|
||||
phase: train
|
||||
gt_size: 128
|
||||
use_hflip: True
|
||||
use_rot: False
|
||||
75
tests/data/test_realesrnet_model.yml
Normal file
75
tests/data/test_realesrnet_model.yml
Normal file
@@ -0,0 +1,75 @@
|
||||
scale: 4
|
||||
num_gpu: 1
|
||||
manual_seed: 0
|
||||
is_train: True
|
||||
dist: False
|
||||
|
||||
# ----------------- options for synthesizing training data ----------------- #
|
||||
gt_usm: True # USM the ground-truth
|
||||
|
||||
# the first degradation process
|
||||
resize_prob: [0.2, 0.7, 0.1] # up, down, keep
|
||||
resize_range: [0.15, 1.5]
|
||||
gaussian_noise_prob: 1
|
||||
noise_range: [1, 30]
|
||||
poisson_scale_range: [0.05, 3]
|
||||
gray_noise_prob: 1
|
||||
jpeg_range: [30, 95]
|
||||
|
||||
# the second degradation process
|
||||
second_blur_prob: 1
|
||||
resize_prob2: [0.3, 0.4, 0.3] # up, down, keep
|
||||
resize_range2: [0.3, 1.2]
|
||||
gaussian_noise_prob2: 1
|
||||
noise_range2: [1, 25]
|
||||
poisson_scale_range2: [0.05, 2.5]
|
||||
gray_noise_prob2: 1
|
||||
jpeg_range2: [30, 95]
|
||||
|
||||
gt_size: 32
|
||||
queue_size: 1
|
||||
|
||||
# network structures
|
||||
network_g:
|
||||
type: RRDBNet
|
||||
num_in_ch: 3
|
||||
num_out_ch: 3
|
||||
num_feat: 4
|
||||
num_block: 1
|
||||
num_grow_ch: 2
|
||||
|
||||
# path
|
||||
path:
|
||||
pretrain_network_g: ~
|
||||
param_key_g: params_ema
|
||||
strict_load_g: true
|
||||
resume_state: ~
|
||||
|
||||
# training settings
|
||||
train:
|
||||
ema_decay: 0.999
|
||||
optim_g:
|
||||
type: Adam
|
||||
lr: !!float 2e-4
|
||||
weight_decay: 0
|
||||
betas: [0.9, 0.99]
|
||||
|
||||
scheduler:
|
||||
type: MultiStepLR
|
||||
milestones: [1000000]
|
||||
gamma: 0.5
|
||||
|
||||
total_iter: 1000000
|
||||
warmup_iter: -1 # no warm up
|
||||
|
||||
# losses
|
||||
pixel_opt:
|
||||
type: L1Loss
|
||||
loss_weight: 1.0
|
||||
reduction: mean
|
||||
|
||||
|
||||
# validation settings
|
||||
val:
|
||||
val_freq: !!float 5e3
|
||||
save_img: False
|
||||
151
tests/test_dataset.py
Normal file
151
tests/test_dataset.py
Normal file
@@ -0,0 +1,151 @@
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
from realesrgan.data.realesrgan_dataset import RealESRGANDataset
|
||||
from realesrgan.data.realesrgan_paired_dataset import RealESRGANPairedDataset
|
||||
|
||||
|
||||
def test_realesrgan_dataset():
|
||||
|
||||
with open('tests/data/test_realesrgan_dataset.yml', mode='r') as f:
|
||||
opt = yaml.load(f, Loader=yaml.FullLoader)
|
||||
|
||||
dataset = RealESRGANDataset(opt)
|
||||
assert dataset.io_backend_opt['type'] == 'disk' # io backend
|
||||
assert len(dataset) == 2 # whether to read correct meta info
|
||||
assert dataset.kernel_list == [
|
||||
'iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso'
|
||||
] # correct initialization the degradation configurations
|
||||
assert dataset.betag_range2 == [0.5, 4]
|
||||
|
||||
# test __getitem__
|
||||
result = dataset.__getitem__(0)
|
||||
# check returned keys
|
||||
expected_keys = ['gt', 'kernel1', 'kernel2', 'sinc_kernel', 'gt_path']
|
||||
assert set(expected_keys).issubset(set(result.keys()))
|
||||
# check shape and contents
|
||||
assert result['gt'].shape == (3, 400, 400)
|
||||
assert result['kernel1'].shape == (21, 21)
|
||||
assert result['kernel2'].shape == (21, 21)
|
||||
assert result['sinc_kernel'].shape == (21, 21)
|
||||
assert result['gt_path'] == 'tests/data/gt/baboon.png'
|
||||
|
||||
# ------------------ test lmdb backend -------------------- #
|
||||
opt['dataroot_gt'] = 'tests/data/gt.lmdb'
|
||||
opt['io_backend']['type'] = 'lmdb'
|
||||
|
||||
dataset = RealESRGANDataset(opt)
|
||||
assert dataset.io_backend_opt['type'] == 'lmdb' # io backend
|
||||
assert len(dataset.paths) == 2 # whether to read correct meta info
|
||||
assert dataset.kernel_list == [
|
||||
'iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso'
|
||||
] # correct initialization the degradation configurations
|
||||
assert dataset.betag_range2 == [0.5, 4]
|
||||
|
||||
# test __getitem__
|
||||
result = dataset.__getitem__(1)
|
||||
# check returned keys
|
||||
expected_keys = ['gt', 'kernel1', 'kernel2', 'sinc_kernel', 'gt_path']
|
||||
assert set(expected_keys).issubset(set(result.keys()))
|
||||
# check shape and contents
|
||||
assert result['gt'].shape == (3, 400, 400)
|
||||
assert result['kernel1'].shape == (21, 21)
|
||||
assert result['kernel2'].shape == (21, 21)
|
||||
assert result['sinc_kernel'].shape == (21, 21)
|
||||
assert result['gt_path'] == 'comic'
|
||||
|
||||
# ------------------ test with sinc_prob = 0 -------------------- #
|
||||
opt['dataroot_gt'] = 'tests/data/gt.lmdb'
|
||||
opt['io_backend']['type'] = 'lmdb'
|
||||
opt['sinc_prob'] = 0
|
||||
opt['sinc_prob2'] = 0
|
||||
opt['final_sinc_prob'] = 0
|
||||
dataset = RealESRGANDataset(opt)
|
||||
result = dataset.__getitem__(0)
|
||||
# check returned keys
|
||||
expected_keys = ['gt', 'kernel1', 'kernel2', 'sinc_kernel', 'gt_path']
|
||||
assert set(expected_keys).issubset(set(result.keys()))
|
||||
# check shape and contents
|
||||
assert result['gt'].shape == (3, 400, 400)
|
||||
assert result['kernel1'].shape == (21, 21)
|
||||
assert result['kernel2'].shape == (21, 21)
|
||||
assert result['sinc_kernel'].shape == (21, 21)
|
||||
assert result['gt_path'] == 'baboon'
|
||||
|
||||
# ------------------ lmdb backend should have paths ends with lmdb -------------------- #
|
||||
with pytest.raises(ValueError):
|
||||
opt['dataroot_gt'] = 'tests/data/gt'
|
||||
opt['io_backend']['type'] = 'lmdb'
|
||||
dataset = RealESRGANDataset(opt)
|
||||
|
||||
|
||||
def test_realesrgan_paired_dataset():
|
||||
|
||||
with open('tests/data/test_realesrgan_paired_dataset.yml', mode='r') as f:
|
||||
opt = yaml.load(f, Loader=yaml.FullLoader)
|
||||
|
||||
dataset = RealESRGANPairedDataset(opt)
|
||||
assert dataset.io_backend_opt['type'] == 'disk' # io backend
|
||||
assert len(dataset) == 2 # whether to read correct meta info
|
||||
|
||||
# test __getitem__
|
||||
result = dataset.__getitem__(0)
|
||||
# check returned keys
|
||||
expected_keys = ['gt', 'lq', 'gt_path', 'lq_path']
|
||||
assert set(expected_keys).issubset(set(result.keys()))
|
||||
# check shape and contents
|
||||
assert result['gt'].shape == (3, 128, 128)
|
||||
assert result['lq'].shape == (3, 32, 32)
|
||||
assert result['gt_path'] == 'tests/data/gt/baboon.png'
|
||||
assert result['lq_path'] == 'tests/data/lq/baboon.png'
|
||||
|
||||
# ------------------ test lmdb backend -------------------- #
|
||||
opt['dataroot_gt'] = 'tests/data/gt.lmdb'
|
||||
opt['dataroot_lq'] = 'tests/data/lq.lmdb'
|
||||
opt['io_backend']['type'] = 'lmdb'
|
||||
|
||||
dataset = RealESRGANPairedDataset(opt)
|
||||
assert dataset.io_backend_opt['type'] == 'lmdb' # io backend
|
||||
assert len(dataset) == 2 # whether to read correct meta info
|
||||
|
||||
# test __getitem__
|
||||
result = dataset.__getitem__(1)
|
||||
# check returned keys
|
||||
expected_keys = ['gt', 'lq', 'gt_path', 'lq_path']
|
||||
assert set(expected_keys).issubset(set(result.keys()))
|
||||
# check shape and contents
|
||||
assert result['gt'].shape == (3, 128, 128)
|
||||
assert result['lq'].shape == (3, 32, 32)
|
||||
assert result['gt_path'] == 'comic'
|
||||
assert result['lq_path'] == 'comic'
|
||||
|
||||
# ------------------ test paired_paths_from_folder -------------------- #
|
||||
opt['dataroot_gt'] = 'tests/data/gt'
|
||||
opt['dataroot_lq'] = 'tests/data/lq'
|
||||
opt['io_backend'] = dict(type='disk')
|
||||
opt['meta_info'] = None
|
||||
|
||||
dataset = RealESRGANPairedDataset(opt)
|
||||
assert dataset.io_backend_opt['type'] == 'disk' # io backend
|
||||
assert len(dataset) == 2 # whether to read correct meta info
|
||||
|
||||
# test __getitem__
|
||||
result = dataset.__getitem__(0)
|
||||
# check returned keys
|
||||
expected_keys = ['gt', 'lq', 'gt_path', 'lq_path']
|
||||
assert set(expected_keys).issubset(set(result.keys()))
|
||||
# check shape and contents
|
||||
assert result['gt'].shape == (3, 128, 128)
|
||||
assert result['lq'].shape == (3, 32, 32)
|
||||
|
||||
# ------------------ test normalization -------------------- #
|
||||
dataset.mean = [0.5, 0.5, 0.5]
|
||||
dataset.std = [0.5, 0.5, 0.5]
|
||||
# test __getitem__
|
||||
result = dataset.__getitem__(0)
|
||||
# check returned keys
|
||||
expected_keys = ['gt', 'lq', 'gt_path', 'lq_path']
|
||||
assert set(expected_keys).issubset(set(result.keys()))
|
||||
# check shape and contents
|
||||
assert result['gt'].shape == (3, 128, 128)
|
||||
assert result['lq'].shape == (3, 32, 32)
|
||||
19
tests/test_discriminator_arch.py
Normal file
19
tests/test_discriminator_arch.py
Normal file
@@ -0,0 +1,19 @@
|
||||
import torch
|
||||
|
||||
from realesrgan.archs.discriminator_arch import UNetDiscriminatorSN
|
||||
|
||||
|
||||
def test_unetdiscriminatorsn():
|
||||
"""Test arch: UNetDiscriminatorSN."""
|
||||
|
||||
# model init and forward (cpu)
|
||||
net = UNetDiscriminatorSN(num_in_ch=3, num_feat=4, skip_connection=True)
|
||||
img = torch.rand((1, 3, 32, 32), dtype=torch.float32)
|
||||
output = net(img)
|
||||
assert output.shape == (1, 1, 32, 32)
|
||||
|
||||
# model init and forward (gpu)
|
||||
if torch.cuda.is_available():
|
||||
net.cuda()
|
||||
output = net(img.cuda())
|
||||
assert output.shape == (1, 1, 32, 32)
|
||||
126
tests/test_model.py
Normal file
126
tests/test_model.py
Normal file
@@ -0,0 +1,126 @@
|
||||
import torch
|
||||
import yaml
|
||||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||
from basicsr.data.paired_image_dataset import PairedImageDataset
|
||||
from basicsr.losses.losses import GANLoss, L1Loss, PerceptualLoss
|
||||
|
||||
from realesrgan.archs.discriminator_arch import UNetDiscriminatorSN
|
||||
from realesrgan.models.realesrgan_model import RealESRGANModel
|
||||
from realesrgan.models.realesrnet_model import RealESRNetModel
|
||||
|
||||
|
||||
def test_realesrnet_model():
|
||||
with open('tests/data/test_realesrnet_model.yml', mode='r') as f:
|
||||
opt = yaml.load(f, Loader=yaml.FullLoader)
|
||||
|
||||
# build model
|
||||
model = RealESRNetModel(opt)
|
||||
# test attributes
|
||||
assert model.__class__.__name__ == 'RealESRNetModel'
|
||||
assert isinstance(model.net_g, RRDBNet)
|
||||
assert isinstance(model.cri_pix, L1Loss)
|
||||
assert isinstance(model.optimizers[0], torch.optim.Adam)
|
||||
|
||||
# prepare data
|
||||
gt = torch.rand((1, 3, 32, 32), dtype=torch.float32)
|
||||
kernel1 = torch.rand((1, 5, 5), dtype=torch.float32)
|
||||
kernel2 = torch.rand((1, 5, 5), dtype=torch.float32)
|
||||
sinc_kernel = torch.rand((1, 5, 5), dtype=torch.float32)
|
||||
data = dict(gt=gt, kernel1=kernel1, kernel2=kernel2, sinc_kernel=sinc_kernel)
|
||||
model.feed_data(data)
|
||||
# check dequeue
|
||||
model.feed_data(data)
|
||||
# check data shape
|
||||
assert model.lq.shape == (1, 3, 8, 8)
|
||||
assert model.gt.shape == (1, 3, 32, 32)
|
||||
|
||||
# change probability to test if-else
|
||||
model.opt['gaussian_noise_prob'] = 0
|
||||
model.opt['gray_noise_prob'] = 0
|
||||
model.opt['second_blur_prob'] = 0
|
||||
model.opt['gaussian_noise_prob2'] = 0
|
||||
model.opt['gray_noise_prob2'] = 0
|
||||
model.feed_data(data)
|
||||
# check data shape
|
||||
assert model.lq.shape == (1, 3, 8, 8)
|
||||
assert model.gt.shape == (1, 3, 32, 32)
|
||||
|
||||
# ----------------- test nondist_validation -------------------- #
|
||||
# construct dataloader
|
||||
dataset_opt = dict(
|
||||
name='Demo',
|
||||
dataroot_gt='tests/data/gt',
|
||||
dataroot_lq='tests/data/lq',
|
||||
io_backend=dict(type='disk'),
|
||||
scale=4,
|
||||
phase='val')
|
||||
dataset = PairedImageDataset(dataset_opt)
|
||||
dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=1, shuffle=False, num_workers=0)
|
||||
assert model.is_train is True
|
||||
model.nondist_validation(dataloader, 1, None, False)
|
||||
assert model.is_train is True
|
||||
|
||||
|
||||
def test_realesrgan_model():
|
||||
with open('tests/data/test_realesrgan_model.yml', mode='r') as f:
|
||||
opt = yaml.load(f, Loader=yaml.FullLoader)
|
||||
|
||||
# build model
|
||||
model = RealESRGANModel(opt)
|
||||
# test attributes
|
||||
assert model.__class__.__name__ == 'RealESRGANModel'
|
||||
assert isinstance(model.net_g, RRDBNet) # generator
|
||||
assert isinstance(model.net_d, UNetDiscriminatorSN) # discriminator
|
||||
assert isinstance(model.cri_pix, L1Loss)
|
||||
assert isinstance(model.cri_perceptual, PerceptualLoss)
|
||||
assert isinstance(model.cri_gan, GANLoss)
|
||||
assert isinstance(model.optimizers[0], torch.optim.Adam)
|
||||
assert isinstance(model.optimizers[1], torch.optim.Adam)
|
||||
|
||||
# prepare data
|
||||
gt = torch.rand((1, 3, 32, 32), dtype=torch.float32)
|
||||
kernel1 = torch.rand((1, 5, 5), dtype=torch.float32)
|
||||
kernel2 = torch.rand((1, 5, 5), dtype=torch.float32)
|
||||
sinc_kernel = torch.rand((1, 5, 5), dtype=torch.float32)
|
||||
data = dict(gt=gt, kernel1=kernel1, kernel2=kernel2, sinc_kernel=sinc_kernel)
|
||||
model.feed_data(data)
|
||||
# check dequeue
|
||||
model.feed_data(data)
|
||||
# check data shape
|
||||
assert model.lq.shape == (1, 3, 8, 8)
|
||||
assert model.gt.shape == (1, 3, 32, 32)
|
||||
|
||||
# change probability to test if-else
|
||||
model.opt['gaussian_noise_prob'] = 0
|
||||
model.opt['gray_noise_prob'] = 0
|
||||
model.opt['second_blur_prob'] = 0
|
||||
model.opt['gaussian_noise_prob2'] = 0
|
||||
model.opt['gray_noise_prob2'] = 0
|
||||
model.feed_data(data)
|
||||
# check data shape
|
||||
assert model.lq.shape == (1, 3, 8, 8)
|
||||
assert model.gt.shape == (1, 3, 32, 32)
|
||||
|
||||
# ----------------- test nondist_validation -------------------- #
|
||||
# construct dataloader
|
||||
dataset_opt = dict(
|
||||
name='Demo',
|
||||
dataroot_gt='tests/data/gt',
|
||||
dataroot_lq='tests/data/lq',
|
||||
io_backend=dict(type='disk'),
|
||||
scale=4,
|
||||
phase='val')
|
||||
dataset = PairedImageDataset(dataset_opt)
|
||||
dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=1, shuffle=False, num_workers=0)
|
||||
assert model.is_train is True
|
||||
model.nondist_validation(dataloader, 1, None, False)
|
||||
assert model.is_train is True
|
||||
|
||||
# ----------------- test optimize_parameters -------------------- #
|
||||
model.feed_data(data)
|
||||
model.optimize_parameters(1)
|
||||
assert model.output.shape == (1, 3, 32, 32)
|
||||
assert isinstance(model.log_dict, dict)
|
||||
# check returned keys
|
||||
expected_keys = ['l_g_pix', 'l_g_percep', 'l_g_gan', 'l_d_real', 'out_d_real', 'l_d_fake', 'out_d_fake']
|
||||
assert set(expected_keys).issubset(set(model.log_dict.keys()))
|
||||
87
tests/test_utils.py
Normal file
87
tests/test_utils.py
Normal file
@@ -0,0 +1,87 @@
|
||||
import numpy as np
|
||||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||
|
||||
from realesrgan.utils import RealESRGANer
|
||||
|
||||
|
||||
def test_realesrganer():
|
||||
# initialize with default model
|
||||
restorer = RealESRGANer(
|
||||
scale=4,
|
||||
model_path='experiments/pretrained_models/RealESRGAN_x4plus.pth',
|
||||
model=None,
|
||||
tile=10,
|
||||
tile_pad=10,
|
||||
pre_pad=2,
|
||||
half=False)
|
||||
assert isinstance(restorer.model, RRDBNet)
|
||||
assert restorer.half is False
|
||||
# initialize with user-defined model
|
||||
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
|
||||
restorer = RealESRGANer(
|
||||
scale=4,
|
||||
model_path='experiments/pretrained_models/RealESRGAN_x4plus_anime_6B.pth',
|
||||
model=model,
|
||||
tile=10,
|
||||
tile_pad=10,
|
||||
pre_pad=2,
|
||||
half=True)
|
||||
# test attribute
|
||||
assert isinstance(restorer.model, RRDBNet)
|
||||
assert restorer.half is True
|
||||
|
||||
# ------------------ test pre_process ---------------- #
|
||||
img = np.random.random((12, 12, 3)).astype(np.float32)
|
||||
restorer.pre_process(img)
|
||||
assert restorer.img.shape == (1, 3, 14, 14)
|
||||
# with modcrop
|
||||
restorer.scale = 1
|
||||
restorer.pre_process(img)
|
||||
assert restorer.img.shape == (1, 3, 16, 16)
|
||||
|
||||
# ------------------ test process ---------------- #
|
||||
restorer.process()
|
||||
assert restorer.output.shape == (1, 3, 64, 64)
|
||||
|
||||
# ------------------ test post_process ---------------- #
|
||||
restorer.mod_scale = 4
|
||||
output = restorer.post_process()
|
||||
assert output.shape == (1, 3, 60, 60)
|
||||
|
||||
# ------------------ test tile_process ---------------- #
|
||||
restorer.scale = 4
|
||||
img = np.random.random((12, 12, 3)).astype(np.float32)
|
||||
restorer.pre_process(img)
|
||||
restorer.tile_process()
|
||||
assert restorer.output.shape == (1, 3, 64, 64)
|
||||
|
||||
# ------------------ test enhance ---------------- #
|
||||
img = np.random.random((12, 12, 3)).astype(np.float32)
|
||||
result = restorer.enhance(img, outscale=2)
|
||||
assert result[0].shape == (24, 24, 3)
|
||||
assert result[1] == 'RGB'
|
||||
|
||||
# ------------------ test enhance with 16-bit image---------------- #
|
||||
img = np.random.random((4, 4, 3)).astype(np.uint16) + 512
|
||||
result = restorer.enhance(img, outscale=2)
|
||||
assert result[0].shape == (8, 8, 3)
|
||||
assert result[1] == 'RGB'
|
||||
|
||||
# ------------------ test enhance with gray image---------------- #
|
||||
img = np.random.random((4, 4)).astype(np.float32)
|
||||
result = restorer.enhance(img, outscale=2)
|
||||
assert result[0].shape == (8, 8)
|
||||
assert result[1] == 'L'
|
||||
|
||||
# ------------------ test enhance with RGBA---------------- #
|
||||
img = np.random.random((4, 4, 4)).astype(np.float32)
|
||||
result = restorer.enhance(img, outscale=2)
|
||||
assert result[0].shape == (8, 8, 4)
|
||||
assert result[1] == 'RGBA'
|
||||
|
||||
# ------------------ test enhance with RGBA, alpha_upsampler---------------- #
|
||||
restorer.tile_size = 0
|
||||
img = np.random.random((4, 4, 4)).astype(np.float32)
|
||||
result = restorer.enhance(img, outscale=2, alpha_upsampler=None)
|
||||
assert result[0].shape == (8, 8, 4)
|
||||
assert result[1] == 'RGBA'
|
||||
Reference in New Issue
Block a user