support more inference features: tile, alpha channel, gray image, 16-bit
This commit is contained in:
@@ -1,6 +1,7 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import cv2
|
import cv2
|
||||||
import glob
|
import glob
|
||||||
|
import math
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import os
|
import os
|
||||||
import torch
|
import torch
|
||||||
@@ -10,64 +11,233 @@ from torch.nn import functional as F
|
|||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('--model_path', type=str, default='experiments/pretrained_models/RealESRGAN_x4plus.pth')
|
parser.add_argument('--input', type=str, default='inputs', help='Input image or folder')
|
||||||
parser.add_argument('--scale', type=int, default=4)
|
parser.add_argument(
|
||||||
parser.add_argument('--suffix', type=str, default='_out')
|
'--model_path',
|
||||||
parser.add_argument('--input', type=str, default='inputs', help='input image or folder')
|
type=str,
|
||||||
|
default='experiments/pretrained_models/RealESRGAN_x4plus.pth',
|
||||||
|
help='Path to the pre-trained model')
|
||||||
|
parser.add_argument('--scale', type=int, default=4, help='Upsample scale factor')
|
||||||
|
parser.add_argument('--suffix', type=str, default='out', help='Suffix of the restored image')
|
||||||
|
parser.add_argument('--tile', type=int, default=0, help='Tile size, 0 for no tile during testing')
|
||||||
|
parser.add_argument('--tile_pad', type=int, default=10, help='Tile padding')
|
||||||
|
parser.add_argument('--pre_pad', type=int, default=0, help='Pre padding size at each border')
|
||||||
|
parser.add_argument(
|
||||||
|
'--alpha_upsampler',
|
||||||
|
type=str,
|
||||||
|
default='realesrgan',
|
||||||
|
help='The upsampler for the alpha channels. Options: realesrgan | bicubic')
|
||||||
|
parser.add_argument(
|
||||||
|
'--extension',
|
||||||
|
type=str,
|
||||||
|
default='auto',
|
||||||
|
help='Image extension. Options: auto | jpg | png, auto means using the same extension as inputs')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
upsampler = RealESRGANer(
|
||||||
# set up model
|
scale=args.scale, model_path=args.model_path, tile=args.tile, tile_pad=args.tile_pad, pre_pad=args.pre_pad)
|
||||||
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=args.scale)
|
os.makedirs('results/', exist_ok=True)
|
||||||
loadnet = torch.load(args.model_path)
|
if os.path.isfile(args.input):
|
||||||
|
paths = [args.input]
|
||||||
|
else:
|
||||||
|
paths = sorted(glob.glob(os.path.join(args.input, '*')))
|
||||||
|
|
||||||
|
for idx, path in enumerate(paths):
|
||||||
|
imgname, extension = os.path.splitext(os.path.basename(path))
|
||||||
|
print('Testing', idx, imgname)
|
||||||
|
|
||||||
|
# ------------------------------ read image ------------------------------ #
|
||||||
|
img = cv2.imread(path, cv2.IMREAD_UNCHANGED).astype(np.float32)
|
||||||
|
if np.max(img) > 255: # 16-bit image
|
||||||
|
max_range = 65535
|
||||||
|
print('\tInput is a 16-bit image')
|
||||||
|
else:
|
||||||
|
max_range = 255
|
||||||
|
img = img / max_range
|
||||||
|
if len(img.shape) == 2: # gray image
|
||||||
|
img_mode = 'L'
|
||||||
|
img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
|
||||||
|
elif img.shape[2] == 4: # RGBA image with alpha channel
|
||||||
|
img_mode = 'RGBA'
|
||||||
|
alpha = img[:, :, 3]
|
||||||
|
img = img[:, :, 0:3]
|
||||||
|
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||||
|
if args.alpha_upsampler == 'realesrgan':
|
||||||
|
alpha = cv2.cvtColor(alpha, cv2.COLOR_GRAY2RGB)
|
||||||
|
else:
|
||||||
|
img_mode = 'RGB'
|
||||||
|
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||||
|
|
||||||
|
# ------------------- process image (without the alpha channel) ------------------- #
|
||||||
|
upsampler.pre_process(img)
|
||||||
|
if args.tile:
|
||||||
|
upsampler.tile_process()
|
||||||
|
else:
|
||||||
|
upsampler.process()
|
||||||
|
output_img = upsampler.post_process()
|
||||||
|
output_img = output_img.data.squeeze().float().cpu().clamp_(0, 1).numpy()
|
||||||
|
output_img = np.transpose(output_img[[2, 1, 0], :, :], (1, 2, 0))
|
||||||
|
if img_mode == 'L':
|
||||||
|
output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2GRAY)
|
||||||
|
|
||||||
|
# ------------------- process the alpha channel if necessary ------------------- #
|
||||||
|
if img_mode == 'RGBA':
|
||||||
|
if args.alpha_upsampler == 'realesrgan':
|
||||||
|
upsampler.pre_process(alpha)
|
||||||
|
if args.tile:
|
||||||
|
upsampler.tile_process()
|
||||||
|
else:
|
||||||
|
upsampler.process()
|
||||||
|
output_alpha = upsampler.post_process()
|
||||||
|
output_alpha = output_alpha.data.squeeze().float().cpu().clamp_(0, 1).numpy()
|
||||||
|
output_alpha = np.transpose(output_alpha[[2, 1, 0], :, :], (1, 2, 0))
|
||||||
|
output_alpha = cv2.cvtColor(output_alpha, cv2.COLOR_BGR2GRAY)
|
||||||
|
else:
|
||||||
|
h, w = alpha.shape[0:2]
|
||||||
|
output_alpha = cv2.resize(alpha, (w * args.scale, h * args.scale), interpolation=cv2.INTER_LINEAR)
|
||||||
|
|
||||||
|
# merge the alpha channel
|
||||||
|
output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2BGRA)
|
||||||
|
output_img[:, :, 3] = output_alpha
|
||||||
|
|
||||||
|
# ------------------------------ save image ------------------------------ #
|
||||||
|
if args.extension == 'auto':
|
||||||
|
extension = extension[1:]
|
||||||
|
else:
|
||||||
|
extension == args.extension
|
||||||
|
if img_mode == 'RGBA': # RGBA images should be saved in png format
|
||||||
|
extension = 'png'
|
||||||
|
save_path = f'results/{imgname}_{args.suffix}.{extension}'
|
||||||
|
if max_range == 65535: # 16-bit image
|
||||||
|
output = (output_img * 65535.0).round().astype(np.uint16)
|
||||||
|
else:
|
||||||
|
output = (output_img * 255.0).round().astype(np.uint8)
|
||||||
|
cv2.imwrite(save_path, output)
|
||||||
|
|
||||||
|
|
||||||
|
class RealESRGANer():
|
||||||
|
|
||||||
|
def __init__(self, scale, model_path, tile=0, tile_pad=10, pre_pad=10):
|
||||||
|
self.scale = scale
|
||||||
|
self.tile_size = tile
|
||||||
|
self.tile_pad = tile_pad
|
||||||
|
self.pre_pad = pre_pad
|
||||||
|
self.mod_scale = None
|
||||||
|
|
||||||
|
# initialize model
|
||||||
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
|
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=scale)
|
||||||
|
loadnet = torch.load(model_path)
|
||||||
if 'params_ema' in loadnet:
|
if 'params_ema' in loadnet:
|
||||||
keyname = 'params_ema'
|
keyname = 'params_ema'
|
||||||
else:
|
else:
|
||||||
keyname = 'params'
|
keyname = 'params'
|
||||||
model.load_state_dict(loadnet[keyname], strict=True)
|
model.load_state_dict(loadnet[keyname], strict=True)
|
||||||
model.eval()
|
model.eval()
|
||||||
model = model.to(device)
|
self.model = model.to(self.device)
|
||||||
|
|
||||||
os.makedirs('results/', exist_ok=True)
|
def pre_process(self, img):
|
||||||
for idx, path in enumerate(sorted(glob.glob(os.path.join(args.input, '*')))):
|
img = torch.from_numpy(np.transpose(img, (2, 0, 1))).float()
|
||||||
imgname = os.path.splitext(os.path.basename(path))[0]
|
self.img = img.unsqueeze(0).to(self.device)
|
||||||
print('Testing', idx, imgname)
|
|
||||||
# read image
|
|
||||||
img = cv2.imread(path, cv2.IMREAD_COLOR).astype(np.float32) / 255.
|
|
||||||
img = torch.from_numpy(np.transpose(img[:, :, [2, 1, 0]], (2, 0, 1))).float()
|
|
||||||
img = img.unsqueeze(0).to(device)
|
|
||||||
|
|
||||||
if args.scale == 2:
|
# pre_pad
|
||||||
mod_scale = 2
|
if self.pre_pad != 0:
|
||||||
elif args.scale == 1:
|
self.img = F.pad(self.img, (0, self.pre_pad, 0, self.pre_pad), 'reflect')
|
||||||
mod_scale = 4
|
# mod pad
|
||||||
else:
|
if self.scale == 2:
|
||||||
mod_scale = None
|
self.mod_scale = 2
|
||||||
if mod_scale is not None:
|
elif self.scale == 1:
|
||||||
h_pad, w_pad = 0, 0
|
self.mod_scale = 4
|
||||||
_, _, h, w = img.size()
|
if self.mod_scale is not None:
|
||||||
if (h % mod_scale != 0):
|
self.mod_pad_h, self.mod_pad_w = 0, 0
|
||||||
h_pad = (mod_scale - h % mod_scale)
|
_, _, h, w = self.img.size()
|
||||||
if (w % mod_scale != 0):
|
if (h % self.mod_scale != 0):
|
||||||
w_pad = (mod_scale - w % mod_scale)
|
self.mod_pad_h = (self.mod_scale - h % self.mod_scale)
|
||||||
img = F.pad(img, (0, w_pad, 0, h_pad), 'reflect')
|
if (w % self.mod_scale != 0):
|
||||||
|
self.mod_pad_w = (self.mod_scale - w % self.mod_scale)
|
||||||
|
self.img = F.pad(self.img, (0, self.mod_pad_w, 0, self.mod_pad_h), 'reflect')
|
||||||
|
|
||||||
|
def process(self):
|
||||||
try:
|
try:
|
||||||
# inference
|
# inference
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
output = model(img)
|
self.output = self.model(self.img)
|
||||||
# remove extra pad
|
|
||||||
if mod_scale is not None:
|
|
||||||
_, _, h, w = output.size()
|
|
||||||
output = output[:, :, 0:h - h_pad, 0:w - w_pad]
|
|
||||||
# save image
|
|
||||||
output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
|
|
||||||
output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0))
|
|
||||||
output = (output * 255.0).round().astype(np.uint8)
|
|
||||||
cv2.imwrite(f'results/{imgname}_{args.suffix}.png', output)
|
|
||||||
except Exception as error:
|
except Exception as error:
|
||||||
print('Error', error)
|
print('Error', error)
|
||||||
|
|
||||||
|
def tile_process(self):
|
||||||
|
"""Modified from: https://github.com/ata4/esrgan-launcher
|
||||||
|
"""
|
||||||
|
batch, channel, height, width = self.img.shape
|
||||||
|
output_height = height * self.scale
|
||||||
|
output_width = width * self.scale
|
||||||
|
output_shape = (batch, channel, output_height, output_width)
|
||||||
|
|
||||||
|
# start with black image
|
||||||
|
self.output = self.img.new_zeros(output_shape)
|
||||||
|
tiles_x = math.ceil(width / self.tile_size)
|
||||||
|
tiles_y = math.ceil(height / self.tile_size)
|
||||||
|
|
||||||
|
# loop over all tiles
|
||||||
|
for y in range(tiles_y):
|
||||||
|
for x in range(tiles_x):
|
||||||
|
# extract tile from input image
|
||||||
|
ofs_x = x * self.tile_size
|
||||||
|
ofs_y = y * self.tile_size
|
||||||
|
# input tile area on total image
|
||||||
|
input_start_x = ofs_x
|
||||||
|
input_end_x = min(ofs_x + self.tile_size, width)
|
||||||
|
input_start_y = ofs_y
|
||||||
|
input_end_y = min(ofs_y + self.tile_size, height)
|
||||||
|
|
||||||
|
# input tile area on total image with padding
|
||||||
|
input_start_x_pad = max(input_start_x - self.tile_pad, 0)
|
||||||
|
input_end_x_pad = min(input_end_x + self.tile_pad, width)
|
||||||
|
input_start_y_pad = max(input_start_y - self.tile_pad, 0)
|
||||||
|
input_end_y_pad = min(input_end_y + self.tile_pad, height)
|
||||||
|
|
||||||
|
# input tile dimensions
|
||||||
|
input_tile_width = input_end_x - input_start_x
|
||||||
|
input_tile_height = input_end_y - input_start_y
|
||||||
|
tile_idx = y * tiles_x + x + 1
|
||||||
|
input_tile = self.img[:, :, input_start_y_pad:input_end_y_pad, input_start_x_pad:input_end_x_pad]
|
||||||
|
|
||||||
|
# upscale tile
|
||||||
|
try:
|
||||||
|
with torch.no_grad():
|
||||||
|
output_tile = self.model(input_tile)
|
||||||
|
except Exception as error:
|
||||||
|
print('Error', error)
|
||||||
|
print(f'\tTile {tile_idx}/{tiles_x * tiles_y}')
|
||||||
|
|
||||||
|
# output tile area on total image
|
||||||
|
output_start_x = input_start_x * self.scale
|
||||||
|
output_end_x = input_end_x * self.scale
|
||||||
|
output_start_y = input_start_y * self.scale
|
||||||
|
output_end_y = input_end_y * self.scale
|
||||||
|
|
||||||
|
# output tile area without padding
|
||||||
|
output_start_x_tile = (input_start_x - input_start_x_pad) * self.scale
|
||||||
|
output_end_x_tile = output_start_x_tile + input_tile_width * self.scale
|
||||||
|
output_start_y_tile = (input_start_y - input_start_y_pad) * self.scale
|
||||||
|
output_end_y_tile = output_start_y_tile + input_tile_height * self.scale
|
||||||
|
|
||||||
|
# put tile into output image
|
||||||
|
self.output[:, :, output_start_y:output_end_y,
|
||||||
|
output_start_x:output_end_x] = output_tile[:, :, output_start_y_tile:output_end_y_tile,
|
||||||
|
output_start_x_tile:output_end_x_tile]
|
||||||
|
|
||||||
|
def post_process(self):
|
||||||
|
# remove extra pad
|
||||||
|
if self.mod_scale is not None:
|
||||||
|
_, _, h, w = self.output.size()
|
||||||
|
self.output = self.output[:, :, 0:h - self.mod_pad_h * self.scale, 0:w - self.mod_pad_w * self.scale]
|
||||||
|
# remove prepad
|
||||||
|
if self.pre_pad != 0:
|
||||||
|
_, _, h, w = self.output.size()
|
||||||
|
self.output = self.output[:, :, 0:h - self.pre_pad * self.scale, 0:w - self.pre_pad * self.scale]
|
||||||
|
return self.output
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
main()
|
main()
|
||||||
|
|||||||
BIN
inputs/tree_alpha_16bit.png
Normal file
BIN
inputs/tree_alpha_16bit.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 373 KiB |
BIN
inputs/wolf_gray.jpg
Normal file
BIN
inputs/wolf_gray.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 52 KiB |
Reference in New Issue
Block a user