add unittest for dataset and archs
This commit is contained in:
28
tests/data/demo_option_realesrgan_dataset.yml
Normal file
28
tests/data/demo_option_realesrgan_dataset.yml
Normal file
@@ -0,0 +1,28 @@
|
||||
name: Demo
|
||||
type: RealESRGANDataset
|
||||
dataroot_gt: tests/data/gt
|
||||
meta_info: tests/data/meta_info_gt.txt
|
||||
io_backend:
|
||||
type: disk
|
||||
|
||||
blur_kernel_size: 21
|
||||
kernel_list: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
|
||||
kernel_prob: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
|
||||
sinc_prob: 1
|
||||
blur_sigma: [0.2, 3]
|
||||
betag_range: [0.5, 4]
|
||||
betap_range: [1, 2]
|
||||
|
||||
blur_kernel_size2: 21
|
||||
kernel_list2: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
|
||||
kernel_prob2: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
|
||||
sinc_prob2: 1
|
||||
blur_sigma2: [0.2, 1.5]
|
||||
betag_range2: [0.5, 4]
|
||||
betap_range2: [1, 2]
|
||||
|
||||
final_sinc_prob: 1
|
||||
|
||||
gt_size: 128
|
||||
use_hflip: True
|
||||
use_rot: False
|
||||
13
tests/data/demo_option_realesrgan_paired_dataset.yml
Normal file
13
tests/data/demo_option_realesrgan_paired_dataset.yml
Normal file
@@ -0,0 +1,13 @@
|
||||
name: Demo
|
||||
type: RealESRGANPairedDataset
|
||||
scale: 4
|
||||
dataroot_gt: tests/data
|
||||
dataroot_lq: tests/data
|
||||
meta_info: tests/data/meta_info_pair.txt
|
||||
io_backend:
|
||||
type: disk
|
||||
|
||||
phase: train
|
||||
gt_size: 128
|
||||
use_hflip: True
|
||||
use_rot: False
|
||||
BIN
tests/data/gt.lmdb/data.mdb
Normal file
BIN
tests/data/gt.lmdb/data.mdb
Normal file
Binary file not shown.
BIN
tests/data/gt.lmdb/lock.mdb
Normal file
BIN
tests/data/gt.lmdb/lock.mdb
Normal file
Binary file not shown.
2
tests/data/gt.lmdb/meta_info.txt
Normal file
2
tests/data/gt.lmdb/meta_info.txt
Normal file
@@ -0,0 +1,2 @@
|
||||
baboon.png (480,500,3) 1
|
||||
comic.png (360,240,3) 1
|
||||
BIN
tests/data/gt/baboon.png
Normal file
BIN
tests/data/gt/baboon.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 532 KiB |
BIN
tests/data/gt/comic.png
Normal file
BIN
tests/data/gt/comic.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 195 KiB |
BIN
tests/data/lq.lmdb/data.mdb
Normal file
BIN
tests/data/lq.lmdb/data.mdb
Normal file
Binary file not shown.
BIN
tests/data/lq.lmdb/lock.mdb
Normal file
BIN
tests/data/lq.lmdb/lock.mdb
Normal file
Binary file not shown.
2
tests/data/lq.lmdb/meta_info.txt
Normal file
2
tests/data/lq.lmdb/meta_info.txt
Normal file
@@ -0,0 +1,2 @@
|
||||
baboon.png (120,125,3) 1
|
||||
comic.png (80,60,3) 1
|
||||
BIN
tests/data/lq/baboon.png
Normal file
BIN
tests/data/lq/baboon.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 35 KiB |
BIN
tests/data/lq/comic.png
Normal file
BIN
tests/data/lq/comic.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 14 KiB |
2
tests/data/meta_info_gt.txt
Normal file
2
tests/data/meta_info_gt.txt
Normal file
@@ -0,0 +1,2 @@
|
||||
baboon.png
|
||||
comic.png
|
||||
2
tests/data/meta_info_pair.txt
Normal file
2
tests/data/meta_info_pair.txt
Normal file
@@ -0,0 +1,2 @@
|
||||
gt/baboon.png, lq/baboon.png
|
||||
gt/comic.png, lq/comic.png
|
||||
151
tests/test_dataset.py
Normal file
151
tests/test_dataset.py
Normal file
@@ -0,0 +1,151 @@
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
from realesrgan.data.realesrgan_dataset import RealESRGANDataset
|
||||
from realesrgan.data.realesrgan_paired_dataset import RealESRGANPairedDataset
|
||||
|
||||
|
||||
def test_realesrgan_dataset():
|
||||
|
||||
with open('tests/data/demo_option_realesrgan_dataset.yml', mode='r') as f:
|
||||
opt = yaml.load(f, Loader=yaml.FullLoader)
|
||||
|
||||
dataset = RealESRGANDataset(opt)
|
||||
assert dataset.io_backend_opt['type'] == 'disk' # io backend
|
||||
assert len(dataset) == 2 # whether to read correct meta info
|
||||
assert dataset.kernel_list == [
|
||||
'iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso'
|
||||
] # correct initialization the degradation configurations
|
||||
assert dataset.betag_range2 == [0.5, 4]
|
||||
|
||||
# test __getitem__
|
||||
result = dataset.__getitem__(0)
|
||||
# check returned keys
|
||||
expected_keys = ['gt', 'kernel1', 'kernel2', 'sinc_kernel', 'gt_path']
|
||||
assert set(expected_keys).issubset(set(result.keys()))
|
||||
# check shape and contents
|
||||
assert result['gt'].shape == (3, 400, 400)
|
||||
assert result['kernel1'].shape == (21, 21)
|
||||
assert result['kernel2'].shape == (21, 21)
|
||||
assert result['sinc_kernel'].shape == (21, 21)
|
||||
assert result['gt_path'] == 'tests/data/gt/baboon.png'
|
||||
|
||||
# ------------------ test lmdb backend -------------------- #
|
||||
opt['dataroot_gt'] = 'tests/data/gt.lmdb'
|
||||
opt['io_backend']['type'] = 'lmdb'
|
||||
|
||||
dataset = RealESRGANDataset(opt)
|
||||
assert dataset.io_backend_opt['type'] == 'lmdb' # io backend
|
||||
assert len(dataset.paths) == 2 # whether to read correct meta info
|
||||
assert dataset.kernel_list == [
|
||||
'iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso'
|
||||
] # correct initialization the degradation configurations
|
||||
assert dataset.betag_range2 == [0.5, 4]
|
||||
|
||||
# test __getitem__
|
||||
result = dataset.__getitem__(1)
|
||||
# check returned keys
|
||||
expected_keys = ['gt', 'kernel1', 'kernel2', 'sinc_kernel', 'gt_path']
|
||||
assert set(expected_keys).issubset(set(result.keys()))
|
||||
# check shape and contents
|
||||
assert result['gt'].shape == (3, 400, 400)
|
||||
assert result['kernel1'].shape == (21, 21)
|
||||
assert result['kernel2'].shape == (21, 21)
|
||||
assert result['sinc_kernel'].shape == (21, 21)
|
||||
assert result['gt_path'] == 'comic'
|
||||
|
||||
# ------------------ test with sinc_prob = 0 -------------------- #
|
||||
opt['dataroot_gt'] = 'tests/data/gt.lmdb'
|
||||
opt['io_backend']['type'] = 'lmdb'
|
||||
opt['sinc_prob'] = 0
|
||||
opt['sinc_prob2'] = 0
|
||||
opt['final_sinc_prob'] = 0
|
||||
dataset = RealESRGANDataset(opt)
|
||||
result = dataset.__getitem__(0)
|
||||
# check returned keys
|
||||
expected_keys = ['gt', 'kernel1', 'kernel2', 'sinc_kernel', 'gt_path']
|
||||
assert set(expected_keys).issubset(set(result.keys()))
|
||||
# check shape and contents
|
||||
assert result['gt'].shape == (3, 400, 400)
|
||||
assert result['kernel1'].shape == (21, 21)
|
||||
assert result['kernel2'].shape == (21, 21)
|
||||
assert result['sinc_kernel'].shape == (21, 21)
|
||||
assert result['gt_path'] == 'baboon'
|
||||
|
||||
# ------------------ lmdb backend should have paths ends with lmdb -------------------- #
|
||||
with pytest.raises(ValueError):
|
||||
opt['dataroot_gt'] = 'tests/data/gt'
|
||||
opt['io_backend']['type'] = 'lmdb'
|
||||
dataset = RealESRGANDataset(opt)
|
||||
|
||||
|
||||
def test_realesrgan_paired_dataset():
|
||||
|
||||
with open('tests/data/demo_option_realesrgan_paired_dataset.yml', mode='r') as f:
|
||||
opt = yaml.load(f, Loader=yaml.FullLoader)
|
||||
|
||||
dataset = RealESRGANPairedDataset(opt)
|
||||
assert dataset.io_backend_opt['type'] == 'disk' # io backend
|
||||
assert len(dataset) == 2 # whether to read correct meta info
|
||||
|
||||
# test __getitem__
|
||||
result = dataset.__getitem__(0)
|
||||
# check returned keys
|
||||
expected_keys = ['gt', 'lq', 'gt_path', 'lq_path']
|
||||
assert set(expected_keys).issubset(set(result.keys()))
|
||||
# check shape and contents
|
||||
assert result['gt'].shape == (3, 128, 128)
|
||||
assert result['lq'].shape == (3, 32, 32)
|
||||
assert result['gt_path'] == 'tests/data/gt/baboon.png'
|
||||
assert result['lq_path'] == 'tests/data/lq/baboon.png'
|
||||
|
||||
# ------------------ test lmdb backend -------------------- #
|
||||
opt['dataroot_gt'] = 'tests/data/gt.lmdb'
|
||||
opt['dataroot_lq'] = 'tests/data/lq.lmdb'
|
||||
opt['io_backend']['type'] = 'lmdb'
|
||||
|
||||
dataset = RealESRGANPairedDataset(opt)
|
||||
assert dataset.io_backend_opt['type'] == 'lmdb' # io backend
|
||||
assert len(dataset) == 2 # whether to read correct meta info
|
||||
|
||||
# test __getitem__
|
||||
result = dataset.__getitem__(1)
|
||||
# check returned keys
|
||||
expected_keys = ['gt', 'lq', 'gt_path', 'lq_path']
|
||||
assert set(expected_keys).issubset(set(result.keys()))
|
||||
# check shape and contents
|
||||
assert result['gt'].shape == (3, 128, 128)
|
||||
assert result['lq'].shape == (3, 32, 32)
|
||||
assert result['gt_path'] == 'comic'
|
||||
assert result['lq_path'] == 'comic'
|
||||
|
||||
# ------------------ test paired_paths_from_folder -------------------- #
|
||||
opt['dataroot_gt'] = 'tests/data/gt'
|
||||
opt['dataroot_lq'] = 'tests/data/lq'
|
||||
opt['io_backend'] = dict(type='disk')
|
||||
opt['meta_info'] = None
|
||||
|
||||
dataset = RealESRGANPairedDataset(opt)
|
||||
assert dataset.io_backend_opt['type'] == 'disk' # io backend
|
||||
assert len(dataset) == 2 # whether to read correct meta info
|
||||
|
||||
# test __getitem__
|
||||
result = dataset.__getitem__(0)
|
||||
# check returned keys
|
||||
expected_keys = ['gt', 'lq', 'gt_path', 'lq_path']
|
||||
assert set(expected_keys).issubset(set(result.keys()))
|
||||
# check shape and contents
|
||||
assert result['gt'].shape == (3, 128, 128)
|
||||
assert result['lq'].shape == (3, 32, 32)
|
||||
|
||||
# ------------------ test normalization -------------------- #
|
||||
dataset.mean = [0.5, 0.5, 0.5]
|
||||
dataset.std = [0.5, 0.5, 0.5]
|
||||
# test __getitem__
|
||||
result = dataset.__getitem__(0)
|
||||
# check returned keys
|
||||
expected_keys = ['gt', 'lq', 'gt_path', 'lq_path']
|
||||
assert set(expected_keys).issubset(set(result.keys()))
|
||||
# check shape and contents
|
||||
assert result['gt'].shape == (3, 128, 128)
|
||||
assert result['lq'].shape == (3, 32, 32)
|
||||
19
tests/test_discriminator_arch.py
Normal file
19
tests/test_discriminator_arch.py
Normal file
@@ -0,0 +1,19 @@
|
||||
import torch
|
||||
|
||||
from realesrgan.archs.discriminator_arch import UNetDiscriminatorSN
|
||||
|
||||
|
||||
def test_unetdiscriminatorsn():
|
||||
"""Test arch: UNetDiscriminatorSN."""
|
||||
|
||||
# model init and forward (cpu)
|
||||
net = UNetDiscriminatorSN(num_in_ch=3, num_feat=4, skip_connection=True)
|
||||
img = torch.rand((1, 3, 32, 32), dtype=torch.float32)
|
||||
output = net(img)
|
||||
assert output.shape == (1, 1, 32, 32)
|
||||
|
||||
# model init and forward (gpu)
|
||||
if torch.cuda.is_available():
|
||||
net.cuda()
|
||||
output = net(img.cuda())
|
||||
assert output.shape == (1, 1, 32, 32)
|
||||
Reference in New Issue
Block a user