modify weight path
This commit is contained in:
2
.gitignore
vendored
2
.gitignore
vendored
@@ -5,7 +5,7 @@ results/*
|
||||
tb_logger/*
|
||||
wandb/*
|
||||
tmp/*
|
||||
realesrgan/weights/*
|
||||
weights/*
|
||||
|
||||
version.py
|
||||
|
||||
|
||||
@@ -5,4 +5,4 @@ include inference_realesrgan.py
|
||||
include VERSION
|
||||
include LICENSE
|
||||
include requirements.txt
|
||||
include realesrgan/weights/README.md
|
||||
include weights/README.md
|
||||
|
||||
@@ -29,52 +29,50 @@ class Predictor(BasePredictor):
|
||||
def setup(self):
|
||||
os.makedirs('output', exist_ok=True)
|
||||
# download weights
|
||||
if not os.path.exists('realesrgan/weights/realesr-general-x4v3.pth'):
|
||||
if not os.path.exists('weights/realesr-general-x4v3.pth'):
|
||||
os.system(
|
||||
'wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth -P ./realesrgan/weights'
|
||||
'wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth -P ./weights'
|
||||
)
|
||||
if not os.path.exists('realesrgan/weights/GFPGANv1.4.pth'):
|
||||
if not os.path.exists('weights/GFPGANv1.4.pth'):
|
||||
os.system('wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth -P ./weights')
|
||||
if not os.path.exists('weights/RealESRGAN_x4plus.pth'):
|
||||
os.system(
|
||||
'wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth -P ./realesrgan/weights'
|
||||
'wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth -P ./weights'
|
||||
)
|
||||
if not os.path.exists('realesrgan/weights/RealESRGAN_x4plus.pth'):
|
||||
if not os.path.exists('weights/RealESRGAN_x4plus_anime_6B.pth'):
|
||||
os.system(
|
||||
'wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth -P ./realesrgan/weights'
|
||||
'wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth -P ./weights'
|
||||
)
|
||||
if not os.path.exists('realesrgan/weights/RealESRGAN_x4plus_anime_6B.pth'):
|
||||
if not os.path.exists('weights/realesr-animevideov3.pth'):
|
||||
os.system(
|
||||
'wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth -P ./realesrgan/weights'
|
||||
)
|
||||
if not os.path.exists('realesrgan/weights/realesr-animevideov3.pth'):
|
||||
os.system(
|
||||
'wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth -P ./realesrgan/weights'
|
||||
'wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth -P ./weights'
|
||||
)
|
||||
|
||||
def choose_model(self, scale, version, tile=0):
|
||||
half = True if torch.cuda.is_available() else False
|
||||
if version == 'General - RealESRGANplus':
|
||||
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
|
||||
model_path = 'realesrgan/weights/RealESRGAN_x4plus.pth'
|
||||
model_path = 'weights/RealESRGAN_x4plus.pth'
|
||||
self.upsampler = RealESRGANer(
|
||||
scale=4, model_path=model_path, model=model, tile=tile, tile_pad=10, pre_pad=0, half=half)
|
||||
elif version == 'General - v3':
|
||||
model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
|
||||
model_path = 'realesrgan/weights/realesr-general-x4v3.pth'
|
||||
model_path = 'weights/realesr-general-x4v3.pth'
|
||||
self.upsampler = RealESRGANer(
|
||||
scale=4, model_path=model_path, model=model, tile=tile, tile_pad=10, pre_pad=0, half=half)
|
||||
elif version == 'Anime - anime6B':
|
||||
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
|
||||
model_path = 'realesrgan/weights/RealESRGAN_x4plus_anime_6B.pth'
|
||||
model_path = 'weights/RealESRGAN_x4plus_anime_6B.pth'
|
||||
self.upsampler = RealESRGANer(
|
||||
scale=4, model_path=model_path, model=model, tile=tile, tile_pad=10, pre_pad=0, half=half)
|
||||
elif version == 'AnimeVideo - v3':
|
||||
model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu')
|
||||
model_path = 'realesrgan/weights/realesr-animevideov3.pth'
|
||||
model_path = 'weights/realesr-animevideov3.pth'
|
||||
self.upsampler = RealESRGANer(
|
||||
scale=4, model_path=model_path, model=model, tile=tile, tile_pad=10, pre_pad=0, half=half)
|
||||
|
||||
self.face_enhancer = GFPGANer(
|
||||
model_path='realesrgan/weights/GFPGANv1.4.pth',
|
||||
model_path='weights/GFPGANv1.4.pth',
|
||||
upscale=scale,
|
||||
arch='clean',
|
||||
channel_multiplier=2,
|
||||
|
||||
@@ -34,7 +34,7 @@ The following are some demos (best view in the full screen mode).
|
||||
|
||||
```bash
|
||||
# download model
|
||||
wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth -P realesrgan/weights
|
||||
wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth -P weights
|
||||
# single gpu and single process inference
|
||||
CUDA_VISIBLE_DEVICES=0 python inference_realesrgan_video.py -i inputs/video/onepiece_demo.mp4 -n realesr-animevideov3 -s 2 --suffix outx2
|
||||
# single gpu and multi process inference (you can use multi-processing to improve GPU utilization)
|
||||
|
||||
@@ -88,13 +88,13 @@ def main():
|
||||
if args.model_path is not None:
|
||||
model_path = args.model_path
|
||||
else:
|
||||
model_path = os.path.join('realesrgan/weights', args.model_name + '.pth')
|
||||
model_path = os.path.join('weights', args.model_name + '.pth')
|
||||
if not os.path.isfile(model_path):
|
||||
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
for url in file_url:
|
||||
# model_path will be updated
|
||||
model_path = load_file_from_url(
|
||||
url=url, model_dir=os.path.join(ROOT_DIR, 'realesrgan/weights'), progress=True, file_name=None)
|
||||
url=url, model_dir=os.path.join(ROOT_DIR, 'weights'), progress=True, file_name=None)
|
||||
|
||||
# use dni to control the denoise strength
|
||||
dni_weight = None
|
||||
|
||||
@@ -190,7 +190,7 @@ def inference_video(args, video_save_path, device=None, total_workers=1, worker_
|
||||
# ---------------------- determine model paths ---------------------- #
|
||||
model_path = os.path.join('experiments/pretrained_models', args.model_name + '.pth')
|
||||
if not os.path.isfile(model_path):
|
||||
model_path = os.path.join('realesrgan/weights', args.model_name + '.pth')
|
||||
model_path = os.path.join('weights', args.model_name + '.pth')
|
||||
if not os.path.isfile(model_path):
|
||||
raise ValueError(f'Model {args.model_name} does not exist.')
|
||||
|
||||
|
||||
@@ -56,13 +56,10 @@ class RealESRGANer():
|
||||
assert len(model_path) == len(dni_weight), 'model_path and dni_weight should have the save length.'
|
||||
loadnet = self.dni(model_path[0], model_path[1], dni_weight)
|
||||
else:
|
||||
# if the model_path starts with https, it will first download models to the folder: realesrgan/weights
|
||||
# if the model_path starts with https, it will first download models to the folder: weights
|
||||
if model_path.startswith('https://'):
|
||||
model_path = load_file_from_url(
|
||||
url=model_path,
|
||||
model_dir=os.path.join(ROOT_DIR, 'realesrgan/weights'),
|
||||
progress=True,
|
||||
file_name=None)
|
||||
url=model_path, model_dir=os.path.join(ROOT_DIR, 'weights'), progress=True, file_name=None)
|
||||
loadnet = torch.load(model_path, map_location=torch.device('cpu'))
|
||||
|
||||
# prefer to use params_ema
|
||||
|
||||
Reference in New Issue
Block a user