From 8454fd2c7a2f63eb6ab35bf6305668ef053ab2ff Mon Sep 17 00:00:00 2001 From: Xintao Date: Sun, 25 Jul 2021 16:16:57 +0800 Subject: [PATCH] add pytorch2onnx --- scripts/pytorch2onnx.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) create mode 100644 scripts/pytorch2onnx.py diff --git a/scripts/pytorch2onnx.py b/scripts/pytorch2onnx.py new file mode 100644 index 0000000..ac104b6 --- /dev/null +++ b/scripts/pytorch2onnx.py @@ -0,0 +1,17 @@ +import torch +import torch.onnx +from basicsr.archs.rrdbnet_arch import RRDBNet + +# An instance of your model +model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32) +model.load_state_dict(torch.load('experiments/pretrained_models/RealESRGAN_x4plus.pth')['params_ema']) +# set the train mode to false since we will only run the forward pass. +model.train(False) +model.cpu().eval() + +# An example input you would normally provide to your model's forward() method +x = torch.rand(1, 3, 64, 64) + +# Export the model +with torch.no_grad(): + torch_out = torch.onnx._export(model, x, 'realesrgan-x4.onnx', opset_version=11, export_params=True)