diff --git a/inference_realesrgan.py b/inference_realesrgan.py index 7bd92a4..46e84ee 100644 --- a/inference_realesrgan.py +++ b/inference_realesrgan.py @@ -1,6 +1,7 @@ import argparse import cv2 import glob +import math import numpy as np import os import torch @@ -10,64 +11,233 @@ from torch.nn import functional as F def main(): parser = argparse.ArgumentParser() - parser.add_argument('--model_path', type=str, default='experiments/pretrained_models/RealESRGAN_x4plus.pth') - parser.add_argument('--scale', type=int, default=4) - parser.add_argument('--suffix', type=str, default='_out') - parser.add_argument('--input', type=str, default='inputs', help='input image or folder') + parser.add_argument('--input', type=str, default='inputs', help='Input image or folder') + parser.add_argument( + '--model_path', + type=str, + default='experiments/pretrained_models/RealESRGAN_x4plus.pth', + help='Path to the pre-trained model') + parser.add_argument('--scale', type=int, default=4, help='Upsample scale factor') + parser.add_argument('--suffix', type=str, default='out', help='Suffix of the restored image') + parser.add_argument('--tile', type=int, default=0, help='Tile size, 0 for no tile during testing') + parser.add_argument('--tile_pad', type=int, default=10, help='Tile padding') + parser.add_argument('--pre_pad', type=int, default=0, help='Pre padding size at each border') + parser.add_argument( + '--alpha_upsampler', + type=str, + default='realesrgan', + help='The upsampler for the alpha channels. Options: realesrgan | bicubic') + parser.add_argument( + '--extension', + type=str, + default='auto', + help='Image extension. Options: auto | jpg | png, auto means using the same extension as inputs') args = parser.parse_args() - device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - # set up model - model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=args.scale) - loadnet = torch.load(args.model_path) - if 'params_ema' in loadnet: - keyname = 'params_ema' - else: - keyname = 'params' - model.load_state_dict(loadnet[keyname], strict=True) - model.eval() - model = model.to(device) - + upsampler = RealESRGANer( + scale=args.scale, model_path=args.model_path, tile=args.tile, tile_pad=args.tile_pad, pre_pad=args.pre_pad) os.makedirs('results/', exist_ok=True) - for idx, path in enumerate(sorted(glob.glob(os.path.join(args.input, '*')))): - imgname = os.path.splitext(os.path.basename(path))[0] + if os.path.isfile(args.input): + paths = [args.input] + else: + paths = sorted(glob.glob(os.path.join(args.input, '*'))) + + for idx, path in enumerate(paths): + imgname, extension = os.path.splitext(os.path.basename(path)) print('Testing', idx, imgname) - # read image - img = cv2.imread(path, cv2.IMREAD_COLOR).astype(np.float32) / 255. - img = torch.from_numpy(np.transpose(img[:, :, [2, 1, 0]], (2, 0, 1))).float() - img = img.unsqueeze(0).to(device) - if args.scale == 2: - mod_scale = 2 - elif args.scale == 1: - mod_scale = 4 + # ------------------------------ read image ------------------------------ # + img = cv2.imread(path, cv2.IMREAD_UNCHANGED).astype(np.float32) + if np.max(img) > 255: # 16-bit image + max_range = 65535 + print('\tInput is a 16-bit image') else: - mod_scale = None - if mod_scale is not None: - h_pad, w_pad = 0, 0 - _, _, h, w = img.size() - if (h % mod_scale != 0): - h_pad = (mod_scale - h % mod_scale) - if (w % mod_scale != 0): - w_pad = (mod_scale - w % mod_scale) - img = F.pad(img, (0, w_pad, 0, h_pad), 'reflect') + max_range = 255 + img = img / max_range + if len(img.shape) == 2: # gray image + img_mode = 'L' + img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) + elif img.shape[2] == 4: # RGBA image with alpha channel + img_mode = 'RGBA' + alpha = img[:, :, 3] + img = img[:, :, 0:3] + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + if args.alpha_upsampler == 'realesrgan': + alpha = cv2.cvtColor(alpha, cv2.COLOR_GRAY2RGB) + else: + img_mode = 'RGB' + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + # ------------------- process image (without the alpha channel) ------------------- # + upsampler.pre_process(img) + if args.tile: + upsampler.tile_process() + else: + upsampler.process() + output_img = upsampler.post_process() + output_img = output_img.data.squeeze().float().cpu().clamp_(0, 1).numpy() + output_img = np.transpose(output_img[[2, 1, 0], :, :], (1, 2, 0)) + if img_mode == 'L': + output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2GRAY) + + # ------------------- process the alpha channel if necessary ------------------- # + if img_mode == 'RGBA': + if args.alpha_upsampler == 'realesrgan': + upsampler.pre_process(alpha) + if args.tile: + upsampler.tile_process() + else: + upsampler.process() + output_alpha = upsampler.post_process() + output_alpha = output_alpha.data.squeeze().float().cpu().clamp_(0, 1).numpy() + output_alpha = np.transpose(output_alpha[[2, 1, 0], :, :], (1, 2, 0)) + output_alpha = cv2.cvtColor(output_alpha, cv2.COLOR_BGR2GRAY) + else: + h, w = alpha.shape[0:2] + output_alpha = cv2.resize(alpha, (w * args.scale, h * args.scale), interpolation=cv2.INTER_LINEAR) + + # merge the alpha channel + output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2BGRA) + output_img[:, :, 3] = output_alpha + + # ------------------------------ save image ------------------------------ # + if args.extension == 'auto': + extension = extension[1:] + else: + extension == args.extension + if img_mode == 'RGBA': # RGBA images should be saved in png format + extension = 'png' + save_path = f'results/{imgname}_{args.suffix}.{extension}' + if max_range == 65535: # 16-bit image + output = (output_img * 65535.0).round().astype(np.uint16) + else: + output = (output_img * 255.0).round().astype(np.uint8) + cv2.imwrite(save_path, output) + + +class RealESRGANer(): + + def __init__(self, scale, model_path, tile=0, tile_pad=10, pre_pad=10): + self.scale = scale + self.tile_size = tile + self.tile_pad = tile_pad + self.pre_pad = pre_pad + self.mod_scale = None + + # initialize model + self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=scale) + loadnet = torch.load(model_path) + if 'params_ema' in loadnet: + keyname = 'params_ema' + else: + keyname = 'params' + model.load_state_dict(loadnet[keyname], strict=True) + model.eval() + self.model = model.to(self.device) + + def pre_process(self, img): + img = torch.from_numpy(np.transpose(img, (2, 0, 1))).float() + self.img = img.unsqueeze(0).to(self.device) + + # pre_pad + if self.pre_pad != 0: + self.img = F.pad(self.img, (0, self.pre_pad, 0, self.pre_pad), 'reflect') + # mod pad + if self.scale == 2: + self.mod_scale = 2 + elif self.scale == 1: + self.mod_scale = 4 + if self.mod_scale is not None: + self.mod_pad_h, self.mod_pad_w = 0, 0 + _, _, h, w = self.img.size() + if (h % self.mod_scale != 0): + self.mod_pad_h = (self.mod_scale - h % self.mod_scale) + if (w % self.mod_scale != 0): + self.mod_pad_w = (self.mod_scale - w % self.mod_scale) + self.img = F.pad(self.img, (0, self.mod_pad_w, 0, self.mod_pad_h), 'reflect') + + def process(self): try: # inference with torch.no_grad(): - output = model(img) - # remove extra pad - if mod_scale is not None: - _, _, h, w = output.size() - output = output[:, :, 0:h - h_pad, 0:w - w_pad] - # save image - output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy() - output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0)) - output = (output * 255.0).round().astype(np.uint8) - cv2.imwrite(f'results/{imgname}_{args.suffix}.png', output) + self.output = self.model(self.img) except Exception as error: print('Error', error) + def tile_process(self): + """Modified from: https://github.com/ata4/esrgan-launcher + """ + batch, channel, height, width = self.img.shape + output_height = height * self.scale + output_width = width * self.scale + output_shape = (batch, channel, output_height, output_width) + + # start with black image + self.output = self.img.new_zeros(output_shape) + tiles_x = math.ceil(width / self.tile_size) + tiles_y = math.ceil(height / self.tile_size) + + # loop over all tiles + for y in range(tiles_y): + for x in range(tiles_x): + # extract tile from input image + ofs_x = x * self.tile_size + ofs_y = y * self.tile_size + # input tile area on total image + input_start_x = ofs_x + input_end_x = min(ofs_x + self.tile_size, width) + input_start_y = ofs_y + input_end_y = min(ofs_y + self.tile_size, height) + + # input tile area on total image with padding + input_start_x_pad = max(input_start_x - self.tile_pad, 0) + input_end_x_pad = min(input_end_x + self.tile_pad, width) + input_start_y_pad = max(input_start_y - self.tile_pad, 0) + input_end_y_pad = min(input_end_y + self.tile_pad, height) + + # input tile dimensions + input_tile_width = input_end_x - input_start_x + input_tile_height = input_end_y - input_start_y + tile_idx = y * tiles_x + x + 1 + input_tile = self.img[:, :, input_start_y_pad:input_end_y_pad, input_start_x_pad:input_end_x_pad] + + # upscale tile + try: + with torch.no_grad(): + output_tile = self.model(input_tile) + except Exception as error: + print('Error', error) + print(f'\tTile {tile_idx}/{tiles_x * tiles_y}') + + # output tile area on total image + output_start_x = input_start_x * self.scale + output_end_x = input_end_x * self.scale + output_start_y = input_start_y * self.scale + output_end_y = input_end_y * self.scale + + # output tile area without padding + output_start_x_tile = (input_start_x - input_start_x_pad) * self.scale + output_end_x_tile = output_start_x_tile + input_tile_width * self.scale + output_start_y_tile = (input_start_y - input_start_y_pad) * self.scale + output_end_y_tile = output_start_y_tile + input_tile_height * self.scale + + # put tile into output image + self.output[:, :, output_start_y:output_end_y, + output_start_x:output_end_x] = output_tile[:, :, output_start_y_tile:output_end_y_tile, + output_start_x_tile:output_end_x_tile] + + def post_process(self): + # remove extra pad + if self.mod_scale is not None: + _, _, h, w = self.output.size() + self.output = self.output[:, :, 0:h - self.mod_pad_h * self.scale, 0:w - self.mod_pad_w * self.scale] + # remove prepad + if self.pre_pad != 0: + _, _, h, w = self.output.size() + self.output = self.output[:, :, 0:h - self.pre_pad * self.scale, 0:w - self.pre_pad * self.scale] + return self.output + if __name__ == '__main__': main() diff --git a/inputs/tree_alpha_16bit.png b/inputs/tree_alpha_16bit.png new file mode 100644 index 0000000..ca7c2aa Binary files /dev/null and b/inputs/tree_alpha_16bit.png differ diff --git a/inputs/wolf_gray.jpg b/inputs/wolf_gray.jpg new file mode 100644 index 0000000..614766b Binary files /dev/null and b/inputs/wolf_gray.jpg differ