add unittest for model and utils

This commit is contained in:
Xintao
2021-11-28 19:54:19 +08:00
parent 1d180efaf3
commit 42110857ef
8 changed files with 407 additions and 27 deletions

View File

@@ -4,9 +4,8 @@ import numpy as np
import os
import torch
from basicsr.archs.rrdbnet_arch import RRDBNet
from torch.hub import download_url_to_file, get_dir
from basicsr.utils.download_util import load_file_from_url
from torch.nn import functional as F
from urllib.parse import urlparse
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
@@ -42,7 +41,7 @@ class RealESRGANer():
# if the model_path starts with https, it will first download models to the folder: realesrgan/weights
if model_path.startswith('https://'):
model_path = load_file_from_url(
url=model_path, model_dir='realesrgan/weights', progress=True, file_name=None)
url=model_path, model_dir=os.path.join(ROOT_DIR, 'realesrgan/weights'), progress=True, file_name=None)
loadnet = torch.load(model_path)
# prefer to use params_ema
if 'params_ema' in loadnet:
@@ -231,25 +230,3 @@ class RealESRGANer():
), interpolation=cv2.INTER_LANCZOS4)
return output, img_mode
def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
"""Load file form http url, will download models if necessary.
Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py
"""
if model_dir is None:
hub_dir = get_dir()
model_dir = os.path.join(hub_dir, 'checkpoints')
os.makedirs(os.path.join(ROOT_DIR, model_dir), exist_ok=True)
parts = urlparse(url)
filename = os.path.basename(parts.path)
if file_name is not None:
filename = file_name
cached_file = os.path.abspath(os.path.join(ROOT_DIR, model_dir, filename))
if not os.path.exists(cached_file):
print(f'Downloading: "{url}" to {cached_file}\n')
download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
return cached_file