support half inference
This commit is contained in:
@@ -17,11 +17,13 @@ def main():
|
||||
type=str,
|
||||
default='experiments/pretrained_models/RealESRGAN_x4plus.pth',
|
||||
help='Path to the pre-trained model')
|
||||
parser.add_argument('--output', type=str, default='results', help='Output folder')
|
||||
parser.add_argument('--scale', type=int, default=4, help='Upsample scale factor')
|
||||
parser.add_argument('--suffix', type=str, default='out', help='Suffix of the restored image')
|
||||
parser.add_argument('--tile', type=int, default=0, help='Tile size, 0 for no tile during testing')
|
||||
parser.add_argument('--tile_pad', type=int, default=10, help='Tile padding')
|
||||
parser.add_argument('--pre_pad', type=int, default=0, help='Pre padding size at each border')
|
||||
parser.add_argument('--half', action='store_true', help='Use half precision during inference')
|
||||
parser.add_argument(
|
||||
'--alpha_upsampler',
|
||||
type=str,
|
||||
@@ -35,8 +37,13 @@ def main():
|
||||
args = parser.parse_args()
|
||||
|
||||
upsampler = RealESRGANer(
|
||||
scale=args.scale, model_path=args.model_path, tile=args.tile, tile_pad=args.tile_pad, pre_pad=args.pre_pad)
|
||||
os.makedirs('results/', exist_ok=True)
|
||||
scale=args.scale,
|
||||
model_path=args.model_path,
|
||||
tile=args.tile,
|
||||
tile_pad=args.tile_pad,
|
||||
pre_pad=args.pre_pad,
|
||||
half=args.half)
|
||||
os.makedirs(args.output, exist_ok=True)
|
||||
if os.path.isfile(args.input):
|
||||
paths = [args.input]
|
||||
else:
|
||||
@@ -107,7 +114,7 @@ def main():
|
||||
extension = args.ext
|
||||
if img_mode == 'RGBA': # RGBA images should be saved in png format
|
||||
extension = 'png'
|
||||
save_path = f'results/{imgname}_{args.suffix}.{extension}'
|
||||
save_path = os.path.join(args.output, f'{imgname}_{args.suffix}.{extension}')
|
||||
if max_range == 65535: # 16-bit image
|
||||
output = (output_img * 65535.0).round().astype(np.uint16)
|
||||
else:
|
||||
@@ -117,12 +124,13 @@ def main():
|
||||
|
||||
class RealESRGANer():
|
||||
|
||||
def __init__(self, scale, model_path, tile=0, tile_pad=10, pre_pad=10):
|
||||
def __init__(self, scale, model_path, tile=0, tile_pad=10, pre_pad=10, half=False):
|
||||
self.scale = scale
|
||||
self.tile_size = tile
|
||||
self.tile_pad = tile_pad
|
||||
self.pre_pad = pre_pad
|
||||
self.mod_scale = None
|
||||
self.half = half
|
||||
|
||||
# initialize model
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
@@ -135,10 +143,14 @@ class RealESRGANer():
|
||||
model.load_state_dict(loadnet[keyname], strict=True)
|
||||
model.eval()
|
||||
self.model = model.to(self.device)
|
||||
if self.half:
|
||||
self.model = self.model.half()
|
||||
|
||||
def pre_process(self, img):
|
||||
img = torch.from_numpy(np.transpose(img, (2, 0, 1))).float()
|
||||
self.img = img.unsqueeze(0).to(self.device)
|
||||
if self.half:
|
||||
self.img = self.img.half()
|
||||
|
||||
# pre_pad
|
||||
if self.pre_pad != 0:
|
||||
|
||||
Reference in New Issue
Block a user