From f932289af1550ded97a41632ee39944a7a6b517c Mon Sep 17 00:00:00 2001 From: Xintao Date: Sun, 1 Aug 2021 12:10:35 +0800 Subject: [PATCH] support half inference --- inference_realesrgan.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/inference_realesrgan.py b/inference_realesrgan.py index 7de2414..0418253 100644 --- a/inference_realesrgan.py +++ b/inference_realesrgan.py @@ -17,11 +17,13 @@ def main(): type=str, default='experiments/pretrained_models/RealESRGAN_x4plus.pth', help='Path to the pre-trained model') + parser.add_argument('--output', type=str, default='results', help='Output folder') parser.add_argument('--scale', type=int, default=4, help='Upsample scale factor') parser.add_argument('--suffix', type=str, default='out', help='Suffix of the restored image') parser.add_argument('--tile', type=int, default=0, help='Tile size, 0 for no tile during testing') parser.add_argument('--tile_pad', type=int, default=10, help='Tile padding') parser.add_argument('--pre_pad', type=int, default=0, help='Pre padding size at each border') + parser.add_argument('--half', action='store_true', help='Use half precision during inference') parser.add_argument( '--alpha_upsampler', type=str, @@ -35,8 +37,13 @@ def main(): args = parser.parse_args() upsampler = RealESRGANer( - scale=args.scale, model_path=args.model_path, tile=args.tile, tile_pad=args.tile_pad, pre_pad=args.pre_pad) - os.makedirs('results/', exist_ok=True) + scale=args.scale, + model_path=args.model_path, + tile=args.tile, + tile_pad=args.tile_pad, + pre_pad=args.pre_pad, + half=args.half) + os.makedirs(args.output, exist_ok=True) if os.path.isfile(args.input): paths = [args.input] else: @@ -107,7 +114,7 @@ def main(): extension = args.ext if img_mode == 'RGBA': # RGBA images should be saved in png format extension = 'png' - save_path = f'results/{imgname}_{args.suffix}.{extension}' + save_path = os.path.join(args.output, f'{imgname}_{args.suffix}.{extension}') if max_range == 65535: # 16-bit image output = (output_img * 65535.0).round().astype(np.uint16) else: @@ -117,12 +124,13 @@ def main(): class RealESRGANer(): - def __init__(self, scale, model_path, tile=0, tile_pad=10, pre_pad=10): + def __init__(self, scale, model_path, tile=0, tile_pad=10, pre_pad=10, half=False): self.scale = scale self.tile_size = tile self.tile_pad = tile_pad self.pre_pad = pre_pad self.mod_scale = None + self.half = half # initialize model self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') @@ -135,10 +143,14 @@ class RealESRGANer(): model.load_state_dict(loadnet[keyname], strict=True) model.eval() self.model = model.to(self.device) + if self.half: + self.model = self.model.half() def pre_process(self, img): img = torch.from_numpy(np.transpose(img, (2, 0, 1))).float() self.img = img.unsqueeze(0).to(self.device) + if self.half: + self.img = self.img.half() # pre_pad if self.pre_pad != 0: