From 0ac8d66d39223f5ff568dc8747aaf8285e3133dc Mon Sep 17 00:00:00 2001 From: Xintao Date: Mon, 19 Sep 2022 01:43:22 +0800 Subject: [PATCH] modify weight path --- .gitignore | 2 +- MANIFEST.in | 2 +- cog_predict.py | 32 +++++++++++------------ docs/anime_video_model.md | 2 +- inference_realesrgan.py | 4 +-- inference_realesrgan_video.py | 2 +- realesrgan/utils.py | 7 ++--- {realesrgan/weights => weights}/README.md | 0 8 files changed, 23 insertions(+), 28 deletions(-) rename {realesrgan/weights => weights}/README.md (100%) diff --git a/.gitignore b/.gitignore index d5b51f8..bb86ed0 100644 --- a/.gitignore +++ b/.gitignore @@ -5,7 +5,7 @@ results/* tb_logger/* wandb/* tmp/* -realesrgan/weights/* +weights/* version.py diff --git a/MANIFEST.in b/MANIFEST.in index b18403e..b87c827 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -5,4 +5,4 @@ include inference_realesrgan.py include VERSION include LICENSE include requirements.txt -include realesrgan/weights/README.md +include weights/README.md diff --git a/cog_predict.py b/cog_predict.py index 2be4da9..fa0f89d 100644 --- a/cog_predict.py +++ b/cog_predict.py @@ -29,52 +29,50 @@ class Predictor(BasePredictor): def setup(self): os.makedirs('output', exist_ok=True) # download weights - if not os.path.exists('realesrgan/weights/realesr-general-x4v3.pth'): + if not os.path.exists('weights/realesr-general-x4v3.pth'): os.system( - 'wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth -P ./realesrgan/weights' + 'wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth -P ./weights' ) - if not os.path.exists('realesrgan/weights/GFPGANv1.4.pth'): + if not os.path.exists('weights/GFPGANv1.4.pth'): + os.system('wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth -P ./weights') + if not os.path.exists('weights/RealESRGAN_x4plus.pth'): os.system( - 'wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth -P ./realesrgan/weights' + 'wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth -P ./weights' ) - if not os.path.exists('realesrgan/weights/RealESRGAN_x4plus.pth'): + if not os.path.exists('weights/RealESRGAN_x4plus_anime_6B.pth'): os.system( - 'wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth -P ./realesrgan/weights' + 'wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth -P ./weights' ) - if not os.path.exists('realesrgan/weights/RealESRGAN_x4plus_anime_6B.pth'): + if not os.path.exists('weights/realesr-animevideov3.pth'): os.system( - 'wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth -P ./realesrgan/weights' - ) - if not os.path.exists('realesrgan/weights/realesr-animevideov3.pth'): - os.system( - 'wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth -P ./realesrgan/weights' + 'wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth -P ./weights' ) def choose_model(self, scale, version, tile=0): half = True if torch.cuda.is_available() else False if version == 'General - RealESRGANplus': model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4) - model_path = 'realesrgan/weights/RealESRGAN_x4plus.pth' + model_path = 'weights/RealESRGAN_x4plus.pth' self.upsampler = RealESRGANer( scale=4, model_path=model_path, model=model, tile=tile, tile_pad=10, pre_pad=0, half=half) elif version == 'General - v3': model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu') - model_path = 'realesrgan/weights/realesr-general-x4v3.pth' + model_path = 'weights/realesr-general-x4v3.pth' self.upsampler = RealESRGANer( scale=4, model_path=model_path, model=model, tile=tile, tile_pad=10, pre_pad=0, half=half) elif version == 'Anime - anime6B': model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4) - model_path = 'realesrgan/weights/RealESRGAN_x4plus_anime_6B.pth' + model_path = 'weights/RealESRGAN_x4plus_anime_6B.pth' self.upsampler = RealESRGANer( scale=4, model_path=model_path, model=model, tile=tile, tile_pad=10, pre_pad=0, half=half) elif version == 'AnimeVideo - v3': model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu') - model_path = 'realesrgan/weights/realesr-animevideov3.pth' + model_path = 'weights/realesr-animevideov3.pth' self.upsampler = RealESRGANer( scale=4, model_path=model_path, model=model, tile=tile, tile_pad=10, pre_pad=0, half=half) self.face_enhancer = GFPGANer( - model_path='realesrgan/weights/GFPGANv1.4.pth', + model_path='weights/GFPGANv1.4.pth', upscale=scale, arch='clean', channel_multiplier=2, diff --git a/docs/anime_video_model.md b/docs/anime_video_model.md index 79bb04b..0ad5c85 100644 --- a/docs/anime_video_model.md +++ b/docs/anime_video_model.md @@ -34,7 +34,7 @@ The following are some demos (best view in the full screen mode). ```bash # download model -wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth -P realesrgan/weights +wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth -P weights # single gpu and single process inference CUDA_VISIBLE_DEVICES=0 python inference_realesrgan_video.py -i inputs/video/onepiece_demo.mp4 -n realesr-animevideov3 -s 2 --suffix outx2 # single gpu and multi process inference (you can use multi-processing to improve GPU utilization) diff --git a/inference_realesrgan.py b/inference_realesrgan.py index 8595009..0a8cc43 100644 --- a/inference_realesrgan.py +++ b/inference_realesrgan.py @@ -88,13 +88,13 @@ def main(): if args.model_path is not None: model_path = args.model_path else: - model_path = os.path.join('realesrgan/weights', args.model_name + '.pth') + model_path = os.path.join('weights', args.model_name + '.pth') if not os.path.isfile(model_path): ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) for url in file_url: # model_path will be updated model_path = load_file_from_url( - url=url, model_dir=os.path.join(ROOT_DIR, 'realesrgan/weights'), progress=True, file_name=None) + url=url, model_dir=os.path.join(ROOT_DIR, 'weights'), progress=True, file_name=None) # use dni to control the denoise strength dni_weight = None diff --git a/inference_realesrgan_video.py b/inference_realesrgan_video.py index 3b38369..692b876 100644 --- a/inference_realesrgan_video.py +++ b/inference_realesrgan_video.py @@ -190,7 +190,7 @@ def inference_video(args, video_save_path, device=None, total_workers=1, worker_ # ---------------------- determine model paths ---------------------- # model_path = os.path.join('experiments/pretrained_models', args.model_name + '.pth') if not os.path.isfile(model_path): - model_path = os.path.join('realesrgan/weights', args.model_name + '.pth') + model_path = os.path.join('weights', args.model_name + '.pth') if not os.path.isfile(model_path): raise ValueError(f'Model {args.model_name} does not exist.') diff --git a/realesrgan/utils.py b/realesrgan/utils.py index e409360..67e5232 100644 --- a/realesrgan/utils.py +++ b/realesrgan/utils.py @@ -56,13 +56,10 @@ class RealESRGANer(): assert len(model_path) == len(dni_weight), 'model_path and dni_weight should have the save length.' loadnet = self.dni(model_path[0], model_path[1], dni_weight) else: - # if the model_path starts with https, it will first download models to the folder: realesrgan/weights + # if the model_path starts with https, it will first download models to the folder: weights if model_path.startswith('https://'): model_path = load_file_from_url( - url=model_path, - model_dir=os.path.join(ROOT_DIR, 'realesrgan/weights'), - progress=True, - file_name=None) + url=model_path, model_dir=os.path.join(ROOT_DIR, 'weights'), progress=True, file_name=None) loadnet = torch.load(model_path, map_location=torch.device('cpu')) # prefer to use params_ema diff --git a/realesrgan/weights/README.md b/weights/README.md similarity index 100% rename from realesrgan/weights/README.md rename to weights/README.md