modify weight path

This commit is contained in:
Xintao
2022-09-19 01:43:22 +08:00
parent 89aa45c72d
commit 0ac8d66d39
8 changed files with 23 additions and 28 deletions

View File

@@ -29,52 +29,50 @@ class Predictor(BasePredictor):
def setup(self):
os.makedirs('output', exist_ok=True)
# download weights
if not os.path.exists('realesrgan/weights/realesr-general-x4v3.pth'):
if not os.path.exists('weights/realesr-general-x4v3.pth'):
os.system(
'wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth -P ./realesrgan/weights'
'wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth -P ./weights'
)
if not os.path.exists('realesrgan/weights/GFPGANv1.4.pth'):
if not os.path.exists('weights/GFPGANv1.4.pth'):
os.system('wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth -P ./weights')
if not os.path.exists('weights/RealESRGAN_x4plus.pth'):
os.system(
'wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth -P ./realesrgan/weights'
'wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth -P ./weights'
)
if not os.path.exists('realesrgan/weights/RealESRGAN_x4plus.pth'):
if not os.path.exists('weights/RealESRGAN_x4plus_anime_6B.pth'):
os.system(
'wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth -P ./realesrgan/weights'
'wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth -P ./weights'
)
if not os.path.exists('realesrgan/weights/RealESRGAN_x4plus_anime_6B.pth'):
if not os.path.exists('weights/realesr-animevideov3.pth'):
os.system(
'wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth -P ./realesrgan/weights'
)
if not os.path.exists('realesrgan/weights/realesr-animevideov3.pth'):
os.system(
'wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth -P ./realesrgan/weights'
'wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth -P ./weights'
)
def choose_model(self, scale, version, tile=0):
half = True if torch.cuda.is_available() else False
if version == 'General - RealESRGANplus':
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
model_path = 'realesrgan/weights/RealESRGAN_x4plus.pth'
model_path = 'weights/RealESRGAN_x4plus.pth'
self.upsampler = RealESRGANer(
scale=4, model_path=model_path, model=model, tile=tile, tile_pad=10, pre_pad=0, half=half)
elif version == 'General - v3':
model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
model_path = 'realesrgan/weights/realesr-general-x4v3.pth'
model_path = 'weights/realesr-general-x4v3.pth'
self.upsampler = RealESRGANer(
scale=4, model_path=model_path, model=model, tile=tile, tile_pad=10, pre_pad=0, half=half)
elif version == 'Anime - anime6B':
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
model_path = 'realesrgan/weights/RealESRGAN_x4plus_anime_6B.pth'
model_path = 'weights/RealESRGAN_x4plus_anime_6B.pth'
self.upsampler = RealESRGANer(
scale=4, model_path=model_path, model=model, tile=tile, tile_pad=10, pre_pad=0, half=half)
elif version == 'AnimeVideo - v3':
model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu')
model_path = 'realesrgan/weights/realesr-animevideov3.pth'
model_path = 'weights/realesr-animevideov3.pth'
self.upsampler = RealESRGANer(
scale=4, model_path=model_path, model=model, tile=tile, tile_pad=10, pre_pad=0, half=half)
self.face_enhancer = GFPGANer(
model_path='realesrgan/weights/GFPGANv1.4.pth',
model_path='weights/GFPGANv1.4.pth',
upscale=scale,
arch='clean',
channel_multiplier=2,