add unittest for dataset and archs
This commit is contained in:
@@ -59,7 +59,7 @@ class RealESRGANPairedDataset(data.Dataset):
|
|||||||
# disk backend with meta_info
|
# disk backend with meta_info
|
||||||
# Each line in the meta_info describes the relative path to an image
|
# Each line in the meta_info describes the relative path to an image
|
||||||
with open(self.opt['meta_info']) as fin:
|
with open(self.opt['meta_info']) as fin:
|
||||||
paths = [line.strip().split(' ')[0] for line in fin]
|
paths = [line.strip() for line in fin]
|
||||||
self.paths = []
|
self.paths = []
|
||||||
for path in paths:
|
for path in paths:
|
||||||
gt_path, lq_path = path.split(', ')
|
gt_path, lq_path = path.split(', ')
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ line_length = 120
|
|||||||
multi_line_output = 0
|
multi_line_output = 0
|
||||||
known_standard_library = pkg_resources,setuptools
|
known_standard_library = pkg_resources,setuptools
|
||||||
known_first_party = realesrgan
|
known_first_party = realesrgan
|
||||||
known_third_party = PIL,basicsr,cv2,numpy,torch,torchvision,tqdm
|
known_third_party = PIL,basicsr,cv2,numpy,pytest,torch,torchvision,tqdm,yaml
|
||||||
no_lines_before = STDLIB,LOCALFOLDER
|
no_lines_before = STDLIB,LOCALFOLDER
|
||||||
default_section = THIRDPARTY
|
default_section = THIRDPARTY
|
||||||
|
|
||||||
@@ -25,3 +25,9 @@ default_section = THIRDPARTY
|
|||||||
skip = .git,./docs/build
|
skip = .git,./docs/build
|
||||||
count =
|
count =
|
||||||
quiet-level = 3
|
quiet-level = 3
|
||||||
|
|
||||||
|
[aliases]
|
||||||
|
test=pytest
|
||||||
|
|
||||||
|
[tool:pytest]
|
||||||
|
addopts=tests/
|
||||||
|
|||||||
28
tests/data/demo_option_realesrgan_dataset.yml
Normal file
28
tests/data/demo_option_realesrgan_dataset.yml
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
name: Demo
|
||||||
|
type: RealESRGANDataset
|
||||||
|
dataroot_gt: tests/data/gt
|
||||||
|
meta_info: tests/data/meta_info_gt.txt
|
||||||
|
io_backend:
|
||||||
|
type: disk
|
||||||
|
|
||||||
|
blur_kernel_size: 21
|
||||||
|
kernel_list: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
|
||||||
|
kernel_prob: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
|
||||||
|
sinc_prob: 1
|
||||||
|
blur_sigma: [0.2, 3]
|
||||||
|
betag_range: [0.5, 4]
|
||||||
|
betap_range: [1, 2]
|
||||||
|
|
||||||
|
blur_kernel_size2: 21
|
||||||
|
kernel_list2: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
|
||||||
|
kernel_prob2: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
|
||||||
|
sinc_prob2: 1
|
||||||
|
blur_sigma2: [0.2, 1.5]
|
||||||
|
betag_range2: [0.5, 4]
|
||||||
|
betap_range2: [1, 2]
|
||||||
|
|
||||||
|
final_sinc_prob: 1
|
||||||
|
|
||||||
|
gt_size: 128
|
||||||
|
use_hflip: True
|
||||||
|
use_rot: False
|
||||||
13
tests/data/demo_option_realesrgan_paired_dataset.yml
Normal file
13
tests/data/demo_option_realesrgan_paired_dataset.yml
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
name: Demo
|
||||||
|
type: RealESRGANPairedDataset
|
||||||
|
scale: 4
|
||||||
|
dataroot_gt: tests/data
|
||||||
|
dataroot_lq: tests/data
|
||||||
|
meta_info: tests/data/meta_info_pair.txt
|
||||||
|
io_backend:
|
||||||
|
type: disk
|
||||||
|
|
||||||
|
phase: train
|
||||||
|
gt_size: 128
|
||||||
|
use_hflip: True
|
||||||
|
use_rot: False
|
||||||
BIN
tests/data/gt.lmdb/data.mdb
Normal file
BIN
tests/data/gt.lmdb/data.mdb
Normal file
Binary file not shown.
BIN
tests/data/gt.lmdb/lock.mdb
Normal file
BIN
tests/data/gt.lmdb/lock.mdb
Normal file
Binary file not shown.
2
tests/data/gt.lmdb/meta_info.txt
Normal file
2
tests/data/gt.lmdb/meta_info.txt
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
baboon.png (480,500,3) 1
|
||||||
|
comic.png (360,240,3) 1
|
||||||
BIN
tests/data/gt/baboon.png
Normal file
BIN
tests/data/gt/baboon.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 532 KiB |
BIN
tests/data/gt/comic.png
Normal file
BIN
tests/data/gt/comic.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 195 KiB |
BIN
tests/data/lq.lmdb/data.mdb
Normal file
BIN
tests/data/lq.lmdb/data.mdb
Normal file
Binary file not shown.
BIN
tests/data/lq.lmdb/lock.mdb
Normal file
BIN
tests/data/lq.lmdb/lock.mdb
Normal file
Binary file not shown.
2
tests/data/lq.lmdb/meta_info.txt
Normal file
2
tests/data/lq.lmdb/meta_info.txt
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
baboon.png (120,125,3) 1
|
||||||
|
comic.png (80,60,3) 1
|
||||||
BIN
tests/data/lq/baboon.png
Normal file
BIN
tests/data/lq/baboon.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 35 KiB |
BIN
tests/data/lq/comic.png
Normal file
BIN
tests/data/lq/comic.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 14 KiB |
2
tests/data/meta_info_gt.txt
Normal file
2
tests/data/meta_info_gt.txt
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
baboon.png
|
||||||
|
comic.png
|
||||||
2
tests/data/meta_info_pair.txt
Normal file
2
tests/data/meta_info_pair.txt
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
gt/baboon.png, lq/baboon.png
|
||||||
|
gt/comic.png, lq/comic.png
|
||||||
151
tests/test_dataset.py
Normal file
151
tests/test_dataset.py
Normal file
@@ -0,0 +1,151 @@
|
|||||||
|
import pytest
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
from realesrgan.data.realesrgan_dataset import RealESRGANDataset
|
||||||
|
from realesrgan.data.realesrgan_paired_dataset import RealESRGANPairedDataset
|
||||||
|
|
||||||
|
|
||||||
|
def test_realesrgan_dataset():
|
||||||
|
|
||||||
|
with open('tests/data/demo_option_realesrgan_dataset.yml', mode='r') as f:
|
||||||
|
opt = yaml.load(f, Loader=yaml.FullLoader)
|
||||||
|
|
||||||
|
dataset = RealESRGANDataset(opt)
|
||||||
|
assert dataset.io_backend_opt['type'] == 'disk' # io backend
|
||||||
|
assert len(dataset) == 2 # whether to read correct meta info
|
||||||
|
assert dataset.kernel_list == [
|
||||||
|
'iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso'
|
||||||
|
] # correct initialization the degradation configurations
|
||||||
|
assert dataset.betag_range2 == [0.5, 4]
|
||||||
|
|
||||||
|
# test __getitem__
|
||||||
|
result = dataset.__getitem__(0)
|
||||||
|
# check returned keys
|
||||||
|
expected_keys = ['gt', 'kernel1', 'kernel2', 'sinc_kernel', 'gt_path']
|
||||||
|
assert set(expected_keys).issubset(set(result.keys()))
|
||||||
|
# check shape and contents
|
||||||
|
assert result['gt'].shape == (3, 400, 400)
|
||||||
|
assert result['kernel1'].shape == (21, 21)
|
||||||
|
assert result['kernel2'].shape == (21, 21)
|
||||||
|
assert result['sinc_kernel'].shape == (21, 21)
|
||||||
|
assert result['gt_path'] == 'tests/data/gt/baboon.png'
|
||||||
|
|
||||||
|
# ------------------ test lmdb backend -------------------- #
|
||||||
|
opt['dataroot_gt'] = 'tests/data/gt.lmdb'
|
||||||
|
opt['io_backend']['type'] = 'lmdb'
|
||||||
|
|
||||||
|
dataset = RealESRGANDataset(opt)
|
||||||
|
assert dataset.io_backend_opt['type'] == 'lmdb' # io backend
|
||||||
|
assert len(dataset.paths) == 2 # whether to read correct meta info
|
||||||
|
assert dataset.kernel_list == [
|
||||||
|
'iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso'
|
||||||
|
] # correct initialization the degradation configurations
|
||||||
|
assert dataset.betag_range2 == [0.5, 4]
|
||||||
|
|
||||||
|
# test __getitem__
|
||||||
|
result = dataset.__getitem__(1)
|
||||||
|
# check returned keys
|
||||||
|
expected_keys = ['gt', 'kernel1', 'kernel2', 'sinc_kernel', 'gt_path']
|
||||||
|
assert set(expected_keys).issubset(set(result.keys()))
|
||||||
|
# check shape and contents
|
||||||
|
assert result['gt'].shape == (3, 400, 400)
|
||||||
|
assert result['kernel1'].shape == (21, 21)
|
||||||
|
assert result['kernel2'].shape == (21, 21)
|
||||||
|
assert result['sinc_kernel'].shape == (21, 21)
|
||||||
|
assert result['gt_path'] == 'comic'
|
||||||
|
|
||||||
|
# ------------------ test with sinc_prob = 0 -------------------- #
|
||||||
|
opt['dataroot_gt'] = 'tests/data/gt.lmdb'
|
||||||
|
opt['io_backend']['type'] = 'lmdb'
|
||||||
|
opt['sinc_prob'] = 0
|
||||||
|
opt['sinc_prob2'] = 0
|
||||||
|
opt['final_sinc_prob'] = 0
|
||||||
|
dataset = RealESRGANDataset(opt)
|
||||||
|
result = dataset.__getitem__(0)
|
||||||
|
# check returned keys
|
||||||
|
expected_keys = ['gt', 'kernel1', 'kernel2', 'sinc_kernel', 'gt_path']
|
||||||
|
assert set(expected_keys).issubset(set(result.keys()))
|
||||||
|
# check shape and contents
|
||||||
|
assert result['gt'].shape == (3, 400, 400)
|
||||||
|
assert result['kernel1'].shape == (21, 21)
|
||||||
|
assert result['kernel2'].shape == (21, 21)
|
||||||
|
assert result['sinc_kernel'].shape == (21, 21)
|
||||||
|
assert result['gt_path'] == 'baboon'
|
||||||
|
|
||||||
|
# ------------------ lmdb backend should have paths ends with lmdb -------------------- #
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
opt['dataroot_gt'] = 'tests/data/gt'
|
||||||
|
opt['io_backend']['type'] = 'lmdb'
|
||||||
|
dataset = RealESRGANDataset(opt)
|
||||||
|
|
||||||
|
|
||||||
|
def test_realesrgan_paired_dataset():
|
||||||
|
|
||||||
|
with open('tests/data/demo_option_realesrgan_paired_dataset.yml', mode='r') as f:
|
||||||
|
opt = yaml.load(f, Loader=yaml.FullLoader)
|
||||||
|
|
||||||
|
dataset = RealESRGANPairedDataset(opt)
|
||||||
|
assert dataset.io_backend_opt['type'] == 'disk' # io backend
|
||||||
|
assert len(dataset) == 2 # whether to read correct meta info
|
||||||
|
|
||||||
|
# test __getitem__
|
||||||
|
result = dataset.__getitem__(0)
|
||||||
|
# check returned keys
|
||||||
|
expected_keys = ['gt', 'lq', 'gt_path', 'lq_path']
|
||||||
|
assert set(expected_keys).issubset(set(result.keys()))
|
||||||
|
# check shape and contents
|
||||||
|
assert result['gt'].shape == (3, 128, 128)
|
||||||
|
assert result['lq'].shape == (3, 32, 32)
|
||||||
|
assert result['gt_path'] == 'tests/data/gt/baboon.png'
|
||||||
|
assert result['lq_path'] == 'tests/data/lq/baboon.png'
|
||||||
|
|
||||||
|
# ------------------ test lmdb backend -------------------- #
|
||||||
|
opt['dataroot_gt'] = 'tests/data/gt.lmdb'
|
||||||
|
opt['dataroot_lq'] = 'tests/data/lq.lmdb'
|
||||||
|
opt['io_backend']['type'] = 'lmdb'
|
||||||
|
|
||||||
|
dataset = RealESRGANPairedDataset(opt)
|
||||||
|
assert dataset.io_backend_opt['type'] == 'lmdb' # io backend
|
||||||
|
assert len(dataset) == 2 # whether to read correct meta info
|
||||||
|
|
||||||
|
# test __getitem__
|
||||||
|
result = dataset.__getitem__(1)
|
||||||
|
# check returned keys
|
||||||
|
expected_keys = ['gt', 'lq', 'gt_path', 'lq_path']
|
||||||
|
assert set(expected_keys).issubset(set(result.keys()))
|
||||||
|
# check shape and contents
|
||||||
|
assert result['gt'].shape == (3, 128, 128)
|
||||||
|
assert result['lq'].shape == (3, 32, 32)
|
||||||
|
assert result['gt_path'] == 'comic'
|
||||||
|
assert result['lq_path'] == 'comic'
|
||||||
|
|
||||||
|
# ------------------ test paired_paths_from_folder -------------------- #
|
||||||
|
opt['dataroot_gt'] = 'tests/data/gt'
|
||||||
|
opt['dataroot_lq'] = 'tests/data/lq'
|
||||||
|
opt['io_backend'] = dict(type='disk')
|
||||||
|
opt['meta_info'] = None
|
||||||
|
|
||||||
|
dataset = RealESRGANPairedDataset(opt)
|
||||||
|
assert dataset.io_backend_opt['type'] == 'disk' # io backend
|
||||||
|
assert len(dataset) == 2 # whether to read correct meta info
|
||||||
|
|
||||||
|
# test __getitem__
|
||||||
|
result = dataset.__getitem__(0)
|
||||||
|
# check returned keys
|
||||||
|
expected_keys = ['gt', 'lq', 'gt_path', 'lq_path']
|
||||||
|
assert set(expected_keys).issubset(set(result.keys()))
|
||||||
|
# check shape and contents
|
||||||
|
assert result['gt'].shape == (3, 128, 128)
|
||||||
|
assert result['lq'].shape == (3, 32, 32)
|
||||||
|
|
||||||
|
# ------------------ test normalization -------------------- #
|
||||||
|
dataset.mean = [0.5, 0.5, 0.5]
|
||||||
|
dataset.std = [0.5, 0.5, 0.5]
|
||||||
|
# test __getitem__
|
||||||
|
result = dataset.__getitem__(0)
|
||||||
|
# check returned keys
|
||||||
|
expected_keys = ['gt', 'lq', 'gt_path', 'lq_path']
|
||||||
|
assert set(expected_keys).issubset(set(result.keys()))
|
||||||
|
# check shape and contents
|
||||||
|
assert result['gt'].shape == (3, 128, 128)
|
||||||
|
assert result['lq'].shape == (3, 32, 32)
|
||||||
19
tests/test_discriminator_arch.py
Normal file
19
tests/test_discriminator_arch.py
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
from realesrgan.archs.discriminator_arch import UNetDiscriminatorSN
|
||||||
|
|
||||||
|
|
||||||
|
def test_unetdiscriminatorsn():
|
||||||
|
"""Test arch: UNetDiscriminatorSN."""
|
||||||
|
|
||||||
|
# model init and forward (cpu)
|
||||||
|
net = UNetDiscriminatorSN(num_in_ch=3, num_feat=4, skip_connection=True)
|
||||||
|
img = torch.rand((1, 3, 32, 32), dtype=torch.float32)
|
||||||
|
output = net(img)
|
||||||
|
assert output.shape == (1, 1, 32, 32)
|
||||||
|
|
||||||
|
# model init and forward (gpu)
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
net.cuda()
|
||||||
|
output = net(img.cuda())
|
||||||
|
assert output.shape == (1, 1, 32, 32)
|
||||||
Reference in New Issue
Block a user