support model config during inference
This commit is contained in:
@@ -2,6 +2,7 @@ import argparse
|
|||||||
import cv2
|
import cv2
|
||||||
import glob
|
import glob
|
||||||
import os
|
import os
|
||||||
|
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||||
|
|
||||||
from realesrgan import RealESRGANer
|
from realesrgan import RealESRGANer
|
||||||
|
|
||||||
@@ -18,11 +19,12 @@ def main():
|
|||||||
parser.add_argument('--netscale', type=int, default=4, help='Upsample scale factor of the network')
|
parser.add_argument('--netscale', type=int, default=4, help='Upsample scale factor of the network')
|
||||||
parser.add_argument('--outscale', type=float, default=4, help='The final upsampling scale of the image')
|
parser.add_argument('--outscale', type=float, default=4, help='The final upsampling scale of the image')
|
||||||
parser.add_argument('--suffix', type=str, default='out', help='Suffix of the restored image')
|
parser.add_argument('--suffix', type=str, default='out', help='Suffix of the restored image')
|
||||||
parser.add_argument('--tile', type=int, default=800, help='Tile size, 0 for no tile during testing')
|
parser.add_argument('--tile', type=int, default=0, help='Tile size, 0 for no tile during testing')
|
||||||
parser.add_argument('--tile_pad', type=int, default=10, help='Tile padding')
|
parser.add_argument('--tile_pad', type=int, default=10, help='Tile padding')
|
||||||
parser.add_argument('--pre_pad', type=int, default=0, help='Pre padding size at each border')
|
parser.add_argument('--pre_pad', type=int, default=0, help='Pre padding size at each border')
|
||||||
parser.add_argument('--face_enhance', action='store_true', help='Use GFPGAN to enhance face')
|
parser.add_argument('--face_enhance', action='store_true', help='Use GFPGAN to enhance face')
|
||||||
parser.add_argument('--half', action='store_true', help='Use half precision during inference')
|
parser.add_argument('--half', action='store_true', help='Use half precision during inference')
|
||||||
|
parser.add_argument('--block', type=int, default=23, help='num_block in RRDB')
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--alpha_upsampler',
|
'--alpha_upsampler',
|
||||||
type=str,
|
type=str,
|
||||||
@@ -35,9 +37,17 @@ def main():
|
|||||||
help='Image extension. Options: auto | jpg | png, auto means using the same extension as inputs')
|
help='Image extension. Options: auto | jpg | png, auto means using the same extension as inputs')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if 'RealESRGAN_x4plus_anime_6B.pth' in args.model_path:
|
||||||
|
args.block = 6
|
||||||
|
elif 'RealESRGAN_x2plus.pth' in args.model_path:
|
||||||
|
args.netscale = 2
|
||||||
|
|
||||||
|
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=args.block, num_grow_ch=32, scale=args.netscale)
|
||||||
|
|
||||||
upsampler = RealESRGANer(
|
upsampler = RealESRGANer(
|
||||||
scale=args.netscale,
|
scale=args.netscale,
|
||||||
model_path=args.model_path,
|
model_path=args.model_path,
|
||||||
|
model=model,
|
||||||
tile=args.tile,
|
tile=args.tile,
|
||||||
tile_pad=args.tile_pad,
|
tile_pad=args.tile_pad,
|
||||||
pre_pad=args.pre_pad,
|
pre_pad=args.pre_pad,
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
|||||||
|
|
||||||
class RealESRGANer():
|
class RealESRGANer():
|
||||||
|
|
||||||
def __init__(self, scale, model_path, tile=0, tile_pad=10, pre_pad=10, half=False):
|
def __init__(self, scale, model_path, model=None, tile=0, tile_pad=10, pre_pad=10, half=False):
|
||||||
self.scale = scale
|
self.scale = scale
|
||||||
self.tile_size = tile
|
self.tile_size = tile
|
||||||
self.tile_pad = tile_pad
|
self.tile_pad = tile_pad
|
||||||
@@ -23,7 +23,8 @@ class RealESRGANer():
|
|||||||
|
|
||||||
# initialize model
|
# initialize model
|
||||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=scale)
|
if model is None:
|
||||||
|
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=scale)
|
||||||
|
|
||||||
if model_path.startswith('https://'):
|
if model_path.startswith('https://'):
|
||||||
model_path = load_file_from_url(
|
model_path = load_file_from_url(
|
||||||
|
|||||||
Reference in New Issue
Block a user