modify weight path

This commit is contained in:
Xintao
2022-09-19 01:43:22 +08:00
parent 89aa45c72d
commit 0ac8d66d39
8 changed files with 23 additions and 28 deletions

2
.gitignore vendored
View File

@@ -5,7 +5,7 @@ results/*
tb_logger/* tb_logger/*
wandb/* wandb/*
tmp/* tmp/*
realesrgan/weights/* weights/*
version.py version.py

View File

@@ -5,4 +5,4 @@ include inference_realesrgan.py
include VERSION include VERSION
include LICENSE include LICENSE
include requirements.txt include requirements.txt
include realesrgan/weights/README.md include weights/README.md

View File

@@ -29,52 +29,50 @@ class Predictor(BasePredictor):
def setup(self): def setup(self):
os.makedirs('output', exist_ok=True) os.makedirs('output', exist_ok=True)
# download weights # download weights
if not os.path.exists('realesrgan/weights/realesr-general-x4v3.pth'): if not os.path.exists('weights/realesr-general-x4v3.pth'):
os.system( os.system(
'wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth -P ./realesrgan/weights' 'wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth -P ./weights'
) )
if not os.path.exists('realesrgan/weights/GFPGANv1.4.pth'): if not os.path.exists('weights/GFPGANv1.4.pth'):
os.system('wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth -P ./weights')
if not os.path.exists('weights/RealESRGAN_x4plus.pth'):
os.system( os.system(
'wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth -P ./realesrgan/weights' 'wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth -P ./weights'
) )
if not os.path.exists('realesrgan/weights/RealESRGAN_x4plus.pth'): if not os.path.exists('weights/RealESRGAN_x4plus_anime_6B.pth'):
os.system( os.system(
'wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth -P ./realesrgan/weights' 'wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth -P ./weights'
) )
if not os.path.exists('realesrgan/weights/RealESRGAN_x4plus_anime_6B.pth'): if not os.path.exists('weights/realesr-animevideov3.pth'):
os.system( os.system(
'wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth -P ./realesrgan/weights' 'wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth -P ./weights'
)
if not os.path.exists('realesrgan/weights/realesr-animevideov3.pth'):
os.system(
'wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth -P ./realesrgan/weights'
) )
def choose_model(self, scale, version, tile=0): def choose_model(self, scale, version, tile=0):
half = True if torch.cuda.is_available() else False half = True if torch.cuda.is_available() else False
if version == 'General - RealESRGANplus': if version == 'General - RealESRGANplus':
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4) model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
model_path = 'realesrgan/weights/RealESRGAN_x4plus.pth' model_path = 'weights/RealESRGAN_x4plus.pth'
self.upsampler = RealESRGANer( self.upsampler = RealESRGANer(
scale=4, model_path=model_path, model=model, tile=tile, tile_pad=10, pre_pad=0, half=half) scale=4, model_path=model_path, model=model, tile=tile, tile_pad=10, pre_pad=0, half=half)
elif version == 'General - v3': elif version == 'General - v3':
model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu') model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
model_path = 'realesrgan/weights/realesr-general-x4v3.pth' model_path = 'weights/realesr-general-x4v3.pth'
self.upsampler = RealESRGANer( self.upsampler = RealESRGANer(
scale=4, model_path=model_path, model=model, tile=tile, tile_pad=10, pre_pad=0, half=half) scale=4, model_path=model_path, model=model, tile=tile, tile_pad=10, pre_pad=0, half=half)
elif version == 'Anime - anime6B': elif version == 'Anime - anime6B':
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4) model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
model_path = 'realesrgan/weights/RealESRGAN_x4plus_anime_6B.pth' model_path = 'weights/RealESRGAN_x4plus_anime_6B.pth'
self.upsampler = RealESRGANer( self.upsampler = RealESRGANer(
scale=4, model_path=model_path, model=model, tile=tile, tile_pad=10, pre_pad=0, half=half) scale=4, model_path=model_path, model=model, tile=tile, tile_pad=10, pre_pad=0, half=half)
elif version == 'AnimeVideo - v3': elif version == 'AnimeVideo - v3':
model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu') model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu')
model_path = 'realesrgan/weights/realesr-animevideov3.pth' model_path = 'weights/realesr-animevideov3.pth'
self.upsampler = RealESRGANer( self.upsampler = RealESRGANer(
scale=4, model_path=model_path, model=model, tile=tile, tile_pad=10, pre_pad=0, half=half) scale=4, model_path=model_path, model=model, tile=tile, tile_pad=10, pre_pad=0, half=half)
self.face_enhancer = GFPGANer( self.face_enhancer = GFPGANer(
model_path='realesrgan/weights/GFPGANv1.4.pth', model_path='weights/GFPGANv1.4.pth',
upscale=scale, upscale=scale,
arch='clean', arch='clean',
channel_multiplier=2, channel_multiplier=2,

View File

@@ -34,7 +34,7 @@ The following are some demos (best view in the full screen mode).
```bash ```bash
# download model # download model
wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth -P realesrgan/weights wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth -P weights
# single gpu and single process inference # single gpu and single process inference
CUDA_VISIBLE_DEVICES=0 python inference_realesrgan_video.py -i inputs/video/onepiece_demo.mp4 -n realesr-animevideov3 -s 2 --suffix outx2 CUDA_VISIBLE_DEVICES=0 python inference_realesrgan_video.py -i inputs/video/onepiece_demo.mp4 -n realesr-animevideov3 -s 2 --suffix outx2
# single gpu and multi process inference (you can use multi-processing to improve GPU utilization) # single gpu and multi process inference (you can use multi-processing to improve GPU utilization)

View File

@@ -88,13 +88,13 @@ def main():
if args.model_path is not None: if args.model_path is not None:
model_path = args.model_path model_path = args.model_path
else: else:
model_path = os.path.join('realesrgan/weights', args.model_name + '.pth') model_path = os.path.join('weights', args.model_name + '.pth')
if not os.path.isfile(model_path): if not os.path.isfile(model_path):
ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
for url in file_url: for url in file_url:
# model_path will be updated # model_path will be updated
model_path = load_file_from_url( model_path = load_file_from_url(
url=url, model_dir=os.path.join(ROOT_DIR, 'realesrgan/weights'), progress=True, file_name=None) url=url, model_dir=os.path.join(ROOT_DIR, 'weights'), progress=True, file_name=None)
# use dni to control the denoise strength # use dni to control the denoise strength
dni_weight = None dni_weight = None

View File

@@ -190,7 +190,7 @@ def inference_video(args, video_save_path, device=None, total_workers=1, worker_
# ---------------------- determine model paths ---------------------- # # ---------------------- determine model paths ---------------------- #
model_path = os.path.join('experiments/pretrained_models', args.model_name + '.pth') model_path = os.path.join('experiments/pretrained_models', args.model_name + '.pth')
if not os.path.isfile(model_path): if not os.path.isfile(model_path):
model_path = os.path.join('realesrgan/weights', args.model_name + '.pth') model_path = os.path.join('weights', args.model_name + '.pth')
if not os.path.isfile(model_path): if not os.path.isfile(model_path):
raise ValueError(f'Model {args.model_name} does not exist.') raise ValueError(f'Model {args.model_name} does not exist.')

View File

@@ -56,13 +56,10 @@ class RealESRGANer():
assert len(model_path) == len(dni_weight), 'model_path and dni_weight should have the save length.' assert len(model_path) == len(dni_weight), 'model_path and dni_weight should have the save length.'
loadnet = self.dni(model_path[0], model_path[1], dni_weight) loadnet = self.dni(model_path[0], model_path[1], dni_weight)
else: else:
# if the model_path starts with https, it will first download models to the folder: realesrgan/weights # if the model_path starts with https, it will first download models to the folder: weights
if model_path.startswith('https://'): if model_path.startswith('https://'):
model_path = load_file_from_url( model_path = load_file_from_url(
url=model_path, url=model_path, model_dir=os.path.join(ROOT_DIR, 'weights'), progress=True, file_name=None)
model_dir=os.path.join(ROOT_DIR, 'realesrgan/weights'),
progress=True,
file_name=None)
loadnet = torch.load(model_path, map_location=torch.device('cpu')) loadnet = torch.load(model_path, map_location=torch.device('cpu'))
# prefer to use params_ema # prefer to use params_ema