add unittest for model and utils
This commit is contained in:
@@ -4,9 +4,8 @@ import numpy as np
|
||||
import os
|
||||
import torch
|
||||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||
from torch.hub import download_url_to_file, get_dir
|
||||
from basicsr.utils.download_util import load_file_from_url
|
||||
from torch.nn import functional as F
|
||||
from urllib.parse import urlparse
|
||||
|
||||
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
@@ -42,7 +41,7 @@ class RealESRGANer():
|
||||
# if the model_path starts with https, it will first download models to the folder: realesrgan/weights
|
||||
if model_path.startswith('https://'):
|
||||
model_path = load_file_from_url(
|
||||
url=model_path, model_dir='realesrgan/weights', progress=True, file_name=None)
|
||||
url=model_path, model_dir=os.path.join(ROOT_DIR, 'realesrgan/weights'), progress=True, file_name=None)
|
||||
loadnet = torch.load(model_path)
|
||||
# prefer to use params_ema
|
||||
if 'params_ema' in loadnet:
|
||||
@@ -231,25 +230,3 @@ class RealESRGANer():
|
||||
), interpolation=cv2.INTER_LANCZOS4)
|
||||
|
||||
return output, img_mode
|
||||
|
||||
|
||||
def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
|
||||
"""Load file form http url, will download models if necessary.
|
||||
|
||||
Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py
|
||||
"""
|
||||
if model_dir is None:
|
||||
hub_dir = get_dir()
|
||||
model_dir = os.path.join(hub_dir, 'checkpoints')
|
||||
|
||||
os.makedirs(os.path.join(ROOT_DIR, model_dir), exist_ok=True)
|
||||
|
||||
parts = urlparse(url)
|
||||
filename = os.path.basename(parts.path)
|
||||
if file_name is not None:
|
||||
filename = file_name
|
||||
cached_file = os.path.abspath(os.path.join(ROOT_DIR, model_dir, filename))
|
||||
if not os.path.exists(cached_file):
|
||||
print(f'Downloading: "{url}" to {cached_file}\n')
|
||||
download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
|
||||
return cached_file
|
||||
|
||||
115
tests/data/test_realesrgan_model.yml
Normal file
115
tests/data/test_realesrgan_model.yml
Normal file
@@ -0,0 +1,115 @@
|
||||
scale: 4
|
||||
num_gpu: 1
|
||||
manual_seed: 0
|
||||
is_train: True
|
||||
dist: False
|
||||
|
||||
# ----------------- options for synthesizing training data ----------------- #
|
||||
# USM the ground-truth
|
||||
l1_gt_usm: True
|
||||
percep_gt_usm: True
|
||||
gan_gt_usm: False
|
||||
|
||||
# the first degradation process
|
||||
resize_prob: [0.2, 0.7, 0.1] # up, down, keep
|
||||
resize_range: [0.15, 1.5]
|
||||
gaussian_noise_prob: 1
|
||||
noise_range: [1, 30]
|
||||
poisson_scale_range: [0.05, 3]
|
||||
gray_noise_prob: 1
|
||||
jpeg_range: [30, 95]
|
||||
|
||||
# the second degradation process
|
||||
second_blur_prob: 1
|
||||
resize_prob2: [0.3, 0.4, 0.3] # up, down, keep
|
||||
resize_range2: [0.3, 1.2]
|
||||
gaussian_noise_prob2: 1
|
||||
noise_range2: [1, 25]
|
||||
poisson_scale_range2: [0.05, 2.5]
|
||||
gray_noise_prob2: 1
|
||||
jpeg_range2: [30, 95]
|
||||
|
||||
gt_size: 32
|
||||
queue_size: 1
|
||||
|
||||
# network structures
|
||||
network_g:
|
||||
type: RRDBNet
|
||||
num_in_ch: 3
|
||||
num_out_ch: 3
|
||||
num_feat: 4
|
||||
num_block: 1
|
||||
num_grow_ch: 2
|
||||
|
||||
network_d:
|
||||
type: UNetDiscriminatorSN
|
||||
num_in_ch: 3
|
||||
num_feat: 2
|
||||
skip_connection: True
|
||||
|
||||
# path
|
||||
path:
|
||||
pretrain_network_g: ~
|
||||
param_key_g: params_ema
|
||||
strict_load_g: true
|
||||
resume_state: ~
|
||||
|
||||
# training settings
|
||||
train:
|
||||
ema_decay: 0.999
|
||||
optim_g:
|
||||
type: Adam
|
||||
lr: !!float 1e-4
|
||||
weight_decay: 0
|
||||
betas: [0.9, 0.99]
|
||||
optim_d:
|
||||
type: Adam
|
||||
lr: !!float 1e-4
|
||||
weight_decay: 0
|
||||
betas: [0.9, 0.99]
|
||||
|
||||
scheduler:
|
||||
type: MultiStepLR
|
||||
milestones: [400000]
|
||||
gamma: 0.5
|
||||
|
||||
total_iter: 400000
|
||||
warmup_iter: -1 # no warm up
|
||||
|
||||
# losses
|
||||
pixel_opt:
|
||||
type: L1Loss
|
||||
loss_weight: 1.0
|
||||
reduction: mean
|
||||
# perceptual loss (content and style losses)
|
||||
perceptual_opt:
|
||||
type: PerceptualLoss
|
||||
layer_weights:
|
||||
# before relu
|
||||
'conv1_2': 0.1
|
||||
'conv2_2': 0.1
|
||||
'conv3_4': 1
|
||||
'conv4_4': 1
|
||||
'conv5_4': 1
|
||||
vgg_type: vgg19
|
||||
use_input_norm: true
|
||||
perceptual_weight: !!float 1.0
|
||||
style_weight: 0
|
||||
range_norm: false
|
||||
criterion: l1
|
||||
# gan loss
|
||||
gan_opt:
|
||||
type: GANLoss
|
||||
gan_type: vanilla
|
||||
real_label_val: 1.0
|
||||
fake_label_val: 0.0
|
||||
loss_weight: !!float 1e-1
|
||||
|
||||
net_d_iters: 1
|
||||
net_d_init_iters: 0
|
||||
|
||||
|
||||
# validation settings
|
||||
val:
|
||||
val_freq: !!float 5e3
|
||||
save_img: False
|
||||
75
tests/data/test_realesrnet_model.yml
Normal file
75
tests/data/test_realesrnet_model.yml
Normal file
@@ -0,0 +1,75 @@
|
||||
scale: 4
|
||||
num_gpu: 1
|
||||
manual_seed: 0
|
||||
is_train: True
|
||||
dist: False
|
||||
|
||||
# ----------------- options for synthesizing training data ----------------- #
|
||||
gt_usm: True # USM the ground-truth
|
||||
|
||||
# the first degradation process
|
||||
resize_prob: [0.2, 0.7, 0.1] # up, down, keep
|
||||
resize_range: [0.15, 1.5]
|
||||
gaussian_noise_prob: 1
|
||||
noise_range: [1, 30]
|
||||
poisson_scale_range: [0.05, 3]
|
||||
gray_noise_prob: 1
|
||||
jpeg_range: [30, 95]
|
||||
|
||||
# the second degradation process
|
||||
second_blur_prob: 1
|
||||
resize_prob2: [0.3, 0.4, 0.3] # up, down, keep
|
||||
resize_range2: [0.3, 1.2]
|
||||
gaussian_noise_prob2: 1
|
||||
noise_range2: [1, 25]
|
||||
poisson_scale_range2: [0.05, 2.5]
|
||||
gray_noise_prob2: 1
|
||||
jpeg_range2: [30, 95]
|
||||
|
||||
gt_size: 32
|
||||
queue_size: 1
|
||||
|
||||
# network structures
|
||||
network_g:
|
||||
type: RRDBNet
|
||||
num_in_ch: 3
|
||||
num_out_ch: 3
|
||||
num_feat: 4
|
||||
num_block: 1
|
||||
num_grow_ch: 2
|
||||
|
||||
# path
|
||||
path:
|
||||
pretrain_network_g: ~
|
||||
param_key_g: params_ema
|
||||
strict_load_g: true
|
||||
resume_state: ~
|
||||
|
||||
# training settings
|
||||
train:
|
||||
ema_decay: 0.999
|
||||
optim_g:
|
||||
type: Adam
|
||||
lr: !!float 2e-4
|
||||
weight_decay: 0
|
||||
betas: [0.9, 0.99]
|
||||
|
||||
scheduler:
|
||||
type: MultiStepLR
|
||||
milestones: [1000000]
|
||||
gamma: 0.5
|
||||
|
||||
total_iter: 1000000
|
||||
warmup_iter: -1 # no warm up
|
||||
|
||||
# losses
|
||||
pixel_opt:
|
||||
type: L1Loss
|
||||
loss_weight: 1.0
|
||||
reduction: mean
|
||||
|
||||
|
||||
# validation settings
|
||||
val:
|
||||
val_freq: !!float 5e3
|
||||
save_img: False
|
||||
@@ -7,7 +7,7 @@ from realesrgan.data.realesrgan_paired_dataset import RealESRGANPairedDataset
|
||||
|
||||
def test_realesrgan_dataset():
|
||||
|
||||
with open('tests/data/demo_option_realesrgan_dataset.yml', mode='r') as f:
|
||||
with open('tests/data/test_realesrgan_dataset.yml', mode='r') as f:
|
||||
opt = yaml.load(f, Loader=yaml.FullLoader)
|
||||
|
||||
dataset = RealESRGANDataset(opt)
|
||||
@@ -81,7 +81,7 @@ def test_realesrgan_dataset():
|
||||
|
||||
def test_realesrgan_paired_dataset():
|
||||
|
||||
with open('tests/data/demo_option_realesrgan_paired_dataset.yml', mode='r') as f:
|
||||
with open('tests/data/test_realesrgan_paired_dataset.yml', mode='r') as f:
|
||||
opt = yaml.load(f, Loader=yaml.FullLoader)
|
||||
|
||||
dataset = RealESRGANPairedDataset(opt)
|
||||
|
||||
126
tests/test_model.py
Normal file
126
tests/test_model.py
Normal file
@@ -0,0 +1,126 @@
|
||||
import torch
|
||||
import yaml
|
||||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||
from basicsr.data.paired_image_dataset import PairedImageDataset
|
||||
from basicsr.losses.losses import GANLoss, L1Loss, PerceptualLoss
|
||||
|
||||
from realesrgan.archs.discriminator_arch import UNetDiscriminatorSN
|
||||
from realesrgan.models.realesrgan_model import RealESRGANModel
|
||||
from realesrgan.models.realesrnet_model import RealESRNetModel
|
||||
|
||||
|
||||
def test_realesrnet_model():
|
||||
with open('tests/data/test_realesrnet_model.yml', mode='r') as f:
|
||||
opt = yaml.load(f, Loader=yaml.FullLoader)
|
||||
|
||||
# build model
|
||||
model = RealESRNetModel(opt)
|
||||
# test attributes
|
||||
assert model.__class__.__name__ == 'RealESRNetModel'
|
||||
assert isinstance(model.net_g, RRDBNet)
|
||||
assert isinstance(model.cri_pix, L1Loss)
|
||||
assert isinstance(model.optimizers[0], torch.optim.Adam)
|
||||
|
||||
# prepare data
|
||||
gt = torch.rand((1, 3, 32, 32), dtype=torch.float32)
|
||||
kernel1 = torch.rand((1, 5, 5), dtype=torch.float32)
|
||||
kernel2 = torch.rand((1, 5, 5), dtype=torch.float32)
|
||||
sinc_kernel = torch.rand((1, 5, 5), dtype=torch.float32)
|
||||
data = dict(gt=gt, kernel1=kernel1, kernel2=kernel2, sinc_kernel=sinc_kernel)
|
||||
model.feed_data(data)
|
||||
# check dequeue
|
||||
model.feed_data(data)
|
||||
# check data shape
|
||||
assert model.lq.shape == (1, 3, 8, 8)
|
||||
assert model.gt.shape == (1, 3, 32, 32)
|
||||
|
||||
# change probability to test if-else
|
||||
model.opt['gaussian_noise_prob'] = 0
|
||||
model.opt['gray_noise_prob'] = 0
|
||||
model.opt['second_blur_prob'] = 0
|
||||
model.opt['gaussian_noise_prob2'] = 0
|
||||
model.opt['gray_noise_prob2'] = 0
|
||||
model.feed_data(data)
|
||||
# check data shape
|
||||
assert model.lq.shape == (1, 3, 8, 8)
|
||||
assert model.gt.shape == (1, 3, 32, 32)
|
||||
|
||||
# ----------------- test nondist_validation -------------------- #
|
||||
# construct dataloader
|
||||
dataset_opt = dict(
|
||||
name='Demo',
|
||||
dataroot_gt='tests/data/gt',
|
||||
dataroot_lq='tests/data/lq',
|
||||
io_backend=dict(type='disk'),
|
||||
scale=4,
|
||||
phase='val')
|
||||
dataset = PairedImageDataset(dataset_opt)
|
||||
dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=1, shuffle=False, num_workers=0)
|
||||
assert model.is_train is True
|
||||
model.nondist_validation(dataloader, 1, None, False)
|
||||
assert model.is_train is True
|
||||
|
||||
|
||||
def test_realesrgan_model():
|
||||
with open('tests/data/test_realesrgan_model.yml', mode='r') as f:
|
||||
opt = yaml.load(f, Loader=yaml.FullLoader)
|
||||
|
||||
# build model
|
||||
model = RealESRGANModel(opt)
|
||||
# test attributes
|
||||
assert model.__class__.__name__ == 'RealESRGANModel'
|
||||
assert isinstance(model.net_g, RRDBNet) # generator
|
||||
assert isinstance(model.net_d, UNetDiscriminatorSN) # discriminator
|
||||
assert isinstance(model.cri_pix, L1Loss)
|
||||
assert isinstance(model.cri_perceptual, PerceptualLoss)
|
||||
assert isinstance(model.cri_gan, GANLoss)
|
||||
assert isinstance(model.optimizers[0], torch.optim.Adam)
|
||||
assert isinstance(model.optimizers[1], torch.optim.Adam)
|
||||
|
||||
# prepare data
|
||||
gt = torch.rand((1, 3, 32, 32), dtype=torch.float32)
|
||||
kernel1 = torch.rand((1, 5, 5), dtype=torch.float32)
|
||||
kernel2 = torch.rand((1, 5, 5), dtype=torch.float32)
|
||||
sinc_kernel = torch.rand((1, 5, 5), dtype=torch.float32)
|
||||
data = dict(gt=gt, kernel1=kernel1, kernel2=kernel2, sinc_kernel=sinc_kernel)
|
||||
model.feed_data(data)
|
||||
# check dequeue
|
||||
model.feed_data(data)
|
||||
# check data shape
|
||||
assert model.lq.shape == (1, 3, 8, 8)
|
||||
assert model.gt.shape == (1, 3, 32, 32)
|
||||
|
||||
# change probability to test if-else
|
||||
model.opt['gaussian_noise_prob'] = 0
|
||||
model.opt['gray_noise_prob'] = 0
|
||||
model.opt['second_blur_prob'] = 0
|
||||
model.opt['gaussian_noise_prob2'] = 0
|
||||
model.opt['gray_noise_prob2'] = 0
|
||||
model.feed_data(data)
|
||||
# check data shape
|
||||
assert model.lq.shape == (1, 3, 8, 8)
|
||||
assert model.gt.shape == (1, 3, 32, 32)
|
||||
|
||||
# ----------------- test nondist_validation -------------------- #
|
||||
# construct dataloader
|
||||
dataset_opt = dict(
|
||||
name='Demo',
|
||||
dataroot_gt='tests/data/gt',
|
||||
dataroot_lq='tests/data/lq',
|
||||
io_backend=dict(type='disk'),
|
||||
scale=4,
|
||||
phase='val')
|
||||
dataset = PairedImageDataset(dataset_opt)
|
||||
dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=1, shuffle=False, num_workers=0)
|
||||
assert model.is_train is True
|
||||
model.nondist_validation(dataloader, 1, None, False)
|
||||
assert model.is_train is True
|
||||
|
||||
# ----------------- test optimize_parameters -------------------- #
|
||||
model.feed_data(data)
|
||||
model.optimize_parameters(1)
|
||||
assert model.output.shape == (1, 3, 32, 32)
|
||||
assert isinstance(model.log_dict, dict)
|
||||
# check returned keys
|
||||
expected_keys = ['l_g_pix', 'l_g_percep', 'l_g_gan', 'l_d_real', 'out_d_real', 'l_d_fake', 'out_d_fake']
|
||||
assert set(expected_keys).issubset(set(model.log_dict.keys()))
|
||||
87
tests/test_utils.py
Normal file
87
tests/test_utils.py
Normal file
@@ -0,0 +1,87 @@
|
||||
import numpy as np
|
||||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||
|
||||
from realesrgan.utils import RealESRGANer
|
||||
|
||||
|
||||
def test_realesrganer():
|
||||
# initialize with default model
|
||||
restorer = RealESRGANer(
|
||||
scale=4,
|
||||
model_path='experiments/pretrained_models/RealESRGAN_x4plus.pth',
|
||||
model=None,
|
||||
tile=10,
|
||||
tile_pad=10,
|
||||
pre_pad=2,
|
||||
half=False)
|
||||
assert isinstance(restorer.model, RRDBNet)
|
||||
assert restorer.half is False
|
||||
# initialize with user-defined model
|
||||
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
|
||||
restorer = RealESRGANer(
|
||||
scale=4,
|
||||
model_path='experiments/pretrained_models/RealESRGAN_x4plus_anime_6B.pth',
|
||||
model=model,
|
||||
tile=10,
|
||||
tile_pad=10,
|
||||
pre_pad=2,
|
||||
half=True)
|
||||
# test attribute
|
||||
assert isinstance(restorer.model, RRDBNet)
|
||||
assert restorer.half is True
|
||||
|
||||
# ------------------ test pre_process ---------------- #
|
||||
img = np.random.random((12, 12, 3)).astype(np.float32)
|
||||
restorer.pre_process(img)
|
||||
assert restorer.img.shape == (1, 3, 14, 14)
|
||||
# with modcrop
|
||||
restorer.scale = 1
|
||||
restorer.pre_process(img)
|
||||
assert restorer.img.shape == (1, 3, 16, 16)
|
||||
|
||||
# ------------------ test process ---------------- #
|
||||
restorer.process()
|
||||
assert restorer.output.shape == (1, 3, 64, 64)
|
||||
|
||||
# ------------------ test post_process ---------------- #
|
||||
restorer.mod_scale = 4
|
||||
output = restorer.post_process()
|
||||
assert output.shape == (1, 3, 60, 60)
|
||||
|
||||
# ------------------ test tile_process ---------------- #
|
||||
restorer.scale = 4
|
||||
img = np.random.random((12, 12, 3)).astype(np.float32)
|
||||
restorer.pre_process(img)
|
||||
restorer.tile_process()
|
||||
assert restorer.output.shape == (1, 3, 64, 64)
|
||||
|
||||
# ------------------ test enhance ---------------- #
|
||||
img = np.random.random((12, 12, 3)).astype(np.float32)
|
||||
result = restorer.enhance(img, outscale=2)
|
||||
assert result[0].shape == (24, 24, 3)
|
||||
assert result[1] == 'RGB'
|
||||
|
||||
# ------------------ test enhance with 16-bit image---------------- #
|
||||
img = np.random.random((4, 4, 3)).astype(np.uint16) + 512
|
||||
result = restorer.enhance(img, outscale=2)
|
||||
assert result[0].shape == (8, 8, 3)
|
||||
assert result[1] == 'RGB'
|
||||
|
||||
# ------------------ test enhance with gray image---------------- #
|
||||
img = np.random.random((4, 4)).astype(np.float32)
|
||||
result = restorer.enhance(img, outscale=2)
|
||||
assert result[0].shape == (8, 8)
|
||||
assert result[1] == 'L'
|
||||
|
||||
# ------------------ test enhance with RGBA---------------- #
|
||||
img = np.random.random((4, 4, 4)).astype(np.float32)
|
||||
result = restorer.enhance(img, outscale=2)
|
||||
assert result[0].shape == (8, 8, 4)
|
||||
assert result[1] == 'RGBA'
|
||||
|
||||
# ------------------ test enhance with RGBA, alpha_upsampler---------------- #
|
||||
restorer.tile_size = 0
|
||||
img = np.random.random((4, 4, 4)).astype(np.float32)
|
||||
result = restorer.enhance(img, outscale=2, alpha_upsampler=None)
|
||||
assert result[0].shape == (8, 8, 4)
|
||||
assert result[1] == 'RGBA'
|
||||
Reference in New Issue
Block a user