diff --git a/inference_realesrgan.py b/inference_realesrgan.py index 0418253..fadec7f 100644 --- a/inference_realesrgan.py +++ b/inference_realesrgan.py @@ -1,12 +1,9 @@ import argparse import cv2 import glob -import math -import numpy as np import os -import torch -from basicsr.archs.rrdbnet_arch import RRDBNet -from torch.nn import functional as F + +from realesrgan import RealESRGANer def main(): @@ -53,61 +50,8 @@ def main(): imgname, extension = os.path.splitext(os.path.basename(path)) print('Testing', idx, imgname) - # ------------------------------ read image ------------------------------ # - img = cv2.imread(path, cv2.IMREAD_UNCHANGED).astype(np.float32) - if np.max(img) > 255: # 16-bit image - max_range = 65535 - print('\tInput is a 16-bit image') - else: - max_range = 255 - img = img / max_range - if len(img.shape) == 2: # gray image - img_mode = 'L' - img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) - elif img.shape[2] == 4: # RGBA image with alpha channel - img_mode = 'RGBA' - alpha = img[:, :, 3] - img = img[:, :, 0:3] - img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) - if args.alpha_upsampler == 'realesrgan': - alpha = cv2.cvtColor(alpha, cv2.COLOR_GRAY2RGB) - else: - img_mode = 'RGB' - img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) - - # ------------------- process image (without the alpha channel) ------------------- # - upsampler.pre_process(img) - if args.tile: - upsampler.tile_process() - else: - upsampler.process() - output_img = upsampler.post_process() - output_img = output_img.data.squeeze().float().cpu().clamp_(0, 1).numpy() - output_img = np.transpose(output_img[[2, 1, 0], :, :], (1, 2, 0)) - if img_mode == 'L': - output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2GRAY) - - # ------------------- process the alpha channel if necessary ------------------- # - if img_mode == 'RGBA': - if args.alpha_upsampler == 'realesrgan': - upsampler.pre_process(alpha) - if args.tile: - upsampler.tile_process() - else: - upsampler.process() - output_alpha = upsampler.post_process() - output_alpha = output_alpha.data.squeeze().float().cpu().clamp_(0, 1).numpy() - output_alpha = np.transpose(output_alpha[[2, 1, 0], :, :], (1, 2, 0)) - output_alpha = cv2.cvtColor(output_alpha, cv2.COLOR_BGR2GRAY) - else: - h, w = alpha.shape[0:2] - output_alpha = cv2.resize(alpha, (w * args.scale, h * args.scale), interpolation=cv2.INTER_LINEAR) - - # merge the alpha channel - output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2BGRA) - output_img[:, :, 3] = output_alpha - - # ------------------------------ save image ------------------------------ # + img = cv2.imread(path, cv2.IMREAD_UNCHANGED) + output, img_mode = upsampler.enhance(img) if args.ext == 'auto': extension = extension[1:] else: @@ -115,141 +59,8 @@ def main(): if img_mode == 'RGBA': # RGBA images should be saved in png format extension = 'png' save_path = os.path.join(args.output, f'{imgname}_{args.suffix}.{extension}') - if max_range == 65535: # 16-bit image - output = (output_img * 65535.0).round().astype(np.uint16) - else: - output = (output_img * 255.0).round().astype(np.uint8) cv2.imwrite(save_path, output) -class RealESRGANer(): - - def __init__(self, scale, model_path, tile=0, tile_pad=10, pre_pad=10, half=False): - self.scale = scale - self.tile_size = tile - self.tile_pad = tile_pad - self.pre_pad = pre_pad - self.mod_scale = None - self.half = half - - # initialize model - self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=scale) - loadnet = torch.load(model_path) - if 'params_ema' in loadnet: - keyname = 'params_ema' - else: - keyname = 'params' - model.load_state_dict(loadnet[keyname], strict=True) - model.eval() - self.model = model.to(self.device) - if self.half: - self.model = self.model.half() - - def pre_process(self, img): - img = torch.from_numpy(np.transpose(img, (2, 0, 1))).float() - self.img = img.unsqueeze(0).to(self.device) - if self.half: - self.img = self.img.half() - - # pre_pad - if self.pre_pad != 0: - self.img = F.pad(self.img, (0, self.pre_pad, 0, self.pre_pad), 'reflect') - # mod pad - if self.scale == 2: - self.mod_scale = 2 - elif self.scale == 1: - self.mod_scale = 4 - if self.mod_scale is not None: - self.mod_pad_h, self.mod_pad_w = 0, 0 - _, _, h, w = self.img.size() - if (h % self.mod_scale != 0): - self.mod_pad_h = (self.mod_scale - h % self.mod_scale) - if (w % self.mod_scale != 0): - self.mod_pad_w = (self.mod_scale - w % self.mod_scale) - self.img = F.pad(self.img, (0, self.mod_pad_w, 0, self.mod_pad_h), 'reflect') - - def process(self): - try: - # inference - with torch.no_grad(): - self.output = self.model(self.img) - except Exception as error: - print('Error', error) - - def tile_process(self): - """Modified from: https://github.com/ata4/esrgan-launcher - """ - batch, channel, height, width = self.img.shape - output_height = height * self.scale - output_width = width * self.scale - output_shape = (batch, channel, output_height, output_width) - - # start with black image - self.output = self.img.new_zeros(output_shape) - tiles_x = math.ceil(width / self.tile_size) - tiles_y = math.ceil(height / self.tile_size) - - # loop over all tiles - for y in range(tiles_y): - for x in range(tiles_x): - # extract tile from input image - ofs_x = x * self.tile_size - ofs_y = y * self.tile_size - # input tile area on total image - input_start_x = ofs_x - input_end_x = min(ofs_x + self.tile_size, width) - input_start_y = ofs_y - input_end_y = min(ofs_y + self.tile_size, height) - - # input tile area on total image with padding - input_start_x_pad = max(input_start_x - self.tile_pad, 0) - input_end_x_pad = min(input_end_x + self.tile_pad, width) - input_start_y_pad = max(input_start_y - self.tile_pad, 0) - input_end_y_pad = min(input_end_y + self.tile_pad, height) - - # input tile dimensions - input_tile_width = input_end_x - input_start_x - input_tile_height = input_end_y - input_start_y - tile_idx = y * tiles_x + x + 1 - input_tile = self.img[:, :, input_start_y_pad:input_end_y_pad, input_start_x_pad:input_end_x_pad] - - # upscale tile - try: - with torch.no_grad(): - output_tile = self.model(input_tile) - except Exception as error: - print('Error', error) - print(f'\tTile {tile_idx}/{tiles_x * tiles_y}') - - # output tile area on total image - output_start_x = input_start_x * self.scale - output_end_x = input_end_x * self.scale - output_start_y = input_start_y * self.scale - output_end_y = input_end_y * self.scale - - # output tile area without padding - output_start_x_tile = (input_start_x - input_start_x_pad) * self.scale - output_end_x_tile = output_start_x_tile + input_tile_width * self.scale - output_start_y_tile = (input_start_y - input_start_y_pad) * self.scale - output_end_y_tile = output_start_y_tile + input_tile_height * self.scale - - # put tile into output image - self.output[:, :, output_start_y:output_end_y, - output_start_x:output_end_x] = output_tile[:, :, output_start_y_tile:output_end_y_tile, - output_start_x_tile:output_end_x_tile] - - def post_process(self): - # remove extra pad - if self.mod_scale is not None: - _, _, h, w = self.output.size() - self.output = self.output[:, :, 0:h - self.mod_pad_h * self.scale, 0:w - self.mod_pad_w * self.scale] - # remove prepad - if self.pre_pad != 0: - _, _, h, w = self.output.size() - self.output = self.output[:, :, 0:h - self.pre_pad * self.scale, 0:w - self.pre_pad * self.scale] - return self.output - - if __name__ == '__main__': main() diff --git a/realesrgan/__init__.py b/realesrgan/__init__.py index 36731a0..4ccac57 100644 --- a/realesrgan/__init__.py +++ b/realesrgan/__init__.py @@ -2,4 +2,5 @@ from .archs import * from .data import * from .models import * +from .utils import * from .version import __gitsha__, __version__ diff --git a/realesrgan/archs/__init__.py b/realesrgan/archs/__init__.py index 4ec725e..f3fbbf3 100644 --- a/realesrgan/archs/__init__.py +++ b/realesrgan/archs/__init__.py @@ -7,4 +7,4 @@ from os import path as osp arch_folder = osp.dirname(osp.abspath(__file__)) arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')] # import all the arch modules -_arch_modules = [importlib.import_module(f'archs.{file_name}') for file_name in arch_filenames] +_arch_modules = [importlib.import_module(f'realesrgan.archs.{file_name}') for file_name in arch_filenames] diff --git a/realesrgan/data/__init__.py b/realesrgan/data/__init__.py index 3b8afa6..a3f8fdd 100644 --- a/realesrgan/data/__init__.py +++ b/realesrgan/data/__init__.py @@ -7,4 +7,4 @@ from os import path as osp data_folder = osp.dirname(osp.abspath(__file__)) dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')] # import all the dataset modules -_dataset_modules = [importlib.import_module(f'data.{file_name}') for file_name in dataset_filenames] +_dataset_modules = [importlib.import_module(f'realesrgan.data.{file_name}') for file_name in dataset_filenames] diff --git a/realesrgan/models/__init__.py b/realesrgan/models/__init__.py index a7ce9a2..0be7105 100644 --- a/realesrgan/models/__init__.py +++ b/realesrgan/models/__init__.py @@ -7,4 +7,4 @@ from os import path as osp model_folder = osp.dirname(osp.abspath(__file__)) model_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) if v.endswith('_model.py')] # import all the model modules -_model_modules = [importlib.import_module(f'models.{file_name}') for file_name in model_filenames] +_model_modules = [importlib.import_module(f'realesrgan.models.{file_name}') for file_name in model_filenames] diff --git a/realesrgan/utils.py b/realesrgan/utils.py new file mode 100644 index 0000000..06b2261 --- /dev/null +++ b/realesrgan/utils.py @@ -0,0 +1,226 @@ +import cv2 +import math +import numpy as np +import os +import torch +from basicsr.archs.rrdbnet_arch import RRDBNet +from torch.hub import download_url_to_file, get_dir +from torch.nn import functional as F +from urllib.parse import urlparse + +ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + + +class RealESRGANer(): + + def __init__(self, scale, model_path, tile=0, tile_pad=10, pre_pad=10, half=False): + self.scale = scale + self.tile_size = tile + self.tile_pad = tile_pad + self.pre_pad = pre_pad + self.mod_scale = None + self.half = half + + # initialize model + self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=scale) + + if model_path.startswith('https://'): + model_path = load_file_from_url( + url=model_path, model_dir='realesrgan/weights', progress=True, file_name=None) + loadnet = torch.load(model_path) + if 'params_ema' in loadnet: + keyname = 'params_ema' + else: + keyname = 'params' + model.load_state_dict(loadnet[keyname], strict=True) + model.eval() + self.model = model.to(self.device) + if self.half: + self.model = self.model.half() + + def pre_process(self, img): + img = torch.from_numpy(np.transpose(img, (2, 0, 1))).float() + self.img = img.unsqueeze(0).to(self.device) + if self.half: + self.img = self.img.half() + + # pre_pad + if self.pre_pad != 0: + self.img = F.pad(self.img, (0, self.pre_pad, 0, self.pre_pad), 'reflect') + # mod pad + if self.scale == 2: + self.mod_scale = 2 + elif self.scale == 1: + self.mod_scale = 4 + if self.mod_scale is not None: + self.mod_pad_h, self.mod_pad_w = 0, 0 + _, _, h, w = self.img.size() + if (h % self.mod_scale != 0): + self.mod_pad_h = (self.mod_scale - h % self.mod_scale) + if (w % self.mod_scale != 0): + self.mod_pad_w = (self.mod_scale - w % self.mod_scale) + self.img = F.pad(self.img, (0, self.mod_pad_w, 0, self.mod_pad_h), 'reflect') + + def process(self): + try: + # inference + with torch.no_grad(): + self.output = self.model(self.img) + except Exception as error: + print('Error', error) + + def tile_process(self): + """Modified from: https://github.com/ata4/esrgan-launcher + """ + batch, channel, height, width = self.img.shape + output_height = height * self.scale + output_width = width * self.scale + output_shape = (batch, channel, output_height, output_width) + + # start with black image + self.output = self.img.new_zeros(output_shape) + tiles_x = math.ceil(width / self.tile_size) + tiles_y = math.ceil(height / self.tile_size) + + # loop over all tiles + for y in range(tiles_y): + for x in range(tiles_x): + # extract tile from input image + ofs_x = x * self.tile_size + ofs_y = y * self.tile_size + # input tile area on total image + input_start_x = ofs_x + input_end_x = min(ofs_x + self.tile_size, width) + input_start_y = ofs_y + input_end_y = min(ofs_y + self.tile_size, height) + + # input tile area on total image with padding + input_start_x_pad = max(input_start_x - self.tile_pad, 0) + input_end_x_pad = min(input_end_x + self.tile_pad, width) + input_start_y_pad = max(input_start_y - self.tile_pad, 0) + input_end_y_pad = min(input_end_y + self.tile_pad, height) + + # input tile dimensions + input_tile_width = input_end_x - input_start_x + input_tile_height = input_end_y - input_start_y + tile_idx = y * tiles_x + x + 1 + input_tile = self.img[:, :, input_start_y_pad:input_end_y_pad, input_start_x_pad:input_end_x_pad] + + # upscale tile + try: + with torch.no_grad(): + output_tile = self.model(input_tile) + except Exception as error: + print('Error', error) + print(f'\tTile {tile_idx}/{tiles_x * tiles_y}') + + # output tile area on total image + output_start_x = input_start_x * self.scale + output_end_x = input_end_x * self.scale + output_start_y = input_start_y * self.scale + output_end_y = input_end_y * self.scale + + # output tile area without padding + output_start_x_tile = (input_start_x - input_start_x_pad) * self.scale + output_end_x_tile = output_start_x_tile + input_tile_width * self.scale + output_start_y_tile = (input_start_y - input_start_y_pad) * self.scale + output_end_y_tile = output_start_y_tile + input_tile_height * self.scale + + # put tile into output image + self.output[:, :, output_start_y:output_end_y, + output_start_x:output_end_x] = output_tile[:, :, output_start_y_tile:output_end_y_tile, + output_start_x_tile:output_end_x_tile] + + def post_process(self): + # remove extra pad + if self.mod_scale is not None: + _, _, h, w = self.output.size() + self.output = self.output[:, :, 0:h - self.mod_pad_h * self.scale, 0:w - self.mod_pad_w * self.scale] + # remove prepad + if self.pre_pad != 0: + _, _, h, w = self.output.size() + self.output = self.output[:, :, 0:h - self.pre_pad * self.scale, 0:w - self.pre_pad * self.scale] + return self.output + + def enhance(self, img, tile=False, alpha_upsampler='realesrgan'): + # img: numpy + img = img.astype(np.float32) + if np.max(img) > 255: # 16-bit image + max_range = 65535 + print('\tInput is a 16-bit image') + else: + max_range = 255 + img = img / max_range + if len(img.shape) == 2: # gray image + img_mode = 'L' + img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) + elif img.shape[2] == 4: # RGBA image with alpha channel + img_mode = 'RGBA' + alpha = img[:, :, 3] + img = img[:, :, 0:3] + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + if alpha_upsampler == 'realesrgan': + alpha = cv2.cvtColor(alpha, cv2.COLOR_GRAY2RGB) + else: + img_mode = 'RGB' + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + + # ------------------- process image (without the alpha channel) ------------------- # + self.pre_process(img) + if self.tile_size > 0: + self.tile_process() + else: + self.process() + output_img = self.post_process() + output_img = output_img.data.squeeze().float().cpu().clamp_(0, 1).numpy() + output_img = np.transpose(output_img[[2, 1, 0], :, :], (1, 2, 0)) + if img_mode == 'L': + output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2GRAY) + + # ------------------- process the alpha channel if necessary ------------------- # + if img_mode == 'RGBA': + if alpha_upsampler == 'realesrgan': + self.pre_process(alpha) + if self.tile_size > 0: + self.tile_process() + else: + self.process() + output_alpha = self.post_process() + output_alpha = output_alpha.data.squeeze().float().cpu().clamp_(0, 1).numpy() + output_alpha = np.transpose(output_alpha[[2, 1, 0], :, :], (1, 2, 0)) + output_alpha = cv2.cvtColor(output_alpha, cv2.COLOR_BGR2GRAY) + else: + h, w = alpha.shape[0:2] + output_alpha = cv2.resize(alpha, (w * self.scale, h * self.scale), interpolation=cv2.INTER_LINEAR) + + # merge the alpha channel + output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2BGRA) + output_img[:, :, 3] = output_alpha + + # ------------------------------ return ------------------------------ # + if max_range == 65535: # 16-bit image + output = (output_img * 65535.0).round().astype(np.uint16) + else: + output = (output_img * 255.0).round().astype(np.uint8) + return output, img_mode + + +def load_file_from_url(url, model_dir=None, progress=True, file_name=None): + """Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py + """ + if model_dir is None: + hub_dir = get_dir() + model_dir = os.path.join(hub_dir, 'checkpoints') + + os.makedirs(os.path.join(ROOT_DIR, model_dir), exist_ok=True) + + parts = urlparse(url) + filename = os.path.basename(parts.path) + if file_name is not None: + filename = file_name + cached_file = os.path.abspath(os.path.join(ROOT_DIR, model_dir, filename)) + if not os.path.exists(cached_file): + print(f'Downloading: "{url}" to {cached_file}\n') + download_url_to_file(url, cached_file, hash_prefix=None, progress=progress) + return cached_file