modify weight path
This commit is contained in:
2
.gitignore
vendored
2
.gitignore
vendored
@@ -5,7 +5,7 @@ results/*
|
|||||||
tb_logger/*
|
tb_logger/*
|
||||||
wandb/*
|
wandb/*
|
||||||
tmp/*
|
tmp/*
|
||||||
realesrgan/weights/*
|
weights/*
|
||||||
|
|
||||||
version.py
|
version.py
|
||||||
|
|
||||||
|
|||||||
@@ -5,4 +5,4 @@ include inference_realesrgan.py
|
|||||||
include VERSION
|
include VERSION
|
||||||
include LICENSE
|
include LICENSE
|
||||||
include requirements.txt
|
include requirements.txt
|
||||||
include realesrgan/weights/README.md
|
include weights/README.md
|
||||||
|
|||||||
@@ -29,52 +29,50 @@ class Predictor(BasePredictor):
|
|||||||
def setup(self):
|
def setup(self):
|
||||||
os.makedirs('output', exist_ok=True)
|
os.makedirs('output', exist_ok=True)
|
||||||
# download weights
|
# download weights
|
||||||
if not os.path.exists('realesrgan/weights/realesr-general-x4v3.pth'):
|
if not os.path.exists('weights/realesr-general-x4v3.pth'):
|
||||||
os.system(
|
os.system(
|
||||||
'wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth -P ./realesrgan/weights'
|
'wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth -P ./weights'
|
||||||
)
|
)
|
||||||
if not os.path.exists('realesrgan/weights/GFPGANv1.4.pth'):
|
if not os.path.exists('weights/GFPGANv1.4.pth'):
|
||||||
|
os.system('wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth -P ./weights')
|
||||||
|
if not os.path.exists('weights/RealESRGAN_x4plus.pth'):
|
||||||
os.system(
|
os.system(
|
||||||
'wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth -P ./realesrgan/weights'
|
'wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth -P ./weights'
|
||||||
)
|
)
|
||||||
if not os.path.exists('realesrgan/weights/RealESRGAN_x4plus.pth'):
|
if not os.path.exists('weights/RealESRGAN_x4plus_anime_6B.pth'):
|
||||||
os.system(
|
os.system(
|
||||||
'wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth -P ./realesrgan/weights'
|
'wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth -P ./weights'
|
||||||
)
|
)
|
||||||
if not os.path.exists('realesrgan/weights/RealESRGAN_x4plus_anime_6B.pth'):
|
if not os.path.exists('weights/realesr-animevideov3.pth'):
|
||||||
os.system(
|
os.system(
|
||||||
'wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth -P ./realesrgan/weights'
|
'wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth -P ./weights'
|
||||||
)
|
|
||||||
if not os.path.exists('realesrgan/weights/realesr-animevideov3.pth'):
|
|
||||||
os.system(
|
|
||||||
'wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth -P ./realesrgan/weights'
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def choose_model(self, scale, version, tile=0):
|
def choose_model(self, scale, version, tile=0):
|
||||||
half = True if torch.cuda.is_available() else False
|
half = True if torch.cuda.is_available() else False
|
||||||
if version == 'General - RealESRGANplus':
|
if version == 'General - RealESRGANplus':
|
||||||
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
|
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
|
||||||
model_path = 'realesrgan/weights/RealESRGAN_x4plus.pth'
|
model_path = 'weights/RealESRGAN_x4plus.pth'
|
||||||
self.upsampler = RealESRGANer(
|
self.upsampler = RealESRGANer(
|
||||||
scale=4, model_path=model_path, model=model, tile=tile, tile_pad=10, pre_pad=0, half=half)
|
scale=4, model_path=model_path, model=model, tile=tile, tile_pad=10, pre_pad=0, half=half)
|
||||||
elif version == 'General - v3':
|
elif version == 'General - v3':
|
||||||
model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
|
model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
|
||||||
model_path = 'realesrgan/weights/realesr-general-x4v3.pth'
|
model_path = 'weights/realesr-general-x4v3.pth'
|
||||||
self.upsampler = RealESRGANer(
|
self.upsampler = RealESRGANer(
|
||||||
scale=4, model_path=model_path, model=model, tile=tile, tile_pad=10, pre_pad=0, half=half)
|
scale=4, model_path=model_path, model=model, tile=tile, tile_pad=10, pre_pad=0, half=half)
|
||||||
elif version == 'Anime - anime6B':
|
elif version == 'Anime - anime6B':
|
||||||
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
|
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
|
||||||
model_path = 'realesrgan/weights/RealESRGAN_x4plus_anime_6B.pth'
|
model_path = 'weights/RealESRGAN_x4plus_anime_6B.pth'
|
||||||
self.upsampler = RealESRGANer(
|
self.upsampler = RealESRGANer(
|
||||||
scale=4, model_path=model_path, model=model, tile=tile, tile_pad=10, pre_pad=0, half=half)
|
scale=4, model_path=model_path, model=model, tile=tile, tile_pad=10, pre_pad=0, half=half)
|
||||||
elif version == 'AnimeVideo - v3':
|
elif version == 'AnimeVideo - v3':
|
||||||
model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu')
|
model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu')
|
||||||
model_path = 'realesrgan/weights/realesr-animevideov3.pth'
|
model_path = 'weights/realesr-animevideov3.pth'
|
||||||
self.upsampler = RealESRGANer(
|
self.upsampler = RealESRGANer(
|
||||||
scale=4, model_path=model_path, model=model, tile=tile, tile_pad=10, pre_pad=0, half=half)
|
scale=4, model_path=model_path, model=model, tile=tile, tile_pad=10, pre_pad=0, half=half)
|
||||||
|
|
||||||
self.face_enhancer = GFPGANer(
|
self.face_enhancer = GFPGANer(
|
||||||
model_path='realesrgan/weights/GFPGANv1.4.pth',
|
model_path='weights/GFPGANv1.4.pth',
|
||||||
upscale=scale,
|
upscale=scale,
|
||||||
arch='clean',
|
arch='clean',
|
||||||
channel_multiplier=2,
|
channel_multiplier=2,
|
||||||
|
|||||||
@@ -34,7 +34,7 @@ The following are some demos (best view in the full screen mode).
|
|||||||
|
|
||||||
```bash
|
```bash
|
||||||
# download model
|
# download model
|
||||||
wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth -P realesrgan/weights
|
wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth -P weights
|
||||||
# single gpu and single process inference
|
# single gpu and single process inference
|
||||||
CUDA_VISIBLE_DEVICES=0 python inference_realesrgan_video.py -i inputs/video/onepiece_demo.mp4 -n realesr-animevideov3 -s 2 --suffix outx2
|
CUDA_VISIBLE_DEVICES=0 python inference_realesrgan_video.py -i inputs/video/onepiece_demo.mp4 -n realesr-animevideov3 -s 2 --suffix outx2
|
||||||
# single gpu and multi process inference (you can use multi-processing to improve GPU utilization)
|
# single gpu and multi process inference (you can use multi-processing to improve GPU utilization)
|
||||||
|
|||||||
@@ -88,13 +88,13 @@ def main():
|
|||||||
if args.model_path is not None:
|
if args.model_path is not None:
|
||||||
model_path = args.model_path
|
model_path = args.model_path
|
||||||
else:
|
else:
|
||||||
model_path = os.path.join('realesrgan/weights', args.model_name + '.pth')
|
model_path = os.path.join('weights', args.model_name + '.pth')
|
||||||
if not os.path.isfile(model_path):
|
if not os.path.isfile(model_path):
|
||||||
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||||
for url in file_url:
|
for url in file_url:
|
||||||
# model_path will be updated
|
# model_path will be updated
|
||||||
model_path = load_file_from_url(
|
model_path = load_file_from_url(
|
||||||
url=url, model_dir=os.path.join(ROOT_DIR, 'realesrgan/weights'), progress=True, file_name=None)
|
url=url, model_dir=os.path.join(ROOT_DIR, 'weights'), progress=True, file_name=None)
|
||||||
|
|
||||||
# use dni to control the denoise strength
|
# use dni to control the denoise strength
|
||||||
dni_weight = None
|
dni_weight = None
|
||||||
|
|||||||
@@ -190,7 +190,7 @@ def inference_video(args, video_save_path, device=None, total_workers=1, worker_
|
|||||||
# ---------------------- determine model paths ---------------------- #
|
# ---------------------- determine model paths ---------------------- #
|
||||||
model_path = os.path.join('experiments/pretrained_models', args.model_name + '.pth')
|
model_path = os.path.join('experiments/pretrained_models', args.model_name + '.pth')
|
||||||
if not os.path.isfile(model_path):
|
if not os.path.isfile(model_path):
|
||||||
model_path = os.path.join('realesrgan/weights', args.model_name + '.pth')
|
model_path = os.path.join('weights', args.model_name + '.pth')
|
||||||
if not os.path.isfile(model_path):
|
if not os.path.isfile(model_path):
|
||||||
raise ValueError(f'Model {args.model_name} does not exist.')
|
raise ValueError(f'Model {args.model_name} does not exist.')
|
||||||
|
|
||||||
|
|||||||
@@ -56,13 +56,10 @@ class RealESRGANer():
|
|||||||
assert len(model_path) == len(dni_weight), 'model_path and dni_weight should have the save length.'
|
assert len(model_path) == len(dni_weight), 'model_path and dni_weight should have the save length.'
|
||||||
loadnet = self.dni(model_path[0], model_path[1], dni_weight)
|
loadnet = self.dni(model_path[0], model_path[1], dni_weight)
|
||||||
else:
|
else:
|
||||||
# if the model_path starts with https, it will first download models to the folder: realesrgan/weights
|
# if the model_path starts with https, it will first download models to the folder: weights
|
||||||
if model_path.startswith('https://'):
|
if model_path.startswith('https://'):
|
||||||
model_path = load_file_from_url(
|
model_path = load_file_from_url(
|
||||||
url=model_path,
|
url=model_path, model_dir=os.path.join(ROOT_DIR, 'weights'), progress=True, file_name=None)
|
||||||
model_dir=os.path.join(ROOT_DIR, 'realesrgan/weights'),
|
|
||||||
progress=True,
|
|
||||||
file_name=None)
|
|
||||||
loadnet = torch.load(model_path, map_location=torch.device('cpu'))
|
loadnet = torch.load(model_path, map_location=torch.device('cpu'))
|
||||||
|
|
||||||
# prefer to use params_ema
|
# prefer to use params_ema
|
||||||
|
|||||||
Reference in New Issue
Block a user