add unittest for model and utils
This commit is contained in:
126
tests/test_model.py
Normal file
126
tests/test_model.py
Normal file
@@ -0,0 +1,126 @@
|
||||
import torch
|
||||
import yaml
|
||||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||
from basicsr.data.paired_image_dataset import PairedImageDataset
|
||||
from basicsr.losses.losses import GANLoss, L1Loss, PerceptualLoss
|
||||
|
||||
from realesrgan.archs.discriminator_arch import UNetDiscriminatorSN
|
||||
from realesrgan.models.realesrgan_model import RealESRGANModel
|
||||
from realesrgan.models.realesrnet_model import RealESRNetModel
|
||||
|
||||
|
||||
def test_realesrnet_model():
|
||||
with open('tests/data/test_realesrnet_model.yml', mode='r') as f:
|
||||
opt = yaml.load(f, Loader=yaml.FullLoader)
|
||||
|
||||
# build model
|
||||
model = RealESRNetModel(opt)
|
||||
# test attributes
|
||||
assert model.__class__.__name__ == 'RealESRNetModel'
|
||||
assert isinstance(model.net_g, RRDBNet)
|
||||
assert isinstance(model.cri_pix, L1Loss)
|
||||
assert isinstance(model.optimizers[0], torch.optim.Adam)
|
||||
|
||||
# prepare data
|
||||
gt = torch.rand((1, 3, 32, 32), dtype=torch.float32)
|
||||
kernel1 = torch.rand((1, 5, 5), dtype=torch.float32)
|
||||
kernel2 = torch.rand((1, 5, 5), dtype=torch.float32)
|
||||
sinc_kernel = torch.rand((1, 5, 5), dtype=torch.float32)
|
||||
data = dict(gt=gt, kernel1=kernel1, kernel2=kernel2, sinc_kernel=sinc_kernel)
|
||||
model.feed_data(data)
|
||||
# check dequeue
|
||||
model.feed_data(data)
|
||||
# check data shape
|
||||
assert model.lq.shape == (1, 3, 8, 8)
|
||||
assert model.gt.shape == (1, 3, 32, 32)
|
||||
|
||||
# change probability to test if-else
|
||||
model.opt['gaussian_noise_prob'] = 0
|
||||
model.opt['gray_noise_prob'] = 0
|
||||
model.opt['second_blur_prob'] = 0
|
||||
model.opt['gaussian_noise_prob2'] = 0
|
||||
model.opt['gray_noise_prob2'] = 0
|
||||
model.feed_data(data)
|
||||
# check data shape
|
||||
assert model.lq.shape == (1, 3, 8, 8)
|
||||
assert model.gt.shape == (1, 3, 32, 32)
|
||||
|
||||
# ----------------- test nondist_validation -------------------- #
|
||||
# construct dataloader
|
||||
dataset_opt = dict(
|
||||
name='Demo',
|
||||
dataroot_gt='tests/data/gt',
|
||||
dataroot_lq='tests/data/lq',
|
||||
io_backend=dict(type='disk'),
|
||||
scale=4,
|
||||
phase='val')
|
||||
dataset = PairedImageDataset(dataset_opt)
|
||||
dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=1, shuffle=False, num_workers=0)
|
||||
assert model.is_train is True
|
||||
model.nondist_validation(dataloader, 1, None, False)
|
||||
assert model.is_train is True
|
||||
|
||||
|
||||
def test_realesrgan_model():
|
||||
with open('tests/data/test_realesrgan_model.yml', mode='r') as f:
|
||||
opt = yaml.load(f, Loader=yaml.FullLoader)
|
||||
|
||||
# build model
|
||||
model = RealESRGANModel(opt)
|
||||
# test attributes
|
||||
assert model.__class__.__name__ == 'RealESRGANModel'
|
||||
assert isinstance(model.net_g, RRDBNet) # generator
|
||||
assert isinstance(model.net_d, UNetDiscriminatorSN) # discriminator
|
||||
assert isinstance(model.cri_pix, L1Loss)
|
||||
assert isinstance(model.cri_perceptual, PerceptualLoss)
|
||||
assert isinstance(model.cri_gan, GANLoss)
|
||||
assert isinstance(model.optimizers[0], torch.optim.Adam)
|
||||
assert isinstance(model.optimizers[1], torch.optim.Adam)
|
||||
|
||||
# prepare data
|
||||
gt = torch.rand((1, 3, 32, 32), dtype=torch.float32)
|
||||
kernel1 = torch.rand((1, 5, 5), dtype=torch.float32)
|
||||
kernel2 = torch.rand((1, 5, 5), dtype=torch.float32)
|
||||
sinc_kernel = torch.rand((1, 5, 5), dtype=torch.float32)
|
||||
data = dict(gt=gt, kernel1=kernel1, kernel2=kernel2, sinc_kernel=sinc_kernel)
|
||||
model.feed_data(data)
|
||||
# check dequeue
|
||||
model.feed_data(data)
|
||||
# check data shape
|
||||
assert model.lq.shape == (1, 3, 8, 8)
|
||||
assert model.gt.shape == (1, 3, 32, 32)
|
||||
|
||||
# change probability to test if-else
|
||||
model.opt['gaussian_noise_prob'] = 0
|
||||
model.opt['gray_noise_prob'] = 0
|
||||
model.opt['second_blur_prob'] = 0
|
||||
model.opt['gaussian_noise_prob2'] = 0
|
||||
model.opt['gray_noise_prob2'] = 0
|
||||
model.feed_data(data)
|
||||
# check data shape
|
||||
assert model.lq.shape == (1, 3, 8, 8)
|
||||
assert model.gt.shape == (1, 3, 32, 32)
|
||||
|
||||
# ----------------- test nondist_validation -------------------- #
|
||||
# construct dataloader
|
||||
dataset_opt = dict(
|
||||
name='Demo',
|
||||
dataroot_gt='tests/data/gt',
|
||||
dataroot_lq='tests/data/lq',
|
||||
io_backend=dict(type='disk'),
|
||||
scale=4,
|
||||
phase='val')
|
||||
dataset = PairedImageDataset(dataset_opt)
|
||||
dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=1, shuffle=False, num_workers=0)
|
||||
assert model.is_train is True
|
||||
model.nondist_validation(dataloader, 1, None, False)
|
||||
assert model.is_train is True
|
||||
|
||||
# ----------------- test optimize_parameters -------------------- #
|
||||
model.feed_data(data)
|
||||
model.optimize_parameters(1)
|
||||
assert model.output.shape == (1, 3, 32, 32)
|
||||
assert isinstance(model.log_dict, dict)
|
||||
# check returned keys
|
||||
expected_keys = ['l_g_pix', 'l_g_percep', 'l_g_gan', 'l_d_real', 'out_d_real', 'l_d_fake', 'out_d_fake']
|
||||
assert set(expected_keys).issubset(set(model.log_dict.keys()))
|
||||
Reference in New Issue
Block a user