From c1669c4b0a195fe2a96179b0b7a638b3d1d2375c Mon Sep 17 00:00:00 2001 From: Xintao Date: Tue, 31 Aug 2021 19:58:40 +0800 Subject: [PATCH] support model config during inference --- inference_realesrgan.py | 12 +++++++++++- realesrgan/utils.py | 5 +++-- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/inference_realesrgan.py b/inference_realesrgan.py index 784c817..57ec0ce 100644 --- a/inference_realesrgan.py +++ b/inference_realesrgan.py @@ -2,6 +2,7 @@ import argparse import cv2 import glob import os +from basicsr.archs.rrdbnet_arch import RRDBNet from realesrgan import RealESRGANer @@ -18,11 +19,12 @@ def main(): parser.add_argument('--netscale', type=int, default=4, help='Upsample scale factor of the network') parser.add_argument('--outscale', type=float, default=4, help='The final upsampling scale of the image') parser.add_argument('--suffix', type=str, default='out', help='Suffix of the restored image') - parser.add_argument('--tile', type=int, default=800, help='Tile size, 0 for no tile during testing') + parser.add_argument('--tile', type=int, default=0, help='Tile size, 0 for no tile during testing') parser.add_argument('--tile_pad', type=int, default=10, help='Tile padding') parser.add_argument('--pre_pad', type=int, default=0, help='Pre padding size at each border') parser.add_argument('--face_enhance', action='store_true', help='Use GFPGAN to enhance face') parser.add_argument('--half', action='store_true', help='Use half precision during inference') + parser.add_argument('--block', type=int, default=23, help='num_block in RRDB') parser.add_argument( '--alpha_upsampler', type=str, @@ -35,9 +37,17 @@ def main(): help='Image extension. Options: auto | jpg | png, auto means using the same extension as inputs') args = parser.parse_args() + if 'RealESRGAN_x4plus_anime_6B.pth' in args.model_path: + args.block = 6 + elif 'RealESRGAN_x2plus.pth' in args.model_path: + args.netscale = 2 + + model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=args.block, num_grow_ch=32, scale=args.netscale) + upsampler = RealESRGANer( scale=args.netscale, model_path=args.model_path, + model=model, tile=args.tile, tile_pad=args.tile_pad, pre_pad=args.pre_pad, diff --git a/realesrgan/utils.py b/realesrgan/utils.py index 15f1957..a815cb3 100644 --- a/realesrgan/utils.py +++ b/realesrgan/utils.py @@ -13,7 +13,7 @@ ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) class RealESRGANer(): - def __init__(self, scale, model_path, tile=0, tile_pad=10, pre_pad=10, half=False): + def __init__(self, scale, model_path, model=None, tile=0, tile_pad=10, pre_pad=10, half=False): self.scale = scale self.tile_size = tile self.tile_pad = tile_pad @@ -23,7 +23,8 @@ class RealESRGANer(): # initialize model self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=scale) + if model is None: + model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=scale) if model_path.startswith('https://'): model_path = load_file_from_url(