30
.github/workflows/publish-pip.yml
vendored
Normal file
30
.github/workflows/publish-pip.yml
vendored
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
name: PyPI Publish
|
||||||
|
|
||||||
|
on: push
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
build-n-publish:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
if: startsWith(github.event.ref, 'refs/tags')
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v2
|
||||||
|
- name: Set up Python 3.8
|
||||||
|
uses: actions/setup-python@v1
|
||||||
|
with:
|
||||||
|
python-version: 3.8
|
||||||
|
- name: Upgrade pip
|
||||||
|
run: pip install pip --upgrade
|
||||||
|
- name: Install PyTorch (cpu)
|
||||||
|
run: pip install torch==1.7.0+cpu torchvision==0.8.1+cpu -f https://download.pytorch.org/whl/torch_stable.html
|
||||||
|
- name: Install dependencies
|
||||||
|
run: pip install -r requirements.txt
|
||||||
|
- name: Build and install
|
||||||
|
run: rm -rf .eggs && pip install -e .
|
||||||
|
- name: Build for distribution
|
||||||
|
# remove bdist_wheel for pip installation with compiling cuda extensions
|
||||||
|
run: python setup.py sdist
|
||||||
|
- name: Publish distribution to PyPI
|
||||||
|
uses: pypa/gh-action-pypi-publish@master
|
||||||
|
with:
|
||||||
|
password: ${{ secrets.PYPI_API_TOKEN }}
|
||||||
4
.github/workflows/pylint.yml
vendored
4
.github/workflows/pylint.yml
vendored
@@ -26,5 +26,5 @@ jobs:
|
|||||||
- name: Lint
|
- name: Lint
|
||||||
run: |
|
run: |
|
||||||
flake8 .
|
flake8 .
|
||||||
isort --check-only --diff data/ models/ inference_realesrgan.py
|
isort --check-only --diff realesrgan/ scripts/ inference_realesrgan.py setup.py
|
||||||
yapf -r -d data/ models/ inference_realesrgan.py
|
yapf -r -d realesrgan/ scripts/ inference_realesrgan.py setup.py
|
||||||
|
|||||||
9
.gitignore
vendored
9
.gitignore
vendored
@@ -1,3 +1,12 @@
|
|||||||
|
# ignored folders
|
||||||
|
datasets/*
|
||||||
|
experiments/*
|
||||||
|
results/*
|
||||||
|
tb_logger/*
|
||||||
|
wandb/*
|
||||||
|
tmp/*
|
||||||
|
|
||||||
|
version.py
|
||||||
.vscode
|
.vscode
|
||||||
|
|
||||||
# Byte-compiled / optimized / DLL files
|
# Byte-compiled / optimized / DLL files
|
||||||
|
|||||||
8
MANIFEST.in
Normal file
8
MANIFEST.in
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
include assets/*
|
||||||
|
include inputs/*
|
||||||
|
include scripts/*.py
|
||||||
|
include inference_realesrgan.py
|
||||||
|
include VERSION
|
||||||
|
include LICENSE
|
||||||
|
include requirements.txt
|
||||||
|
include realesrgan/weights/README.md
|
||||||
@@ -97,6 +97,7 @@ This executable file is based on the wonderful [Tencent/ncnn](https://github.com
|
|||||||
# We use BasicSR for both training and inference
|
# We use BasicSR for both training and inference
|
||||||
pip install basicsr
|
pip install basicsr
|
||||||
pip install -r requirements.txt
|
pip install -r requirements.txt
|
||||||
|
python setup.py develop
|
||||||
```
|
```
|
||||||
|
|
||||||
## :zap: Quick Inference
|
## :zap: Quick Inference
|
||||||
|
|||||||
10
Training.md
10
Training.md
@@ -44,7 +44,7 @@ DF2K_HR_sub/000001_s003.png
|
|||||||
name: DF2K+OST
|
name: DF2K+OST
|
||||||
type: RealESRGANDataset
|
type: RealESRGANDataset
|
||||||
dataroot_gt: datasets/DF2K # modify to the root path of your folder
|
dataroot_gt: datasets/DF2K # modify to the root path of your folder
|
||||||
meta_info: data/meta_info/meta_info_DF2Kmultiscale+OST_sub.txt # modify to your own generate meta info txt
|
meta_info: realesrgan/meta_info/meta_info_DF2Kmultiscale+OST_sub.txt # modify to your own generate meta info txt
|
||||||
io_backend:
|
io_backend:
|
||||||
type: disk
|
type: disk
|
||||||
```
|
```
|
||||||
@@ -76,12 +76,12 @@ DF2K_HR_sub/000001_s003.png
|
|||||||
1. Before the formal training, you may run in the `--debug` mode to see whether everything is OK. We use four GPUs for training:
|
1. Before the formal training, you may run in the `--debug` mode to see whether everything is OK. We use four GPUs for training:
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0,1,2,3 \
|
CUDA_VISIBLE_DEVICES=0,1,2,3 \
|
||||||
python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 train.py -opt options/train_realesrnet_x4plus.yml --launcher pytorch --debug
|
python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 realesrgan/train.py -opt options/train_realesrnet_x4plus.yml --launcher pytorch --debug
|
||||||
```
|
```
|
||||||
1. The formal training. We use four GPUs for training. We use the `--auto_resume` argument to automatically resume the training if necessary.
|
1. The formal training. We use four GPUs for training. We use the `--auto_resume` argument to automatically resume the training if necessary.
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0,1,2,3 \
|
CUDA_VISIBLE_DEVICES=0,1,2,3 \
|
||||||
python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 train.py -opt options/train_realesrnet_x4plus.yml --launcher pytorch --auto_resume
|
python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 realesrgan/train.py -opt options/train_realesrnet_x4plus.yml --launcher pytorch --auto_resume
|
||||||
```
|
```
|
||||||
|
|
||||||
## Train Real-ESRGAN
|
## Train Real-ESRGAN
|
||||||
@@ -91,10 +91,10 @@ DF2K_HR_sub/000001_s003.png
|
|||||||
1. Before the formal training, you may run in the `--debug` mode to see whether everything is OK. We use four GPUs for training:
|
1. Before the formal training, you may run in the `--debug` mode to see whether everything is OK. We use four GPUs for training:
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0,1,2,3 \
|
CUDA_VISIBLE_DEVICES=0,1,2,3 \
|
||||||
python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 train.py -opt options/train_realesrgan_x4plus.yml --launcher pytorch --debug
|
python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 realesrgan/train.py -opt options/train_realesrgan_x4plus.yml --launcher pytorch --debug
|
||||||
```
|
```
|
||||||
1. The formal training. We use four GPUs for training. We use the `--auto_resume` argument to automatically resume the training if necessary.
|
1. The formal training. We use four GPUs for training. We use the `--auto_resume` argument to automatically resume the training if necessary.
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0,1,2,3 \
|
CUDA_VISIBLE_DEVICES=0,1,2,3 \
|
||||||
python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 train.py -opt options/train_realesrgan_x4plus.yml --launcher pytorch --auto_resume
|
python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 realesrgan/train.py -opt options/train_realesrgan_x4plus.yml --launcher pytorch --auto_resume
|
||||||
```
|
```
|
||||||
|
|||||||
@@ -1,12 +1,9 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import cv2
|
import cv2
|
||||||
import glob
|
import glob
|
||||||
import math
|
|
||||||
import numpy as np
|
|
||||||
import os
|
import os
|
||||||
import torch
|
|
||||||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
from realesrgan import RealESRGANer
|
||||||
from torch.nn import functional as F
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
@@ -53,61 +50,8 @@ def main():
|
|||||||
imgname, extension = os.path.splitext(os.path.basename(path))
|
imgname, extension = os.path.splitext(os.path.basename(path))
|
||||||
print('Testing', idx, imgname)
|
print('Testing', idx, imgname)
|
||||||
|
|
||||||
# ------------------------------ read image ------------------------------ #
|
img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
|
||||||
img = cv2.imread(path, cv2.IMREAD_UNCHANGED).astype(np.float32)
|
output, img_mode = upsampler.enhance(img)
|
||||||
if np.max(img) > 255: # 16-bit image
|
|
||||||
max_range = 65535
|
|
||||||
print('\tInput is a 16-bit image')
|
|
||||||
else:
|
|
||||||
max_range = 255
|
|
||||||
img = img / max_range
|
|
||||||
if len(img.shape) == 2: # gray image
|
|
||||||
img_mode = 'L'
|
|
||||||
img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
|
|
||||||
elif img.shape[2] == 4: # RGBA image with alpha channel
|
|
||||||
img_mode = 'RGBA'
|
|
||||||
alpha = img[:, :, 3]
|
|
||||||
img = img[:, :, 0:3]
|
|
||||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
|
||||||
if args.alpha_upsampler == 'realesrgan':
|
|
||||||
alpha = cv2.cvtColor(alpha, cv2.COLOR_GRAY2RGB)
|
|
||||||
else:
|
|
||||||
img_mode = 'RGB'
|
|
||||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
|
||||||
|
|
||||||
# ------------------- process image (without the alpha channel) ------------------- #
|
|
||||||
upsampler.pre_process(img)
|
|
||||||
if args.tile:
|
|
||||||
upsampler.tile_process()
|
|
||||||
else:
|
|
||||||
upsampler.process()
|
|
||||||
output_img = upsampler.post_process()
|
|
||||||
output_img = output_img.data.squeeze().float().cpu().clamp_(0, 1).numpy()
|
|
||||||
output_img = np.transpose(output_img[[2, 1, 0], :, :], (1, 2, 0))
|
|
||||||
if img_mode == 'L':
|
|
||||||
output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2GRAY)
|
|
||||||
|
|
||||||
# ------------------- process the alpha channel if necessary ------------------- #
|
|
||||||
if img_mode == 'RGBA':
|
|
||||||
if args.alpha_upsampler == 'realesrgan':
|
|
||||||
upsampler.pre_process(alpha)
|
|
||||||
if args.tile:
|
|
||||||
upsampler.tile_process()
|
|
||||||
else:
|
|
||||||
upsampler.process()
|
|
||||||
output_alpha = upsampler.post_process()
|
|
||||||
output_alpha = output_alpha.data.squeeze().float().cpu().clamp_(0, 1).numpy()
|
|
||||||
output_alpha = np.transpose(output_alpha[[2, 1, 0], :, :], (1, 2, 0))
|
|
||||||
output_alpha = cv2.cvtColor(output_alpha, cv2.COLOR_BGR2GRAY)
|
|
||||||
else:
|
|
||||||
h, w = alpha.shape[0:2]
|
|
||||||
output_alpha = cv2.resize(alpha, (w * args.scale, h * args.scale), interpolation=cv2.INTER_LINEAR)
|
|
||||||
|
|
||||||
# merge the alpha channel
|
|
||||||
output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2BGRA)
|
|
||||||
output_img[:, :, 3] = output_alpha
|
|
||||||
|
|
||||||
# ------------------------------ save image ------------------------------ #
|
|
||||||
if args.ext == 'auto':
|
if args.ext == 'auto':
|
||||||
extension = extension[1:]
|
extension = extension[1:]
|
||||||
else:
|
else:
|
||||||
@@ -115,141 +59,8 @@ def main():
|
|||||||
if img_mode == 'RGBA': # RGBA images should be saved in png format
|
if img_mode == 'RGBA': # RGBA images should be saved in png format
|
||||||
extension = 'png'
|
extension = 'png'
|
||||||
save_path = os.path.join(args.output, f'{imgname}_{args.suffix}.{extension}')
|
save_path = os.path.join(args.output, f'{imgname}_{args.suffix}.{extension}')
|
||||||
if max_range == 65535: # 16-bit image
|
|
||||||
output = (output_img * 65535.0).round().astype(np.uint16)
|
|
||||||
else:
|
|
||||||
output = (output_img * 255.0).round().astype(np.uint8)
|
|
||||||
cv2.imwrite(save_path, output)
|
cv2.imwrite(save_path, output)
|
||||||
|
|
||||||
|
|
||||||
class RealESRGANer():
|
|
||||||
|
|
||||||
def __init__(self, scale, model_path, tile=0, tile_pad=10, pre_pad=10, half=False):
|
|
||||||
self.scale = scale
|
|
||||||
self.tile_size = tile
|
|
||||||
self.tile_pad = tile_pad
|
|
||||||
self.pre_pad = pre_pad
|
|
||||||
self.mod_scale = None
|
|
||||||
self.half = half
|
|
||||||
|
|
||||||
# initialize model
|
|
||||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
||||||
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=scale)
|
|
||||||
loadnet = torch.load(model_path)
|
|
||||||
if 'params_ema' in loadnet:
|
|
||||||
keyname = 'params_ema'
|
|
||||||
else:
|
|
||||||
keyname = 'params'
|
|
||||||
model.load_state_dict(loadnet[keyname], strict=True)
|
|
||||||
model.eval()
|
|
||||||
self.model = model.to(self.device)
|
|
||||||
if self.half:
|
|
||||||
self.model = self.model.half()
|
|
||||||
|
|
||||||
def pre_process(self, img):
|
|
||||||
img = torch.from_numpy(np.transpose(img, (2, 0, 1))).float()
|
|
||||||
self.img = img.unsqueeze(0).to(self.device)
|
|
||||||
if self.half:
|
|
||||||
self.img = self.img.half()
|
|
||||||
|
|
||||||
# pre_pad
|
|
||||||
if self.pre_pad != 0:
|
|
||||||
self.img = F.pad(self.img, (0, self.pre_pad, 0, self.pre_pad), 'reflect')
|
|
||||||
# mod pad
|
|
||||||
if self.scale == 2:
|
|
||||||
self.mod_scale = 2
|
|
||||||
elif self.scale == 1:
|
|
||||||
self.mod_scale = 4
|
|
||||||
if self.mod_scale is not None:
|
|
||||||
self.mod_pad_h, self.mod_pad_w = 0, 0
|
|
||||||
_, _, h, w = self.img.size()
|
|
||||||
if (h % self.mod_scale != 0):
|
|
||||||
self.mod_pad_h = (self.mod_scale - h % self.mod_scale)
|
|
||||||
if (w % self.mod_scale != 0):
|
|
||||||
self.mod_pad_w = (self.mod_scale - w % self.mod_scale)
|
|
||||||
self.img = F.pad(self.img, (0, self.mod_pad_w, 0, self.mod_pad_h), 'reflect')
|
|
||||||
|
|
||||||
def process(self):
|
|
||||||
try:
|
|
||||||
# inference
|
|
||||||
with torch.no_grad():
|
|
||||||
self.output = self.model(self.img)
|
|
||||||
except Exception as error:
|
|
||||||
print('Error', error)
|
|
||||||
|
|
||||||
def tile_process(self):
|
|
||||||
"""Modified from: https://github.com/ata4/esrgan-launcher
|
|
||||||
"""
|
|
||||||
batch, channel, height, width = self.img.shape
|
|
||||||
output_height = height * self.scale
|
|
||||||
output_width = width * self.scale
|
|
||||||
output_shape = (batch, channel, output_height, output_width)
|
|
||||||
|
|
||||||
# start with black image
|
|
||||||
self.output = self.img.new_zeros(output_shape)
|
|
||||||
tiles_x = math.ceil(width / self.tile_size)
|
|
||||||
tiles_y = math.ceil(height / self.tile_size)
|
|
||||||
|
|
||||||
# loop over all tiles
|
|
||||||
for y in range(tiles_y):
|
|
||||||
for x in range(tiles_x):
|
|
||||||
# extract tile from input image
|
|
||||||
ofs_x = x * self.tile_size
|
|
||||||
ofs_y = y * self.tile_size
|
|
||||||
# input tile area on total image
|
|
||||||
input_start_x = ofs_x
|
|
||||||
input_end_x = min(ofs_x + self.tile_size, width)
|
|
||||||
input_start_y = ofs_y
|
|
||||||
input_end_y = min(ofs_y + self.tile_size, height)
|
|
||||||
|
|
||||||
# input tile area on total image with padding
|
|
||||||
input_start_x_pad = max(input_start_x - self.tile_pad, 0)
|
|
||||||
input_end_x_pad = min(input_end_x + self.tile_pad, width)
|
|
||||||
input_start_y_pad = max(input_start_y - self.tile_pad, 0)
|
|
||||||
input_end_y_pad = min(input_end_y + self.tile_pad, height)
|
|
||||||
|
|
||||||
# input tile dimensions
|
|
||||||
input_tile_width = input_end_x - input_start_x
|
|
||||||
input_tile_height = input_end_y - input_start_y
|
|
||||||
tile_idx = y * tiles_x + x + 1
|
|
||||||
input_tile = self.img[:, :, input_start_y_pad:input_end_y_pad, input_start_x_pad:input_end_x_pad]
|
|
||||||
|
|
||||||
# upscale tile
|
|
||||||
try:
|
|
||||||
with torch.no_grad():
|
|
||||||
output_tile = self.model(input_tile)
|
|
||||||
except Exception as error:
|
|
||||||
print('Error', error)
|
|
||||||
print(f'\tTile {tile_idx}/{tiles_x * tiles_y}')
|
|
||||||
|
|
||||||
# output tile area on total image
|
|
||||||
output_start_x = input_start_x * self.scale
|
|
||||||
output_end_x = input_end_x * self.scale
|
|
||||||
output_start_y = input_start_y * self.scale
|
|
||||||
output_end_y = input_end_y * self.scale
|
|
||||||
|
|
||||||
# output tile area without padding
|
|
||||||
output_start_x_tile = (input_start_x - input_start_x_pad) * self.scale
|
|
||||||
output_end_x_tile = output_start_x_tile + input_tile_width * self.scale
|
|
||||||
output_start_y_tile = (input_start_y - input_start_y_pad) * self.scale
|
|
||||||
output_end_y_tile = output_start_y_tile + input_tile_height * self.scale
|
|
||||||
|
|
||||||
# put tile into output image
|
|
||||||
self.output[:, :, output_start_y:output_end_y,
|
|
||||||
output_start_x:output_end_x] = output_tile[:, :, output_start_y_tile:output_end_y_tile,
|
|
||||||
output_start_x_tile:output_end_x_tile]
|
|
||||||
|
|
||||||
def post_process(self):
|
|
||||||
# remove extra pad
|
|
||||||
if self.mod_scale is not None:
|
|
||||||
_, _, h, w = self.output.size()
|
|
||||||
self.output = self.output[:, :, 0:h - self.mod_pad_h * self.scale, 0:w - self.mod_pad_w * self.scale]
|
|
||||||
# remove prepad
|
|
||||||
if self.pre_pad != 0:
|
|
||||||
_, _, h, w = self.output.size()
|
|
||||||
self.output = self.output[:, :, 0:h - self.pre_pad * self.scale, 0:w - self.pre_pad * self.scale]
|
|
||||||
return self.output
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
main()
|
main()
|
||||||
|
|||||||
@@ -39,7 +39,7 @@ datasets:
|
|||||||
name: DF2K+OST
|
name: DF2K+OST
|
||||||
type: RealESRGANDataset
|
type: RealESRGANDataset
|
||||||
dataroot_gt: datasets/DF2K
|
dataroot_gt: datasets/DF2K
|
||||||
meta_info: data/meta_info/meta_info_DF2Kmultiscale+OST_sub.txt
|
meta_info: realesrgan/data/meta_info/meta_info_DF2Kmultiscale+OST_sub.txt
|
||||||
io_backend:
|
io_backend:
|
||||||
type: disk
|
type: disk
|
||||||
|
|
||||||
@@ -100,7 +100,7 @@ network_d:
|
|||||||
# path
|
# path
|
||||||
path:
|
path:
|
||||||
# use the pre-trained Real-ESRNet model
|
# use the pre-trained Real-ESRNet model
|
||||||
pretrain_network_g: experiments/train_RealESRNetx4plus_1000k_B12G4_fromESRGAN/model/net_g_1000000.pth
|
pretrain_network_g: experiments/train_RealESRNetx4plus_1000k_B12G4_fromESRGAN/models/net_g_1000000.pth
|
||||||
param_key_g: params_ema
|
param_key_g: params_ema
|
||||||
strict_load_g: true
|
strict_load_g: true
|
||||||
resume_state: ~
|
resume_state: ~
|
||||||
|
|||||||
@@ -36,7 +36,7 @@ datasets:
|
|||||||
name: DF2K+OST
|
name: DF2K+OST
|
||||||
type: RealESRGANDataset
|
type: RealESRGANDataset
|
||||||
dataroot_gt: datasets/DF2K
|
dataroot_gt: datasets/DF2K
|
||||||
meta_info: data/meta_info/meta_info_DF2Kmultiscale+OST_sub.txt
|
meta_info: realesrgan/data/meta_info/meta_info_DF2Kmultiscale+OST_sub.txt
|
||||||
io_backend:
|
io_backend:
|
||||||
type: disk
|
type: disk
|
||||||
|
|
||||||
|
|||||||
6
realesrgan/__init__.py
Normal file
6
realesrgan/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
# flake8: noqa
|
||||||
|
from .archs import *
|
||||||
|
from .data import *
|
||||||
|
from .models import *
|
||||||
|
from .utils import *
|
||||||
|
from .version import __gitsha__, __version__
|
||||||
@@ -7,4 +7,4 @@ from os import path as osp
|
|||||||
arch_folder = osp.dirname(osp.abspath(__file__))
|
arch_folder = osp.dirname(osp.abspath(__file__))
|
||||||
arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')]
|
arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')]
|
||||||
# import all the arch modules
|
# import all the arch modules
|
||||||
_arch_modules = [importlib.import_module(f'archs.{file_name}') for file_name in arch_filenames]
|
_arch_modules = [importlib.import_module(f'realesrgan.archs.{file_name}') for file_name in arch_filenames]
|
||||||
@@ -7,4 +7,4 @@ from os import path as osp
|
|||||||
data_folder = osp.dirname(osp.abspath(__file__))
|
data_folder = osp.dirname(osp.abspath(__file__))
|
||||||
dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')]
|
dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')]
|
||||||
# import all the dataset modules
|
# import all the dataset modules
|
||||||
_dataset_modules = [importlib.import_module(f'data.{file_name}') for file_name in dataset_filenames]
|
_dataset_modules = [importlib.import_module(f'realesrgan.data.{file_name}') for file_name in dataset_filenames]
|
||||||
@@ -7,4 +7,4 @@ from os import path as osp
|
|||||||
model_folder = osp.dirname(osp.abspath(__file__))
|
model_folder = osp.dirname(osp.abspath(__file__))
|
||||||
model_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) if v.endswith('_model.py')]
|
model_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) if v.endswith('_model.py')]
|
||||||
# import all the model modules
|
# import all the model modules
|
||||||
_model_modules = [importlib.import_module(f'models.{file_name}') for file_name in model_filenames]
|
_model_modules = [importlib.import_module(f'realesrgan.models.{file_name}') for file_name in model_filenames]
|
||||||
11
realesrgan/train.py
Normal file
11
realesrgan/train.py
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
# flake8: noqa
|
||||||
|
import os.path as osp
|
||||||
|
from basicsr.train import train_pipeline
|
||||||
|
|
||||||
|
import realesrgan.archs
|
||||||
|
import realesrgan.data
|
||||||
|
import realesrgan.models
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir))
|
||||||
|
train_pipeline(root_path)
|
||||||
226
realesrgan/utils.py
Normal file
226
realesrgan/utils.py
Normal file
@@ -0,0 +1,226 @@
|
|||||||
|
import cv2
|
||||||
|
import math
|
||||||
|
import numpy as np
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||||
|
from torch.hub import download_url_to_file, get_dir
|
||||||
|
from torch.nn import functional as F
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||||
|
|
||||||
|
|
||||||
|
class RealESRGANer():
|
||||||
|
|
||||||
|
def __init__(self, scale, model_path, tile=0, tile_pad=10, pre_pad=10, half=False):
|
||||||
|
self.scale = scale
|
||||||
|
self.tile_size = tile
|
||||||
|
self.tile_pad = tile_pad
|
||||||
|
self.pre_pad = pre_pad
|
||||||
|
self.mod_scale = None
|
||||||
|
self.half = half
|
||||||
|
|
||||||
|
# initialize model
|
||||||
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
|
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=scale)
|
||||||
|
|
||||||
|
if model_path.startswith('https://'):
|
||||||
|
model_path = load_file_from_url(
|
||||||
|
url=model_path, model_dir='realesrgan/weights', progress=True, file_name=None)
|
||||||
|
loadnet = torch.load(model_path)
|
||||||
|
if 'params_ema' in loadnet:
|
||||||
|
keyname = 'params_ema'
|
||||||
|
else:
|
||||||
|
keyname = 'params'
|
||||||
|
model.load_state_dict(loadnet[keyname], strict=True)
|
||||||
|
model.eval()
|
||||||
|
self.model = model.to(self.device)
|
||||||
|
if self.half:
|
||||||
|
self.model = self.model.half()
|
||||||
|
|
||||||
|
def pre_process(self, img):
|
||||||
|
img = torch.from_numpy(np.transpose(img, (2, 0, 1))).float()
|
||||||
|
self.img = img.unsqueeze(0).to(self.device)
|
||||||
|
if self.half:
|
||||||
|
self.img = self.img.half()
|
||||||
|
|
||||||
|
# pre_pad
|
||||||
|
if self.pre_pad != 0:
|
||||||
|
self.img = F.pad(self.img, (0, self.pre_pad, 0, self.pre_pad), 'reflect')
|
||||||
|
# mod pad
|
||||||
|
if self.scale == 2:
|
||||||
|
self.mod_scale = 2
|
||||||
|
elif self.scale == 1:
|
||||||
|
self.mod_scale = 4
|
||||||
|
if self.mod_scale is not None:
|
||||||
|
self.mod_pad_h, self.mod_pad_w = 0, 0
|
||||||
|
_, _, h, w = self.img.size()
|
||||||
|
if (h % self.mod_scale != 0):
|
||||||
|
self.mod_pad_h = (self.mod_scale - h % self.mod_scale)
|
||||||
|
if (w % self.mod_scale != 0):
|
||||||
|
self.mod_pad_w = (self.mod_scale - w % self.mod_scale)
|
||||||
|
self.img = F.pad(self.img, (0, self.mod_pad_w, 0, self.mod_pad_h), 'reflect')
|
||||||
|
|
||||||
|
def process(self):
|
||||||
|
try:
|
||||||
|
# inference
|
||||||
|
with torch.no_grad():
|
||||||
|
self.output = self.model(self.img)
|
||||||
|
except Exception as error:
|
||||||
|
print('Error', error)
|
||||||
|
|
||||||
|
def tile_process(self):
|
||||||
|
"""Modified from: https://github.com/ata4/esrgan-launcher
|
||||||
|
"""
|
||||||
|
batch, channel, height, width = self.img.shape
|
||||||
|
output_height = height * self.scale
|
||||||
|
output_width = width * self.scale
|
||||||
|
output_shape = (batch, channel, output_height, output_width)
|
||||||
|
|
||||||
|
# start with black image
|
||||||
|
self.output = self.img.new_zeros(output_shape)
|
||||||
|
tiles_x = math.ceil(width / self.tile_size)
|
||||||
|
tiles_y = math.ceil(height / self.tile_size)
|
||||||
|
|
||||||
|
# loop over all tiles
|
||||||
|
for y in range(tiles_y):
|
||||||
|
for x in range(tiles_x):
|
||||||
|
# extract tile from input image
|
||||||
|
ofs_x = x * self.tile_size
|
||||||
|
ofs_y = y * self.tile_size
|
||||||
|
# input tile area on total image
|
||||||
|
input_start_x = ofs_x
|
||||||
|
input_end_x = min(ofs_x + self.tile_size, width)
|
||||||
|
input_start_y = ofs_y
|
||||||
|
input_end_y = min(ofs_y + self.tile_size, height)
|
||||||
|
|
||||||
|
# input tile area on total image with padding
|
||||||
|
input_start_x_pad = max(input_start_x - self.tile_pad, 0)
|
||||||
|
input_end_x_pad = min(input_end_x + self.tile_pad, width)
|
||||||
|
input_start_y_pad = max(input_start_y - self.tile_pad, 0)
|
||||||
|
input_end_y_pad = min(input_end_y + self.tile_pad, height)
|
||||||
|
|
||||||
|
# input tile dimensions
|
||||||
|
input_tile_width = input_end_x - input_start_x
|
||||||
|
input_tile_height = input_end_y - input_start_y
|
||||||
|
tile_idx = y * tiles_x + x + 1
|
||||||
|
input_tile = self.img[:, :, input_start_y_pad:input_end_y_pad, input_start_x_pad:input_end_x_pad]
|
||||||
|
|
||||||
|
# upscale tile
|
||||||
|
try:
|
||||||
|
with torch.no_grad():
|
||||||
|
output_tile = self.model(input_tile)
|
||||||
|
except Exception as error:
|
||||||
|
print('Error', error)
|
||||||
|
print(f'\tTile {tile_idx}/{tiles_x * tiles_y}')
|
||||||
|
|
||||||
|
# output tile area on total image
|
||||||
|
output_start_x = input_start_x * self.scale
|
||||||
|
output_end_x = input_end_x * self.scale
|
||||||
|
output_start_y = input_start_y * self.scale
|
||||||
|
output_end_y = input_end_y * self.scale
|
||||||
|
|
||||||
|
# output tile area without padding
|
||||||
|
output_start_x_tile = (input_start_x - input_start_x_pad) * self.scale
|
||||||
|
output_end_x_tile = output_start_x_tile + input_tile_width * self.scale
|
||||||
|
output_start_y_tile = (input_start_y - input_start_y_pad) * self.scale
|
||||||
|
output_end_y_tile = output_start_y_tile + input_tile_height * self.scale
|
||||||
|
|
||||||
|
# put tile into output image
|
||||||
|
self.output[:, :, output_start_y:output_end_y,
|
||||||
|
output_start_x:output_end_x] = output_tile[:, :, output_start_y_tile:output_end_y_tile,
|
||||||
|
output_start_x_tile:output_end_x_tile]
|
||||||
|
|
||||||
|
def post_process(self):
|
||||||
|
# remove extra pad
|
||||||
|
if self.mod_scale is not None:
|
||||||
|
_, _, h, w = self.output.size()
|
||||||
|
self.output = self.output[:, :, 0:h - self.mod_pad_h * self.scale, 0:w - self.mod_pad_w * self.scale]
|
||||||
|
# remove prepad
|
||||||
|
if self.pre_pad != 0:
|
||||||
|
_, _, h, w = self.output.size()
|
||||||
|
self.output = self.output[:, :, 0:h - self.pre_pad * self.scale, 0:w - self.pre_pad * self.scale]
|
||||||
|
return self.output
|
||||||
|
|
||||||
|
def enhance(self, img, tile=False, alpha_upsampler='realesrgan'):
|
||||||
|
# img: numpy
|
||||||
|
img = img.astype(np.float32)
|
||||||
|
if np.max(img) > 255: # 16-bit image
|
||||||
|
max_range = 65535
|
||||||
|
print('\tInput is a 16-bit image')
|
||||||
|
else:
|
||||||
|
max_range = 255
|
||||||
|
img = img / max_range
|
||||||
|
if len(img.shape) == 2: # gray image
|
||||||
|
img_mode = 'L'
|
||||||
|
img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
|
||||||
|
elif img.shape[2] == 4: # RGBA image with alpha channel
|
||||||
|
img_mode = 'RGBA'
|
||||||
|
alpha = img[:, :, 3]
|
||||||
|
img = img[:, :, 0:3]
|
||||||
|
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||||
|
if alpha_upsampler == 'realesrgan':
|
||||||
|
alpha = cv2.cvtColor(alpha, cv2.COLOR_GRAY2RGB)
|
||||||
|
else:
|
||||||
|
img_mode = 'RGB'
|
||||||
|
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||||
|
|
||||||
|
# ------------------- process image (without the alpha channel) ------------------- #
|
||||||
|
self.pre_process(img)
|
||||||
|
if self.tile_size > 0:
|
||||||
|
self.tile_process()
|
||||||
|
else:
|
||||||
|
self.process()
|
||||||
|
output_img = self.post_process()
|
||||||
|
output_img = output_img.data.squeeze().float().cpu().clamp_(0, 1).numpy()
|
||||||
|
output_img = np.transpose(output_img[[2, 1, 0], :, :], (1, 2, 0))
|
||||||
|
if img_mode == 'L':
|
||||||
|
output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2GRAY)
|
||||||
|
|
||||||
|
# ------------------- process the alpha channel if necessary ------------------- #
|
||||||
|
if img_mode == 'RGBA':
|
||||||
|
if alpha_upsampler == 'realesrgan':
|
||||||
|
self.pre_process(alpha)
|
||||||
|
if self.tile_size > 0:
|
||||||
|
self.tile_process()
|
||||||
|
else:
|
||||||
|
self.process()
|
||||||
|
output_alpha = self.post_process()
|
||||||
|
output_alpha = output_alpha.data.squeeze().float().cpu().clamp_(0, 1).numpy()
|
||||||
|
output_alpha = np.transpose(output_alpha[[2, 1, 0], :, :], (1, 2, 0))
|
||||||
|
output_alpha = cv2.cvtColor(output_alpha, cv2.COLOR_BGR2GRAY)
|
||||||
|
else:
|
||||||
|
h, w = alpha.shape[0:2]
|
||||||
|
output_alpha = cv2.resize(alpha, (w * self.scale, h * self.scale), interpolation=cv2.INTER_LINEAR)
|
||||||
|
|
||||||
|
# merge the alpha channel
|
||||||
|
output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2BGRA)
|
||||||
|
output_img[:, :, 3] = output_alpha
|
||||||
|
|
||||||
|
# ------------------------------ return ------------------------------ #
|
||||||
|
if max_range == 65535: # 16-bit image
|
||||||
|
output = (output_img * 65535.0).round().astype(np.uint16)
|
||||||
|
else:
|
||||||
|
output = (output_img * 255.0).round().astype(np.uint8)
|
||||||
|
return output, img_mode
|
||||||
|
|
||||||
|
|
||||||
|
def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
|
||||||
|
"""Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py
|
||||||
|
"""
|
||||||
|
if model_dir is None:
|
||||||
|
hub_dir = get_dir()
|
||||||
|
model_dir = os.path.join(hub_dir, 'checkpoints')
|
||||||
|
|
||||||
|
os.makedirs(os.path.join(ROOT_DIR, model_dir), exist_ok=True)
|
||||||
|
|
||||||
|
parts = urlparse(url)
|
||||||
|
filename = os.path.basename(parts.path)
|
||||||
|
if file_name is not None:
|
||||||
|
filename = file_name
|
||||||
|
cached_file = os.path.abspath(os.path.join(ROOT_DIR, model_dir, filename))
|
||||||
|
if not os.path.exists(cached_file):
|
||||||
|
print(f'Downloading: "{url}" to {cached_file}\n')
|
||||||
|
download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
|
||||||
|
return cached_file
|
||||||
3
realesrgan/weights/README.md
Normal file
3
realesrgan/weights/README.md
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
# Weights
|
||||||
|
|
||||||
|
Put the downloaded weights to this folder.
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
basicsr
|
basicsr
|
||||||
cv2
|
|
||||||
numpy
|
numpy
|
||||||
|
opencv-python
|
||||||
torch>=1.7
|
torch>=1.7
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ split_before_expression_after_opening_paren = true
|
|||||||
line_length = 120
|
line_length = 120
|
||||||
multi_line_output = 0
|
multi_line_output = 0
|
||||||
known_standard_library = pkg_resources,setuptools
|
known_standard_library = pkg_resources,setuptools
|
||||||
known_first_party = basicsr # modify it!
|
known_first_party = realesrgan
|
||||||
known_third_party = basicsr,cv2,numpy,torch
|
known_third_party = basicsr,cv2,numpy,torch
|
||||||
no_lines_before = STDLIB,LOCALFOLDER
|
no_lines_before = STDLIB,LOCALFOLDER
|
||||||
default_section = THIRDPARTY
|
default_section = THIRDPARTY
|
||||||
|
|||||||
113
setup.py
Normal file
113
setup.py
Normal file
@@ -0,0 +1,113 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
from setuptools import find_packages, setup
|
||||||
|
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
import time
|
||||||
|
|
||||||
|
version_file = 'realesrgan/version.py'
|
||||||
|
|
||||||
|
|
||||||
|
def readme():
|
||||||
|
with open('README.md', encoding='utf-8') as f:
|
||||||
|
content = f.read()
|
||||||
|
return content
|
||||||
|
|
||||||
|
|
||||||
|
def get_git_hash():
|
||||||
|
|
||||||
|
def _minimal_ext_cmd(cmd):
|
||||||
|
# construct minimal environment
|
||||||
|
env = {}
|
||||||
|
for k in ['SYSTEMROOT', 'PATH', 'HOME']:
|
||||||
|
v = os.environ.get(k)
|
||||||
|
if v is not None:
|
||||||
|
env[k] = v
|
||||||
|
# LANGUAGE is used on win32
|
||||||
|
env['LANGUAGE'] = 'C'
|
||||||
|
env['LANG'] = 'C'
|
||||||
|
env['LC_ALL'] = 'C'
|
||||||
|
out = subprocess.Popen(cmd, stdout=subprocess.PIPE, env=env).communicate()[0]
|
||||||
|
return out
|
||||||
|
|
||||||
|
try:
|
||||||
|
out = _minimal_ext_cmd(['git', 'rev-parse', 'HEAD'])
|
||||||
|
sha = out.strip().decode('ascii')
|
||||||
|
except OSError:
|
||||||
|
sha = 'unknown'
|
||||||
|
|
||||||
|
return sha
|
||||||
|
|
||||||
|
|
||||||
|
def get_hash():
|
||||||
|
if os.path.exists('.git'):
|
||||||
|
sha = get_git_hash()[:7]
|
||||||
|
elif os.path.exists(version_file):
|
||||||
|
try:
|
||||||
|
from facexlib.version import __version__
|
||||||
|
sha = __version__.split('+')[-1]
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError('Unable to get git version')
|
||||||
|
else:
|
||||||
|
sha = 'unknown'
|
||||||
|
|
||||||
|
return sha
|
||||||
|
|
||||||
|
|
||||||
|
def write_version_py():
|
||||||
|
content = """# GENERATED VERSION FILE
|
||||||
|
# TIME: {}
|
||||||
|
__version__ = '{}'
|
||||||
|
__gitsha__ = '{}'
|
||||||
|
version_info = ({})
|
||||||
|
"""
|
||||||
|
sha = get_hash()
|
||||||
|
with open('VERSION', 'r') as f:
|
||||||
|
SHORT_VERSION = f.read().strip()
|
||||||
|
VERSION_INFO = ', '.join([x if x.isdigit() else f'"{x}"' for x in SHORT_VERSION.split('.')])
|
||||||
|
|
||||||
|
version_file_str = content.format(time.asctime(), SHORT_VERSION, sha, VERSION_INFO)
|
||||||
|
with open(version_file, 'w') as f:
|
||||||
|
f.write(version_file_str)
|
||||||
|
|
||||||
|
|
||||||
|
def get_version():
|
||||||
|
with open(version_file, 'r') as f:
|
||||||
|
exec(compile(f.read(), version_file, 'exec'))
|
||||||
|
return locals()['__version__']
|
||||||
|
|
||||||
|
|
||||||
|
def get_requirements(filename='requirements.txt'):
|
||||||
|
here = os.path.dirname(os.path.realpath(__file__))
|
||||||
|
with open(os.path.join(here, filename), 'r') as f:
|
||||||
|
requires = [line.replace('\n', '') for line in f.readlines()]
|
||||||
|
return requires
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
write_version_py()
|
||||||
|
setup(
|
||||||
|
name='realesrgan',
|
||||||
|
version=get_version(),
|
||||||
|
description='Real-ESRGAN aims at developing Practical Algorithms for General Image Restoration',
|
||||||
|
long_description=readme(),
|
||||||
|
long_description_content_type='text/markdown',
|
||||||
|
author='Xintao Wang',
|
||||||
|
author_email='xintao.wang@outlook.com',
|
||||||
|
keywords='computer vision, pytorch, image restoration, super-resolution, esrgan, real-esrgan',
|
||||||
|
url='https://github.com/xinntao/Real-ESRGAN',
|
||||||
|
include_package_data=True,
|
||||||
|
packages=find_packages(exclude=('options', 'datasets', 'experiments', 'results', 'tb_logger', 'wandb')),
|
||||||
|
classifiers=[
|
||||||
|
'Development Status :: 4 - Beta',
|
||||||
|
'License :: OSI Approved :: Apache Software License',
|
||||||
|
'Operating System :: OS Independent',
|
||||||
|
'Programming Language :: Python :: 3',
|
||||||
|
'Programming Language :: Python :: 3.7',
|
||||||
|
'Programming Language :: Python :: 3.8',
|
||||||
|
],
|
||||||
|
license='BSD-3-Clause License',
|
||||||
|
setup_requires=['cython', 'numpy'],
|
||||||
|
install_requires=get_requirements(),
|
||||||
|
zip_safe=False)
|
||||||
Reference in New Issue
Block a user