From 9baa0b3d0072c8c72f0c77aaf96b883bde77d18f Mon Sep 17 00:00:00 2001 From: Xintao Date: Sun, 8 Aug 2021 14:46:43 +0800 Subject: [PATCH 1/9] regroup files --- .github/workflows/pylint.yml | 4 ++-- {archs => realesrgan/archs}/__init__.py | 0 {archs => realesrgan/archs}/discriminator_arch.py | 0 {data => realesrgan/data}/__init__.py | 0 {data => realesrgan/data}/realesrgan_dataset.py | 0 {models => realesrgan/models}/__init__.py | 0 {models => realesrgan/models}/realesrgan_model.py | 0 {models => realesrgan/models}/realesrnet_model.py | 0 train.py => realesrgan/train.py | 6 +++--- 9 files changed, 5 insertions(+), 5 deletions(-) rename {archs => realesrgan/archs}/__init__.py (100%) rename {archs => realesrgan/archs}/discriminator_arch.py (100%) rename {data => realesrgan/data}/__init__.py (100%) rename {data => realesrgan/data}/realesrgan_dataset.py (100%) rename {models => realesrgan/models}/__init__.py (100%) rename {models => realesrgan/models}/realesrgan_model.py (100%) rename {models => realesrgan/models}/realesrnet_model.py (100%) rename train.py => realesrgan/train.py (61%) diff --git a/.github/workflows/pylint.yml b/.github/workflows/pylint.yml index 1d8b501..d754f53 100644 --- a/.github/workflows/pylint.yml +++ b/.github/workflows/pylint.yml @@ -26,5 +26,5 @@ jobs: - name: Lint run: | flake8 . - isort --check-only --diff data/ models/ inference_realesrgan.py - yapf -r -d data/ models/ inference_realesrgan.py + isort --check-only --diff realesrgan/ scripts/ inference_realesrgan.py setup.py + yapf -r -d realesrgan/ scripts/ inference_realesrgan.py setup.py diff --git a/archs/__init__.py b/realesrgan/archs/__init__.py similarity index 100% rename from archs/__init__.py rename to realesrgan/archs/__init__.py diff --git a/archs/discriminator_arch.py b/realesrgan/archs/discriminator_arch.py similarity index 100% rename from archs/discriminator_arch.py rename to realesrgan/archs/discriminator_arch.py diff --git a/data/__init__.py b/realesrgan/data/__init__.py similarity index 100% rename from data/__init__.py rename to realesrgan/data/__init__.py diff --git a/data/realesrgan_dataset.py b/realesrgan/data/realesrgan_dataset.py similarity index 100% rename from data/realesrgan_dataset.py rename to realesrgan/data/realesrgan_dataset.py diff --git a/models/__init__.py b/realesrgan/models/__init__.py similarity index 100% rename from models/__init__.py rename to realesrgan/models/__init__.py diff --git a/models/realesrgan_model.py b/realesrgan/models/realesrgan_model.py similarity index 100% rename from models/realesrgan_model.py rename to realesrgan/models/realesrgan_model.py diff --git a/models/realesrnet_model.py b/realesrgan/models/realesrnet_model.py similarity index 100% rename from models/realesrnet_model.py rename to realesrgan/models/realesrnet_model.py diff --git a/train.py b/realesrgan/train.py similarity index 61% rename from train.py rename to realesrgan/train.py index bd52322..d985a79 100644 --- a/train.py +++ b/realesrgan/train.py @@ -1,9 +1,9 @@ import os.path as osp from basicsr.train import train_pipeline -import archs # noqa: F401 -import data # noqa: F401 -import models # noqa: F401 +import realesrgan.archs # noqa: F401 +import realesrgan.data # noqa: F401 +import realesrgan.models # noqa: F401 if __name__ == '__main__': root_path = osp.abspath(osp.join(__file__, osp.pardir)) From 94ae626008ee389f6e63f75af920c44dd0ddf7ba Mon Sep 17 00:00:00 2001 From: Xintao Date: Sun, 8 Aug 2021 14:50:32 +0800 Subject: [PATCH 2/9] regroup --- realesrgan/train.py | 7 ++++--- setup.cfg | 2 +- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/realesrgan/train.py b/realesrgan/train.py index d985a79..e9ad4dd 100644 --- a/realesrgan/train.py +++ b/realesrgan/train.py @@ -1,9 +1,10 @@ +# flake8: noqa import os.path as osp from basicsr.train import train_pipeline -import realesrgan.archs # noqa: F401 -import realesrgan.data # noqa: F401 -import realesrgan.models # noqa: F401 +from .archs import * +from .data import * +from .models import * if __name__ == '__main__': root_path = osp.abspath(osp.join(__file__, osp.pardir)) diff --git a/setup.cfg b/setup.cfg index 4eb6529..2293ad7 100644 --- a/setup.cfg +++ b/setup.cfg @@ -16,7 +16,7 @@ split_before_expression_after_opening_paren = true line_length = 120 multi_line_output = 0 known_standard_library = pkg_resources,setuptools -known_first_party = basicsr # modify it! +known_first_party = realesrgan known_third_party = basicsr,cv2,numpy,torch no_lines_before = STDLIB,LOCALFOLDER default_section = THIRDPARTY From 52eab16d11ed5008ecade8f8a8c3d29ec707a082 Mon Sep 17 00:00:00 2001 From: Xintao Date: Sun, 8 Aug 2021 15:01:08 +0800 Subject: [PATCH 3/9] update .gitignore --- .gitignore | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/.gitignore b/.gitignore index b6c4e54..b240c56 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,11 @@ +# ignored folders +datasets/* +experiments/* +results/* +tb_logger/* +wandb/* +tmp/* + .vscode # Byte-compiled / optimized / DLL files From 064df9956bc97aa21d23fe9ee170a718858b45a5 Mon Sep 17 00:00:00 2001 From: Xintao Date: Sun, 8 Aug 2021 15:13:53 +0800 Subject: [PATCH 4/9] add setup.py --- MANIFEST.in | 7 +++ VERSION | 1 + realesrgan/__init__.py | 5 ++ realesrgan/weights/README.md | 3 + setup.py | 113 +++++++++++++++++++++++++++++++++++ 5 files changed, 129 insertions(+) create mode 100644 MANIFEST.in create mode 100644 VERSION create mode 100644 realesrgan/__init__.py create mode 100644 realesrgan/weights/README.md create mode 100644 setup.py diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 0000000..11233df --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1,7 @@ +include assets/* +include inputs/* +include scripts/*.py +include inference_realesrgan.py +include VERSION +include requirements.txt +include realesrgan/weights/README.md diff --git a/VERSION b/VERSION new file mode 100644 index 0000000..0ea3a94 --- /dev/null +++ b/VERSION @@ -0,0 +1 @@ +0.2.0 diff --git a/realesrgan/__init__.py b/realesrgan/__init__.py new file mode 100644 index 0000000..36731a0 --- /dev/null +++ b/realesrgan/__init__.py @@ -0,0 +1,5 @@ +# flake8: noqa +from .archs import * +from .data import * +from .models import * +from .version import __gitsha__, __version__ diff --git a/realesrgan/weights/README.md b/realesrgan/weights/README.md new file mode 100644 index 0000000..4d7b7e6 --- /dev/null +++ b/realesrgan/weights/README.md @@ -0,0 +1,3 @@ +# Weights + +Put the downloaded weights to this folder. diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..e68d193 --- /dev/null +++ b/setup.py @@ -0,0 +1,113 @@ +#!/usr/bin/env python + +from setuptools import find_packages, setup + +import os +import subprocess +import time + +version_file = 'realesrgan/version.py' + + +def readme(): + with open('README.md', encoding='utf-8') as f: + content = f.read() + return content + + +def get_git_hash(): + + def _minimal_ext_cmd(cmd): + # construct minimal environment + env = {} + for k in ['SYSTEMROOT', 'PATH', 'HOME']: + v = os.environ.get(k) + if v is not None: + env[k] = v + # LANGUAGE is used on win32 + env['LANGUAGE'] = 'C' + env['LANG'] = 'C' + env['LC_ALL'] = 'C' + out = subprocess.Popen(cmd, stdout=subprocess.PIPE, env=env).communicate()[0] + return out + + try: + out = _minimal_ext_cmd(['git', 'rev-parse', 'HEAD']) + sha = out.strip().decode('ascii') + except OSError: + sha = 'unknown' + + return sha + + +def get_hash(): + if os.path.exists('.git'): + sha = get_git_hash()[:7] + elif os.path.exists(version_file): + try: + from facexlib.version import __version__ + sha = __version__.split('+')[-1] + except ImportError: + raise ImportError('Unable to get git version') + else: + sha = 'unknown' + + return sha + + +def write_version_py(): + content = """# GENERATED VERSION FILE +# TIME: {} +__version__ = '{}' +__gitsha__ = '{}' +version_info = ({}) +""" + sha = get_hash() + with open('VERSION', 'r') as f: + SHORT_VERSION = f.read().strip() + VERSION_INFO = ', '.join([x if x.isdigit() else f'"{x}"' for x in SHORT_VERSION.split('.')]) + + version_file_str = content.format(time.asctime(), SHORT_VERSION, sha, VERSION_INFO) + with open(version_file, 'w') as f: + f.write(version_file_str) + + +def get_version(): + with open(version_file, 'r') as f: + exec(compile(f.read(), version_file, 'exec')) + return locals()['__version__'] + + +def get_requirements(filename='requirements.txt'): + here = os.path.dirname(os.path.realpath(__file__)) + with open(os.path.join(here, filename), 'r') as f: + requires = [line.replace('\n', '') for line in f.readlines()] + return requires + + +if __name__ == '__main__': + write_version_py() + setup( + name='realesrgan', + version=get_version(), + description='Real-ESRGAN aims at developing Practical Algorithms for General Image Restoration', + long_description=readme(), + long_description_content_type='text/markdown', + author='Xintao Wang', + author_email='xintao.wang@outlook.com', + keywords='computer vision, pytorch, image restoration, super-resolution, esrgan, real-esrgan', + url='https://github.com/xinntao/Real-ESRGAN', + include_package_data=True, + packages=find_packages(exclude=('options', 'datasets', 'experiments', 'results', 'tb_logger', 'wandb')), + classifiers=[ + 'Development Status :: 4 - Beta', + 'License :: OSI Approved :: Apache Software License', + 'Operating System :: OS Independent', + 'Programming Language :: Python :: 3', + 'Programming Language :: Python :: 3.7', + 'Programming Language :: Python :: 3.8', + ], + license='BSD-3-Clause License', + setup_requires=['cython', 'numpy'], + install_requires=get_requirements(), + zip_safe=False) From bef5e3cabdc07ecc8920a59652563cf1c30a608e Mon Sep 17 00:00:00 2001 From: Xintao Date: Sun, 8 Aug 2021 15:27:55 +0800 Subject: [PATCH 5/9] update for running --- realesrgan/train.py | 8 ++++---- requirements.txt | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/realesrgan/train.py b/realesrgan/train.py index e9ad4dd..8a9cec9 100644 --- a/realesrgan/train.py +++ b/realesrgan/train.py @@ -2,10 +2,10 @@ import os.path as osp from basicsr.train import train_pipeline -from .archs import * -from .data import * -from .models import * +import realesrgan.archs +import realesrgan.data +import realesrgan.models if __name__ == '__main__': - root_path = osp.abspath(osp.join(__file__, osp.pardir)) + root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir)) train_pipeline(root_path) diff --git a/requirements.txt b/requirements.txt index 0ea6378..f0a64cb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ basicsr -cv2 numpy +opencv-python torch>=1.7 From 1f83ce543287f2717d4ee1e49eb968ac0fc20d6e Mon Sep 17 00:00:00 2001 From: Xintao Date: Sun, 8 Aug 2021 15:29:33 +0800 Subject: [PATCH 6/9] update .gitignore --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index b240c56..0200900 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,7 @@ tb_logger/* wandb/* tmp/* +version.py .vscode # Byte-compiled / optimized / DLL files From 18ebf723f29a95a528b69914541d16869b776810 Mon Sep 17 00:00:00 2001 From: Xintao Date: Sun, 8 Aug 2021 16:12:56 +0800 Subject: [PATCH 7/9] adaption for pypi --- inference_realesrgan.py | 197 +---------------------------- realesrgan/__init__.py | 1 + realesrgan/archs/__init__.py | 2 +- realesrgan/data/__init__.py | 2 +- realesrgan/models/__init__.py | 2 +- realesrgan/utils.py | 226 ++++++++++++++++++++++++++++++++++ 6 files changed, 234 insertions(+), 196 deletions(-) create mode 100644 realesrgan/utils.py diff --git a/inference_realesrgan.py b/inference_realesrgan.py index 0418253..fadec7f 100644 --- a/inference_realesrgan.py +++ b/inference_realesrgan.py @@ -1,12 +1,9 @@ import argparse import cv2 import glob -import math -import numpy as np import os -import torch -from basicsr.archs.rrdbnet_arch import RRDBNet -from torch.nn import functional as F + +from realesrgan import RealESRGANer def main(): @@ -53,61 +50,8 @@ def main(): imgname, extension = os.path.splitext(os.path.basename(path)) print('Testing', idx, imgname) - # ------------------------------ read image ------------------------------ # - img = cv2.imread(path, cv2.IMREAD_UNCHANGED).astype(np.float32) - if np.max(img) > 255: # 16-bit image - max_range = 65535 - print('\tInput is a 16-bit image') - else: - max_range = 255 - img = img / max_range - if len(img.shape) == 2: # gray image - img_mode = 'L' - img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) - elif img.shape[2] == 4: # RGBA image with alpha channel - img_mode = 'RGBA' - alpha = img[:, :, 3] - img = img[:, :, 0:3] - img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) - if args.alpha_upsampler == 'realesrgan': - alpha = cv2.cvtColor(alpha, cv2.COLOR_GRAY2RGB) - else: - img_mode = 'RGB' - img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) - - # ------------------- process image (without the alpha channel) ------------------- # - upsampler.pre_process(img) - if args.tile: - upsampler.tile_process() - else: - upsampler.process() - output_img = upsampler.post_process() - output_img = output_img.data.squeeze().float().cpu().clamp_(0, 1).numpy() - output_img = np.transpose(output_img[[2, 1, 0], :, :], (1, 2, 0)) - if img_mode == 'L': - output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2GRAY) - - # ------------------- process the alpha channel if necessary ------------------- # - if img_mode == 'RGBA': - if args.alpha_upsampler == 'realesrgan': - upsampler.pre_process(alpha) - if args.tile: - upsampler.tile_process() - else: - upsampler.process() - output_alpha = upsampler.post_process() - output_alpha = output_alpha.data.squeeze().float().cpu().clamp_(0, 1).numpy() - output_alpha = np.transpose(output_alpha[[2, 1, 0], :, :], (1, 2, 0)) - output_alpha = cv2.cvtColor(output_alpha, cv2.COLOR_BGR2GRAY) - else: - h, w = alpha.shape[0:2] - output_alpha = cv2.resize(alpha, (w * args.scale, h * args.scale), interpolation=cv2.INTER_LINEAR) - - # merge the alpha channel - output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2BGRA) - output_img[:, :, 3] = output_alpha - - # ------------------------------ save image ------------------------------ # + img = cv2.imread(path, cv2.IMREAD_UNCHANGED) + output, img_mode = upsampler.enhance(img) if args.ext == 'auto': extension = extension[1:] else: @@ -115,141 +59,8 @@ def main(): if img_mode == 'RGBA': # RGBA images should be saved in png format extension = 'png' save_path = os.path.join(args.output, f'{imgname}_{args.suffix}.{extension}') - if max_range == 65535: # 16-bit image - output = (output_img * 65535.0).round().astype(np.uint16) - else: - output = (output_img * 255.0).round().astype(np.uint8) cv2.imwrite(save_path, output) -class RealESRGANer(): - - def __init__(self, scale, model_path, tile=0, tile_pad=10, pre_pad=10, half=False): - self.scale = scale - self.tile_size = tile - self.tile_pad = tile_pad - self.pre_pad = pre_pad - self.mod_scale = None - self.half = half - - # initialize model - self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=scale) - loadnet = torch.load(model_path) - if 'params_ema' in loadnet: - keyname = 'params_ema' - else: - keyname = 'params' - model.load_state_dict(loadnet[keyname], strict=True) - model.eval() - self.model = model.to(self.device) - if self.half: - self.model = self.model.half() - - def pre_process(self, img): - img = torch.from_numpy(np.transpose(img, (2, 0, 1))).float() - self.img = img.unsqueeze(0).to(self.device) - if self.half: - self.img = self.img.half() - - # pre_pad - if self.pre_pad != 0: - self.img = F.pad(self.img, (0, self.pre_pad, 0, self.pre_pad), 'reflect') - # mod pad - if self.scale == 2: - self.mod_scale = 2 - elif self.scale == 1: - self.mod_scale = 4 - if self.mod_scale is not None: - self.mod_pad_h, self.mod_pad_w = 0, 0 - _, _, h, w = self.img.size() - if (h % self.mod_scale != 0): - self.mod_pad_h = (self.mod_scale - h % self.mod_scale) - if (w % self.mod_scale != 0): - self.mod_pad_w = (self.mod_scale - w % self.mod_scale) - self.img = F.pad(self.img, (0, self.mod_pad_w, 0, self.mod_pad_h), 'reflect') - - def process(self): - try: - # inference - with torch.no_grad(): - self.output = self.model(self.img) - except Exception as error: - print('Error', error) - - def tile_process(self): - """Modified from: https://github.com/ata4/esrgan-launcher - """ - batch, channel, height, width = self.img.shape - output_height = height * self.scale - output_width = width * self.scale - output_shape = (batch, channel, output_height, output_width) - - # start with black image - self.output = self.img.new_zeros(output_shape) - tiles_x = math.ceil(width / self.tile_size) - tiles_y = math.ceil(height / self.tile_size) - - # loop over all tiles - for y in range(tiles_y): - for x in range(tiles_x): - # extract tile from input image - ofs_x = x * self.tile_size - ofs_y = y * self.tile_size - # input tile area on total image - input_start_x = ofs_x - input_end_x = min(ofs_x + self.tile_size, width) - input_start_y = ofs_y - input_end_y = min(ofs_y + self.tile_size, height) - - # input tile area on total image with padding - input_start_x_pad = max(input_start_x - self.tile_pad, 0) - input_end_x_pad = min(input_end_x + self.tile_pad, width) - input_start_y_pad = max(input_start_y - self.tile_pad, 0) - input_end_y_pad = min(input_end_y + self.tile_pad, height) - - # input tile dimensions - input_tile_width = input_end_x - input_start_x - input_tile_height = input_end_y - input_start_y - tile_idx = y * tiles_x + x + 1 - input_tile = self.img[:, :, input_start_y_pad:input_end_y_pad, input_start_x_pad:input_end_x_pad] - - # upscale tile - try: - with torch.no_grad(): - output_tile = self.model(input_tile) - except Exception as error: - print('Error', error) - print(f'\tTile {tile_idx}/{tiles_x * tiles_y}') - - # output tile area on total image - output_start_x = input_start_x * self.scale - output_end_x = input_end_x * self.scale - output_start_y = input_start_y * self.scale - output_end_y = input_end_y * self.scale - - # output tile area without padding - output_start_x_tile = (input_start_x - input_start_x_pad) * self.scale - output_end_x_tile = output_start_x_tile + input_tile_width * self.scale - output_start_y_tile = (input_start_y - input_start_y_pad) * self.scale - output_end_y_tile = output_start_y_tile + input_tile_height * self.scale - - # put tile into output image - self.output[:, :, output_start_y:output_end_y, - output_start_x:output_end_x] = output_tile[:, :, output_start_y_tile:output_end_y_tile, - output_start_x_tile:output_end_x_tile] - - def post_process(self): - # remove extra pad - if self.mod_scale is not None: - _, _, h, w = self.output.size() - self.output = self.output[:, :, 0:h - self.mod_pad_h * self.scale, 0:w - self.mod_pad_w * self.scale] - # remove prepad - if self.pre_pad != 0: - _, _, h, w = self.output.size() - self.output = self.output[:, :, 0:h - self.pre_pad * self.scale, 0:w - self.pre_pad * self.scale] - return self.output - - if __name__ == '__main__': main() diff --git a/realesrgan/__init__.py b/realesrgan/__init__.py index 36731a0..4ccac57 100644 --- a/realesrgan/__init__.py +++ b/realesrgan/__init__.py @@ -2,4 +2,5 @@ from .archs import * from .data import * from .models import * +from .utils import * from .version import __gitsha__, __version__ diff --git a/realesrgan/archs/__init__.py b/realesrgan/archs/__init__.py index 4ec725e..f3fbbf3 100644 --- a/realesrgan/archs/__init__.py +++ b/realesrgan/archs/__init__.py @@ -7,4 +7,4 @@ from os import path as osp arch_folder = osp.dirname(osp.abspath(__file__)) arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')] # import all the arch modules -_arch_modules = [importlib.import_module(f'archs.{file_name}') for file_name in arch_filenames] +_arch_modules = [importlib.import_module(f'realesrgan.archs.{file_name}') for file_name in arch_filenames] diff --git a/realesrgan/data/__init__.py b/realesrgan/data/__init__.py index 3b8afa6..a3f8fdd 100644 --- a/realesrgan/data/__init__.py +++ b/realesrgan/data/__init__.py @@ -7,4 +7,4 @@ from os import path as osp data_folder = osp.dirname(osp.abspath(__file__)) dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')] # import all the dataset modules -_dataset_modules = [importlib.import_module(f'data.{file_name}') for file_name in dataset_filenames] +_dataset_modules = [importlib.import_module(f'realesrgan.data.{file_name}') for file_name in dataset_filenames] diff --git a/realesrgan/models/__init__.py b/realesrgan/models/__init__.py index a7ce9a2..0be7105 100644 --- a/realesrgan/models/__init__.py +++ b/realesrgan/models/__init__.py @@ -7,4 +7,4 @@ from os import path as osp model_folder = osp.dirname(osp.abspath(__file__)) model_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) if v.endswith('_model.py')] # import all the model modules -_model_modules = [importlib.import_module(f'models.{file_name}') for file_name in model_filenames] +_model_modules = [importlib.import_module(f'realesrgan.models.{file_name}') for file_name in model_filenames] diff --git a/realesrgan/utils.py b/realesrgan/utils.py new file mode 100644 index 0000000..06b2261 --- /dev/null +++ b/realesrgan/utils.py @@ -0,0 +1,226 @@ +import cv2 +import math +import numpy as np +import os +import torch +from basicsr.archs.rrdbnet_arch import RRDBNet +from torch.hub import download_url_to_file, get_dir +from torch.nn import functional as F +from urllib.parse import urlparse + +ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + + +class RealESRGANer(): + + def __init__(self, scale, model_path, tile=0, tile_pad=10, pre_pad=10, half=False): + self.scale = scale + self.tile_size = tile + self.tile_pad = tile_pad + self.pre_pad = pre_pad + self.mod_scale = None + self.half = half + + # initialize model + self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=scale) + + if model_path.startswith('https://'): + model_path = load_file_from_url( + url=model_path, model_dir='realesrgan/weights', progress=True, file_name=None) + loadnet = torch.load(model_path) + if 'params_ema' in loadnet: + keyname = 'params_ema' + else: + keyname = 'params' + model.load_state_dict(loadnet[keyname], strict=True) + model.eval() + self.model = model.to(self.device) + if self.half: + self.model = self.model.half() + + def pre_process(self, img): + img = torch.from_numpy(np.transpose(img, (2, 0, 1))).float() + self.img = img.unsqueeze(0).to(self.device) + if self.half: + self.img = self.img.half() + + # pre_pad + if self.pre_pad != 0: + self.img = F.pad(self.img, (0, self.pre_pad, 0, self.pre_pad), 'reflect') + # mod pad + if self.scale == 2: + self.mod_scale = 2 + elif self.scale == 1: + self.mod_scale = 4 + if self.mod_scale is not None: + self.mod_pad_h, self.mod_pad_w = 0, 0 + _, _, h, w = self.img.size() + if (h % self.mod_scale != 0): + self.mod_pad_h = (self.mod_scale - h % self.mod_scale) + if (w % self.mod_scale != 0): + self.mod_pad_w = (self.mod_scale - w % self.mod_scale) + self.img = F.pad(self.img, (0, self.mod_pad_w, 0, self.mod_pad_h), 'reflect') + + def process(self): + try: + # inference + with torch.no_grad(): + self.output = self.model(self.img) + except Exception as error: + print('Error', error) + + def tile_process(self): + """Modified from: https://github.com/ata4/esrgan-launcher + """ + batch, channel, height, width = self.img.shape + output_height = height * self.scale + output_width = width * self.scale + output_shape = (batch, channel, output_height, output_width) + + # start with black image + self.output = self.img.new_zeros(output_shape) + tiles_x = math.ceil(width / self.tile_size) + tiles_y = math.ceil(height / self.tile_size) + + # loop over all tiles + for y in range(tiles_y): + for x in range(tiles_x): + # extract tile from input image + ofs_x = x * self.tile_size + ofs_y = y * self.tile_size + # input tile area on total image + input_start_x = ofs_x + input_end_x = min(ofs_x + self.tile_size, width) + input_start_y = ofs_y + input_end_y = min(ofs_y + self.tile_size, height) + + # input tile area on total image with padding + input_start_x_pad = max(input_start_x - self.tile_pad, 0) + input_end_x_pad = min(input_end_x + self.tile_pad, width) + input_start_y_pad = max(input_start_y - self.tile_pad, 0) + input_end_y_pad = min(input_end_y + self.tile_pad, height) + + # input tile dimensions + input_tile_width = input_end_x - input_start_x + input_tile_height = input_end_y - input_start_y + tile_idx = y * tiles_x + x + 1 + input_tile = self.img[:, :, input_start_y_pad:input_end_y_pad, input_start_x_pad:input_end_x_pad] + + # upscale tile + try: + with torch.no_grad(): + output_tile = self.model(input_tile) + except Exception as error: + print('Error', error) + print(f'\tTile {tile_idx}/{tiles_x * tiles_y}') + + # output tile area on total image + output_start_x = input_start_x * self.scale + output_end_x = input_end_x * self.scale + output_start_y = input_start_y * self.scale + output_end_y = input_end_y * self.scale + + # output tile area without padding + output_start_x_tile = (input_start_x - input_start_x_pad) * self.scale + output_end_x_tile = output_start_x_tile + input_tile_width * self.scale + output_start_y_tile = (input_start_y - input_start_y_pad) * self.scale + output_end_y_tile = output_start_y_tile + input_tile_height * self.scale + + # put tile into output image + self.output[:, :, output_start_y:output_end_y, + output_start_x:output_end_x] = output_tile[:, :, output_start_y_tile:output_end_y_tile, + output_start_x_tile:output_end_x_tile] + + def post_process(self): + # remove extra pad + if self.mod_scale is not None: + _, _, h, w = self.output.size() + self.output = self.output[:, :, 0:h - self.mod_pad_h * self.scale, 0:w - self.mod_pad_w * self.scale] + # remove prepad + if self.pre_pad != 0: + _, _, h, w = self.output.size() + self.output = self.output[:, :, 0:h - self.pre_pad * self.scale, 0:w - self.pre_pad * self.scale] + return self.output + + def enhance(self, img, tile=False, alpha_upsampler='realesrgan'): + # img: numpy + img = img.astype(np.float32) + if np.max(img) > 255: # 16-bit image + max_range = 65535 + print('\tInput is a 16-bit image') + else: + max_range = 255 + img = img / max_range + if len(img.shape) == 2: # gray image + img_mode = 'L' + img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) + elif img.shape[2] == 4: # RGBA image with alpha channel + img_mode = 'RGBA' + alpha = img[:, :, 3] + img = img[:, :, 0:3] + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + if alpha_upsampler == 'realesrgan': + alpha = cv2.cvtColor(alpha, cv2.COLOR_GRAY2RGB) + else: + img_mode = 'RGB' + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + + # ------------------- process image (without the alpha channel) ------------------- # + self.pre_process(img) + if self.tile_size > 0: + self.tile_process() + else: + self.process() + output_img = self.post_process() + output_img = output_img.data.squeeze().float().cpu().clamp_(0, 1).numpy() + output_img = np.transpose(output_img[[2, 1, 0], :, :], (1, 2, 0)) + if img_mode == 'L': + output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2GRAY) + + # ------------------- process the alpha channel if necessary ------------------- # + if img_mode == 'RGBA': + if alpha_upsampler == 'realesrgan': + self.pre_process(alpha) + if self.tile_size > 0: + self.tile_process() + else: + self.process() + output_alpha = self.post_process() + output_alpha = output_alpha.data.squeeze().float().cpu().clamp_(0, 1).numpy() + output_alpha = np.transpose(output_alpha[[2, 1, 0], :, :], (1, 2, 0)) + output_alpha = cv2.cvtColor(output_alpha, cv2.COLOR_BGR2GRAY) + else: + h, w = alpha.shape[0:2] + output_alpha = cv2.resize(alpha, (w * self.scale, h * self.scale), interpolation=cv2.INTER_LINEAR) + + # merge the alpha channel + output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2BGRA) + output_img[:, :, 3] = output_alpha + + # ------------------------------ return ------------------------------ # + if max_range == 65535: # 16-bit image + output = (output_img * 65535.0).round().astype(np.uint16) + else: + output = (output_img * 255.0).round().astype(np.uint8) + return output, img_mode + + +def load_file_from_url(url, model_dir=None, progress=True, file_name=None): + """Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py + """ + if model_dir is None: + hub_dir = get_dir() + model_dir = os.path.join(hub_dir, 'checkpoints') + + os.makedirs(os.path.join(ROOT_DIR, model_dir), exist_ok=True) + + parts = urlparse(url) + filename = os.path.basename(parts.path) + if file_name is not None: + filename = file_name + cached_file = os.path.abspath(os.path.join(ROOT_DIR, model_dir, filename)) + if not os.path.exists(cached_file): + print(f'Downloading: "{url}" to {cached_file}\n') + download_url_to_file(url, cached_file, hash_prefix=None, progress=progress) + return cached_file From 32a4fa17723bce05db21afa67a473dc4285f8786 Mon Sep 17 00:00:00 2001 From: Xintao Date: Sun, 8 Aug 2021 16:20:01 +0800 Subject: [PATCH 8/9] add publish-pip action --- .github/workflows/publish-pip.yml | 30 ++++++++++++++++++++++++++++++ MANIFEST.in | 1 + 2 files changed, 31 insertions(+) create mode 100644 .github/workflows/publish-pip.yml diff --git a/.github/workflows/publish-pip.yml b/.github/workflows/publish-pip.yml new file mode 100644 index 0000000..06047f7 --- /dev/null +++ b/.github/workflows/publish-pip.yml @@ -0,0 +1,30 @@ +name: PyPI Publish + +on: push + +jobs: + build-n-publish: + runs-on: ubuntu-latest + if: startsWith(github.event.ref, 'refs/tags') + + steps: + - uses: actions/checkout@v2 + - name: Set up Python 3.8 + uses: actions/setup-python@v1 + with: + python-version: 3.8 + - name: Upgrade pip + run: pip install pip --upgrade + - name: Install PyTorch (cpu) + run: pip install torch==1.7.0+cpu torchvision==0.8.1+cpu -f https://download.pytorch.org/whl/torch_stable.html + - name: Install dependencies + run: pip install -r requirements.txt + - name: Build and install + run: rm -rf .eggs && pip install -e . + - name: Build for distribution + # remove bdist_wheel for pip installation with compiling cuda extensions + run: python setup.py sdist + - name: Publish distribution to PyPI + uses: pypa/gh-action-pypi-publish@master + with: + password: ${{ secrets.PYPI_API_TOKEN }} diff --git a/MANIFEST.in b/MANIFEST.in index 11233df..b18403e 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -3,5 +3,6 @@ include inputs/* include scripts/*.py include inference_realesrgan.py include VERSION +include LICENSE include requirements.txt include realesrgan/weights/README.md From 4356ba057841b3ef42733a4bea56d80d36618279 Mon Sep 17 00:00:00 2001 From: Xintao Date: Sun, 8 Aug 2021 16:26:09 +0800 Subject: [PATCH 9/9] update readme --- README.md | 1 + Training.md | 10 +++++----- options/train_realesrgan_x4plus.yml | 4 ++-- options/train_realesrnet_x4plus.yml | 2 +- 4 files changed, 9 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index 644450f..97f671c 100644 --- a/README.md +++ b/README.md @@ -97,6 +97,7 @@ This executable file is based on the wonderful [Tencent/ncnn](https://github.com # We use BasicSR for both training and inference pip install basicsr pip install -r requirements.txt + python setup.py develop ``` ## :zap: Quick Inference diff --git a/Training.md b/Training.md index 50454ae..64704e1 100644 --- a/Training.md +++ b/Training.md @@ -44,7 +44,7 @@ DF2K_HR_sub/000001_s003.png name: DF2K+OST type: RealESRGANDataset dataroot_gt: datasets/DF2K # modify to the root path of your folder - meta_info: data/meta_info/meta_info_DF2Kmultiscale+OST_sub.txt # modify to your own generate meta info txt + meta_info: realesrgan/meta_info/meta_info_DF2Kmultiscale+OST_sub.txt # modify to your own generate meta info txt io_backend: type: disk ``` @@ -76,12 +76,12 @@ DF2K_HR_sub/000001_s003.png 1. Before the formal training, you may run in the `--debug` mode to see whether everything is OK. We use four GPUs for training: ```bash CUDA_VISIBLE_DEVICES=0,1,2,3 \ - python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 train.py -opt options/train_realesrnet_x4plus.yml --launcher pytorch --debug + python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 realesrgan/train.py -opt options/train_realesrnet_x4plus.yml --launcher pytorch --debug ``` 1. The formal training. We use four GPUs for training. We use the `--auto_resume` argument to automatically resume the training if necessary. ```bash CUDA_VISIBLE_DEVICES=0,1,2,3 \ - python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 train.py -opt options/train_realesrnet_x4plus.yml --launcher pytorch --auto_resume + python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 realesrgan/train.py -opt options/train_realesrnet_x4plus.yml --launcher pytorch --auto_resume ``` ## Train Real-ESRGAN @@ -91,10 +91,10 @@ DF2K_HR_sub/000001_s003.png 1. Before the formal training, you may run in the `--debug` mode to see whether everything is OK. We use four GPUs for training: ```bash CUDA_VISIBLE_DEVICES=0,1,2,3 \ - python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 train.py -opt options/train_realesrgan_x4plus.yml --launcher pytorch --debug + python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 realesrgan/train.py -opt options/train_realesrgan_x4plus.yml --launcher pytorch --debug ``` 1. The formal training. We use four GPUs for training. We use the `--auto_resume` argument to automatically resume the training if necessary. ```bash CUDA_VISIBLE_DEVICES=0,1,2,3 \ - python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 train.py -opt options/train_realesrgan_x4plus.yml --launcher pytorch --auto_resume + python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 realesrgan/train.py -opt options/train_realesrgan_x4plus.yml --launcher pytorch --auto_resume ``` diff --git a/options/train_realesrgan_x4plus.yml b/options/train_realesrgan_x4plus.yml index 940a777..25bb5a2 100644 --- a/options/train_realesrgan_x4plus.yml +++ b/options/train_realesrgan_x4plus.yml @@ -39,7 +39,7 @@ datasets: name: DF2K+OST type: RealESRGANDataset dataroot_gt: datasets/DF2K - meta_info: data/meta_info/meta_info_DF2Kmultiscale+OST_sub.txt + meta_info: realesrgan/data/meta_info/meta_info_DF2Kmultiscale+OST_sub.txt io_backend: type: disk @@ -100,7 +100,7 @@ network_d: # path path: # use the pre-trained Real-ESRNet model - pretrain_network_g: experiments/train_RealESRNetx4plus_1000k_B12G4_fromESRGAN/model/net_g_1000000.pth + pretrain_network_g: experiments/train_RealESRNetx4plus_1000k_B12G4_fromESRGAN/models/net_g_1000000.pth param_key_g: params_ema strict_load_g: true resume_state: ~ diff --git a/options/train_realesrnet_x4plus.yml b/options/train_realesrnet_x4plus.yml index 400c580..2e13c39 100644 --- a/options/train_realesrnet_x4plus.yml +++ b/options/train_realesrnet_x4plus.yml @@ -36,7 +36,7 @@ datasets: name: DF2K+OST type: RealESRGANDataset dataroot_gt: datasets/DF2K - meta_info: data/meta_info/meta_info_DF2Kmultiscale+OST_sub.txt + meta_info: realesrgan/data/meta_info/meta_info_DF2Kmultiscale+OST_sub.txt io_backend: type: disk