support half inference

This commit is contained in:
Xintao
2021-08-01 12:10:35 +08:00
parent f59a0c66ec
commit f932289af1

View File

@@ -17,11 +17,13 @@ def main():
type=str,
default='experiments/pretrained_models/RealESRGAN_x4plus.pth',
help='Path to the pre-trained model')
parser.add_argument('--output', type=str, default='results', help='Output folder')
parser.add_argument('--scale', type=int, default=4, help='Upsample scale factor')
parser.add_argument('--suffix', type=str, default='out', help='Suffix of the restored image')
parser.add_argument('--tile', type=int, default=0, help='Tile size, 0 for no tile during testing')
parser.add_argument('--tile_pad', type=int, default=10, help='Tile padding')
parser.add_argument('--pre_pad', type=int, default=0, help='Pre padding size at each border')
parser.add_argument('--half', action='store_true', help='Use half precision during inference')
parser.add_argument(
'--alpha_upsampler',
type=str,
@@ -35,8 +37,13 @@ def main():
args = parser.parse_args()
upsampler = RealESRGANer(
scale=args.scale, model_path=args.model_path, tile=args.tile, tile_pad=args.tile_pad, pre_pad=args.pre_pad)
os.makedirs('results/', exist_ok=True)
scale=args.scale,
model_path=args.model_path,
tile=args.tile,
tile_pad=args.tile_pad,
pre_pad=args.pre_pad,
half=args.half)
os.makedirs(args.output, exist_ok=True)
if os.path.isfile(args.input):
paths = [args.input]
else:
@@ -107,7 +114,7 @@ def main():
extension = args.ext
if img_mode == 'RGBA': # RGBA images should be saved in png format
extension = 'png'
save_path = f'results/{imgname}_{args.suffix}.{extension}'
save_path = os.path.join(args.output, f'{imgname}_{args.suffix}.{extension}')
if max_range == 65535: # 16-bit image
output = (output_img * 65535.0).round().astype(np.uint16)
else:
@@ -117,12 +124,13 @@ def main():
class RealESRGANer():
def __init__(self, scale, model_path, tile=0, tile_pad=10, pre_pad=10):
def __init__(self, scale, model_path, tile=0, tile_pad=10, pre_pad=10, half=False):
self.scale = scale
self.tile_size = tile
self.tile_pad = tile_pad
self.pre_pad = pre_pad
self.mod_scale = None
self.half = half
# initialize model
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
@@ -135,10 +143,14 @@ class RealESRGANer():
model.load_state_dict(loadnet[keyname], strict=True)
model.eval()
self.model = model.to(self.device)
if self.half:
self.model = self.model.half()
def pre_process(self, img):
img = torch.from_numpy(np.transpose(img, (2, 0, 1))).float()
self.img = img.unsqueeze(0).to(self.device)
if self.half:
self.img = self.img.half()
# pre_pad
if self.pre_pad != 0: