add unittest for dataset and archs

This commit is contained in:
Xintao
2021-11-28 15:59:14 +08:00
parent 7dd860a881
commit 1d180efaf3
18 changed files with 227 additions and 2 deletions

View File

@@ -59,7 +59,7 @@ class RealESRGANPairedDataset(data.Dataset):
# disk backend with meta_info
# Each line in the meta_info describes the relative path to an image
with open(self.opt['meta_info']) as fin:
paths = [line.strip().split(' ')[0] for line in fin]
paths = [line.strip() for line in fin]
self.paths = []
for path in paths:
gt_path, lq_path = path.split(', ')

View File

@@ -17,7 +17,7 @@ line_length = 120
multi_line_output = 0
known_standard_library = pkg_resources,setuptools
known_first_party = realesrgan
known_third_party = PIL,basicsr,cv2,numpy,torch,torchvision,tqdm
known_third_party = PIL,basicsr,cv2,numpy,pytest,torch,torchvision,tqdm,yaml
no_lines_before = STDLIB,LOCALFOLDER
default_section = THIRDPARTY
@@ -25,3 +25,9 @@ default_section = THIRDPARTY
skip = .git,./docs/build
count =
quiet-level = 3
[aliases]
test=pytest
[tool:pytest]
addopts=tests/

View File

@@ -0,0 +1,28 @@
name: Demo
type: RealESRGANDataset
dataroot_gt: tests/data/gt
meta_info: tests/data/meta_info_gt.txt
io_backend:
type: disk
blur_kernel_size: 21
kernel_list: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
kernel_prob: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
sinc_prob: 1
blur_sigma: [0.2, 3]
betag_range: [0.5, 4]
betap_range: [1, 2]
blur_kernel_size2: 21
kernel_list2: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
kernel_prob2: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
sinc_prob2: 1
blur_sigma2: [0.2, 1.5]
betag_range2: [0.5, 4]
betap_range2: [1, 2]
final_sinc_prob: 1
gt_size: 128
use_hflip: True
use_rot: False

View File

@@ -0,0 +1,13 @@
name: Demo
type: RealESRGANPairedDataset
scale: 4
dataroot_gt: tests/data
dataroot_lq: tests/data
meta_info: tests/data/meta_info_pair.txt
io_backend:
type: disk
phase: train
gt_size: 128
use_hflip: True
use_rot: False

BIN
tests/data/gt.lmdb/data.mdb Normal file

Binary file not shown.

BIN
tests/data/gt.lmdb/lock.mdb Normal file

Binary file not shown.

View File

@@ -0,0 +1,2 @@
baboon.png (480,500,3) 1
comic.png (360,240,3) 1

BIN
tests/data/gt/baboon.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 532 KiB

BIN
tests/data/gt/comic.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 195 KiB

BIN
tests/data/lq.lmdb/data.mdb Normal file

Binary file not shown.

BIN
tests/data/lq.lmdb/lock.mdb Normal file

Binary file not shown.

View File

@@ -0,0 +1,2 @@
baboon.png (120,125,3) 1
comic.png (80,60,3) 1

BIN
tests/data/lq/baboon.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 35 KiB

BIN
tests/data/lq/comic.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 14 KiB

View File

@@ -0,0 +1,2 @@
baboon.png
comic.png

View File

@@ -0,0 +1,2 @@
gt/baboon.png, lq/baboon.png
gt/comic.png, lq/comic.png

151
tests/test_dataset.py Normal file
View File

@@ -0,0 +1,151 @@
import pytest
import yaml
from realesrgan.data.realesrgan_dataset import RealESRGANDataset
from realesrgan.data.realesrgan_paired_dataset import RealESRGANPairedDataset
def test_realesrgan_dataset():
with open('tests/data/demo_option_realesrgan_dataset.yml', mode='r') as f:
opt = yaml.load(f, Loader=yaml.FullLoader)
dataset = RealESRGANDataset(opt)
assert dataset.io_backend_opt['type'] == 'disk' # io backend
assert len(dataset) == 2 # whether to read correct meta info
assert dataset.kernel_list == [
'iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso'
] # correct initialization the degradation configurations
assert dataset.betag_range2 == [0.5, 4]
# test __getitem__
result = dataset.__getitem__(0)
# check returned keys
expected_keys = ['gt', 'kernel1', 'kernel2', 'sinc_kernel', 'gt_path']
assert set(expected_keys).issubset(set(result.keys()))
# check shape and contents
assert result['gt'].shape == (3, 400, 400)
assert result['kernel1'].shape == (21, 21)
assert result['kernel2'].shape == (21, 21)
assert result['sinc_kernel'].shape == (21, 21)
assert result['gt_path'] == 'tests/data/gt/baboon.png'
# ------------------ test lmdb backend -------------------- #
opt['dataroot_gt'] = 'tests/data/gt.lmdb'
opt['io_backend']['type'] = 'lmdb'
dataset = RealESRGANDataset(opt)
assert dataset.io_backend_opt['type'] == 'lmdb' # io backend
assert len(dataset.paths) == 2 # whether to read correct meta info
assert dataset.kernel_list == [
'iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso'
] # correct initialization the degradation configurations
assert dataset.betag_range2 == [0.5, 4]
# test __getitem__
result = dataset.__getitem__(1)
# check returned keys
expected_keys = ['gt', 'kernel1', 'kernel2', 'sinc_kernel', 'gt_path']
assert set(expected_keys).issubset(set(result.keys()))
# check shape and contents
assert result['gt'].shape == (3, 400, 400)
assert result['kernel1'].shape == (21, 21)
assert result['kernel2'].shape == (21, 21)
assert result['sinc_kernel'].shape == (21, 21)
assert result['gt_path'] == 'comic'
# ------------------ test with sinc_prob = 0 -------------------- #
opt['dataroot_gt'] = 'tests/data/gt.lmdb'
opt['io_backend']['type'] = 'lmdb'
opt['sinc_prob'] = 0
opt['sinc_prob2'] = 0
opt['final_sinc_prob'] = 0
dataset = RealESRGANDataset(opt)
result = dataset.__getitem__(0)
# check returned keys
expected_keys = ['gt', 'kernel1', 'kernel2', 'sinc_kernel', 'gt_path']
assert set(expected_keys).issubset(set(result.keys()))
# check shape and contents
assert result['gt'].shape == (3, 400, 400)
assert result['kernel1'].shape == (21, 21)
assert result['kernel2'].shape == (21, 21)
assert result['sinc_kernel'].shape == (21, 21)
assert result['gt_path'] == 'baboon'
# ------------------ lmdb backend should have paths ends with lmdb -------------------- #
with pytest.raises(ValueError):
opt['dataroot_gt'] = 'tests/data/gt'
opt['io_backend']['type'] = 'lmdb'
dataset = RealESRGANDataset(opt)
def test_realesrgan_paired_dataset():
with open('tests/data/demo_option_realesrgan_paired_dataset.yml', mode='r') as f:
opt = yaml.load(f, Loader=yaml.FullLoader)
dataset = RealESRGANPairedDataset(opt)
assert dataset.io_backend_opt['type'] == 'disk' # io backend
assert len(dataset) == 2 # whether to read correct meta info
# test __getitem__
result = dataset.__getitem__(0)
# check returned keys
expected_keys = ['gt', 'lq', 'gt_path', 'lq_path']
assert set(expected_keys).issubset(set(result.keys()))
# check shape and contents
assert result['gt'].shape == (3, 128, 128)
assert result['lq'].shape == (3, 32, 32)
assert result['gt_path'] == 'tests/data/gt/baboon.png'
assert result['lq_path'] == 'tests/data/lq/baboon.png'
# ------------------ test lmdb backend -------------------- #
opt['dataroot_gt'] = 'tests/data/gt.lmdb'
opt['dataroot_lq'] = 'tests/data/lq.lmdb'
opt['io_backend']['type'] = 'lmdb'
dataset = RealESRGANPairedDataset(opt)
assert dataset.io_backend_opt['type'] == 'lmdb' # io backend
assert len(dataset) == 2 # whether to read correct meta info
# test __getitem__
result = dataset.__getitem__(1)
# check returned keys
expected_keys = ['gt', 'lq', 'gt_path', 'lq_path']
assert set(expected_keys).issubset(set(result.keys()))
# check shape and contents
assert result['gt'].shape == (3, 128, 128)
assert result['lq'].shape == (3, 32, 32)
assert result['gt_path'] == 'comic'
assert result['lq_path'] == 'comic'
# ------------------ test paired_paths_from_folder -------------------- #
opt['dataroot_gt'] = 'tests/data/gt'
opt['dataroot_lq'] = 'tests/data/lq'
opt['io_backend'] = dict(type='disk')
opt['meta_info'] = None
dataset = RealESRGANPairedDataset(opt)
assert dataset.io_backend_opt['type'] == 'disk' # io backend
assert len(dataset) == 2 # whether to read correct meta info
# test __getitem__
result = dataset.__getitem__(0)
# check returned keys
expected_keys = ['gt', 'lq', 'gt_path', 'lq_path']
assert set(expected_keys).issubset(set(result.keys()))
# check shape and contents
assert result['gt'].shape == (3, 128, 128)
assert result['lq'].shape == (3, 32, 32)
# ------------------ test normalization -------------------- #
dataset.mean = [0.5, 0.5, 0.5]
dataset.std = [0.5, 0.5, 0.5]
# test __getitem__
result = dataset.__getitem__(0)
# check returned keys
expected_keys = ['gt', 'lq', 'gt_path', 'lq_path']
assert set(expected_keys).issubset(set(result.keys()))
# check shape and contents
assert result['gt'].shape == (3, 128, 128)
assert result['lq'].shape == (3, 32, 32)

View File

@@ -0,0 +1,19 @@
import torch
from realesrgan.archs.discriminator_arch import UNetDiscriminatorSN
def test_unetdiscriminatorsn():
"""Test arch: UNetDiscriminatorSN."""
# model init and forward (cpu)
net = UNetDiscriminatorSN(num_in_ch=3, num_feat=4, skip_connection=True)
img = torch.rand((1, 3, 32, 32), dtype=torch.float32)
output = net(img)
assert output.shape == (1, 1, 32, 32)
# model init and forward (gpu)
if torch.cuda.is_available():
net.cuda()
output = net(img.cuda())
assert output.shape == (1, 1, 32, 32)