diff --git a/realesrgan/data/realesrgan_paired_dataset.py b/realesrgan/data/realesrgan_paired_dataset.py index c8deb33..386c8d7 100644 --- a/realesrgan/data/realesrgan_paired_dataset.py +++ b/realesrgan/data/realesrgan_paired_dataset.py @@ -59,7 +59,7 @@ class RealESRGANPairedDataset(data.Dataset): # disk backend with meta_info # Each line in the meta_info describes the relative path to an image with open(self.opt['meta_info']) as fin: - paths = [line.strip().split(' ')[0] for line in fin] + paths = [line.strip() for line in fin] self.paths = [] for path in paths: gt_path, lq_path = path.split(', ') diff --git a/setup.cfg b/setup.cfg index 16aa0f9..9cecd96 100644 --- a/setup.cfg +++ b/setup.cfg @@ -17,7 +17,7 @@ line_length = 120 multi_line_output = 0 known_standard_library = pkg_resources,setuptools known_first_party = realesrgan -known_third_party = PIL,basicsr,cv2,numpy,torch,torchvision,tqdm +known_third_party = PIL,basicsr,cv2,numpy,pytest,torch,torchvision,tqdm,yaml no_lines_before = STDLIB,LOCALFOLDER default_section = THIRDPARTY @@ -25,3 +25,9 @@ default_section = THIRDPARTY skip = .git,./docs/build count = quiet-level = 3 + +[aliases] +test=pytest + +[tool:pytest] +addopts=tests/ diff --git a/tests/data/demo_option_realesrgan_dataset.yml b/tests/data/demo_option_realesrgan_dataset.yml new file mode 100644 index 0000000..48e6ecc --- /dev/null +++ b/tests/data/demo_option_realesrgan_dataset.yml @@ -0,0 +1,28 @@ +name: Demo +type: RealESRGANDataset +dataroot_gt: tests/data/gt +meta_info: tests/data/meta_info_gt.txt +io_backend: + type: disk + +blur_kernel_size: 21 +kernel_list: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso'] +kernel_prob: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03] +sinc_prob: 1 +blur_sigma: [0.2, 3] +betag_range: [0.5, 4] +betap_range: [1, 2] + +blur_kernel_size2: 21 +kernel_list2: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso'] +kernel_prob2: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03] +sinc_prob2: 1 +blur_sigma2: [0.2, 1.5] +betag_range2: [0.5, 4] +betap_range2: [1, 2] + +final_sinc_prob: 1 + +gt_size: 128 +use_hflip: True +use_rot: False diff --git a/tests/data/demo_option_realesrgan_paired_dataset.yml b/tests/data/demo_option_realesrgan_paired_dataset.yml new file mode 100644 index 0000000..8ea9709 --- /dev/null +++ b/tests/data/demo_option_realesrgan_paired_dataset.yml @@ -0,0 +1,13 @@ +name: Demo +type: RealESRGANPairedDataset +scale: 4 +dataroot_gt: tests/data +dataroot_lq: tests/data +meta_info: tests/data/meta_info_pair.txt +io_backend: + type: disk + +phase: train +gt_size: 128 +use_hflip: True +use_rot: False diff --git a/tests/data/gt.lmdb/data.mdb b/tests/data/gt.lmdb/data.mdb new file mode 100644 index 0000000..f28ad48 Binary files /dev/null and b/tests/data/gt.lmdb/data.mdb differ diff --git a/tests/data/gt.lmdb/lock.mdb b/tests/data/gt.lmdb/lock.mdb new file mode 100644 index 0000000..37b3f72 Binary files /dev/null and b/tests/data/gt.lmdb/lock.mdb differ diff --git a/tests/data/gt.lmdb/meta_info.txt b/tests/data/gt.lmdb/meta_info.txt new file mode 100644 index 0000000..f422954 --- /dev/null +++ b/tests/data/gt.lmdb/meta_info.txt @@ -0,0 +1,2 @@ +baboon.png (480,500,3) 1 +comic.png (360,240,3) 1 diff --git a/tests/data/gt/baboon.png b/tests/data/gt/baboon.png new file mode 100644 index 0000000..c81e18d Binary files /dev/null and b/tests/data/gt/baboon.png differ diff --git a/tests/data/gt/comic.png b/tests/data/gt/comic.png new file mode 100644 index 0000000..600f548 Binary files /dev/null and b/tests/data/gt/comic.png differ diff --git a/tests/data/lq.lmdb/data.mdb b/tests/data/lq.lmdb/data.mdb new file mode 100644 index 0000000..c016215 Binary files /dev/null and b/tests/data/lq.lmdb/data.mdb differ diff --git a/tests/data/lq.lmdb/lock.mdb b/tests/data/lq.lmdb/lock.mdb new file mode 100644 index 0000000..c3b69ed Binary files /dev/null and b/tests/data/lq.lmdb/lock.mdb differ diff --git a/tests/data/lq.lmdb/meta_info.txt b/tests/data/lq.lmdb/meta_info.txt new file mode 100644 index 0000000..6dfca0d --- /dev/null +++ b/tests/data/lq.lmdb/meta_info.txt @@ -0,0 +1,2 @@ +baboon.png (120,125,3) 1 +comic.png (80,60,3) 1 diff --git a/tests/data/lq/baboon.png b/tests/data/lq/baboon.png new file mode 100644 index 0000000..bbd2012 Binary files /dev/null and b/tests/data/lq/baboon.png differ diff --git a/tests/data/lq/comic.png b/tests/data/lq/comic.png new file mode 100644 index 0000000..c4e38ab Binary files /dev/null and b/tests/data/lq/comic.png differ diff --git a/tests/data/meta_info_gt.txt b/tests/data/meta_info_gt.txt new file mode 100644 index 0000000..2234632 --- /dev/null +++ b/tests/data/meta_info_gt.txt @@ -0,0 +1,2 @@ +baboon.png +comic.png diff --git a/tests/data/meta_info_pair.txt b/tests/data/meta_info_pair.txt new file mode 100644 index 0000000..4775dda --- /dev/null +++ b/tests/data/meta_info_pair.txt @@ -0,0 +1,2 @@ +gt/baboon.png, lq/baboon.png +gt/comic.png, lq/comic.png diff --git a/tests/test_dataset.py b/tests/test_dataset.py new file mode 100644 index 0000000..3fb051a --- /dev/null +++ b/tests/test_dataset.py @@ -0,0 +1,151 @@ +import pytest +import yaml + +from realesrgan.data.realesrgan_dataset import RealESRGANDataset +from realesrgan.data.realesrgan_paired_dataset import RealESRGANPairedDataset + + +def test_realesrgan_dataset(): + + with open('tests/data/demo_option_realesrgan_dataset.yml', mode='r') as f: + opt = yaml.load(f, Loader=yaml.FullLoader) + + dataset = RealESRGANDataset(opt) + assert dataset.io_backend_opt['type'] == 'disk' # io backend + assert len(dataset) == 2 # whether to read correct meta info + assert dataset.kernel_list == [ + 'iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso' + ] # correct initialization the degradation configurations + assert dataset.betag_range2 == [0.5, 4] + + # test __getitem__ + result = dataset.__getitem__(0) + # check returned keys + expected_keys = ['gt', 'kernel1', 'kernel2', 'sinc_kernel', 'gt_path'] + assert set(expected_keys).issubset(set(result.keys())) + # check shape and contents + assert result['gt'].shape == (3, 400, 400) + assert result['kernel1'].shape == (21, 21) + assert result['kernel2'].shape == (21, 21) + assert result['sinc_kernel'].shape == (21, 21) + assert result['gt_path'] == 'tests/data/gt/baboon.png' + + # ------------------ test lmdb backend -------------------- # + opt['dataroot_gt'] = 'tests/data/gt.lmdb' + opt['io_backend']['type'] = 'lmdb' + + dataset = RealESRGANDataset(opt) + assert dataset.io_backend_opt['type'] == 'lmdb' # io backend + assert len(dataset.paths) == 2 # whether to read correct meta info + assert dataset.kernel_list == [ + 'iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso' + ] # correct initialization the degradation configurations + assert dataset.betag_range2 == [0.5, 4] + + # test __getitem__ + result = dataset.__getitem__(1) + # check returned keys + expected_keys = ['gt', 'kernel1', 'kernel2', 'sinc_kernel', 'gt_path'] + assert set(expected_keys).issubset(set(result.keys())) + # check shape and contents + assert result['gt'].shape == (3, 400, 400) + assert result['kernel1'].shape == (21, 21) + assert result['kernel2'].shape == (21, 21) + assert result['sinc_kernel'].shape == (21, 21) + assert result['gt_path'] == 'comic' + + # ------------------ test with sinc_prob = 0 -------------------- # + opt['dataroot_gt'] = 'tests/data/gt.lmdb' + opt['io_backend']['type'] = 'lmdb' + opt['sinc_prob'] = 0 + opt['sinc_prob2'] = 0 + opt['final_sinc_prob'] = 0 + dataset = RealESRGANDataset(opt) + result = dataset.__getitem__(0) + # check returned keys + expected_keys = ['gt', 'kernel1', 'kernel2', 'sinc_kernel', 'gt_path'] + assert set(expected_keys).issubset(set(result.keys())) + # check shape and contents + assert result['gt'].shape == (3, 400, 400) + assert result['kernel1'].shape == (21, 21) + assert result['kernel2'].shape == (21, 21) + assert result['sinc_kernel'].shape == (21, 21) + assert result['gt_path'] == 'baboon' + + # ------------------ lmdb backend should have paths ends with lmdb -------------------- # + with pytest.raises(ValueError): + opt['dataroot_gt'] = 'tests/data/gt' + opt['io_backend']['type'] = 'lmdb' + dataset = RealESRGANDataset(opt) + + +def test_realesrgan_paired_dataset(): + + with open('tests/data/demo_option_realesrgan_paired_dataset.yml', mode='r') as f: + opt = yaml.load(f, Loader=yaml.FullLoader) + + dataset = RealESRGANPairedDataset(opt) + assert dataset.io_backend_opt['type'] == 'disk' # io backend + assert len(dataset) == 2 # whether to read correct meta info + + # test __getitem__ + result = dataset.__getitem__(0) + # check returned keys + expected_keys = ['gt', 'lq', 'gt_path', 'lq_path'] + assert set(expected_keys).issubset(set(result.keys())) + # check shape and contents + assert result['gt'].shape == (3, 128, 128) + assert result['lq'].shape == (3, 32, 32) + assert result['gt_path'] == 'tests/data/gt/baboon.png' + assert result['lq_path'] == 'tests/data/lq/baboon.png' + + # ------------------ test lmdb backend -------------------- # + opt['dataroot_gt'] = 'tests/data/gt.lmdb' + opt['dataroot_lq'] = 'tests/data/lq.lmdb' + opt['io_backend']['type'] = 'lmdb' + + dataset = RealESRGANPairedDataset(opt) + assert dataset.io_backend_opt['type'] == 'lmdb' # io backend + assert len(dataset) == 2 # whether to read correct meta info + + # test __getitem__ + result = dataset.__getitem__(1) + # check returned keys + expected_keys = ['gt', 'lq', 'gt_path', 'lq_path'] + assert set(expected_keys).issubset(set(result.keys())) + # check shape and contents + assert result['gt'].shape == (3, 128, 128) + assert result['lq'].shape == (3, 32, 32) + assert result['gt_path'] == 'comic' + assert result['lq_path'] == 'comic' + + # ------------------ test paired_paths_from_folder -------------------- # + opt['dataroot_gt'] = 'tests/data/gt' + opt['dataroot_lq'] = 'tests/data/lq' + opt['io_backend'] = dict(type='disk') + opt['meta_info'] = None + + dataset = RealESRGANPairedDataset(opt) + assert dataset.io_backend_opt['type'] == 'disk' # io backend + assert len(dataset) == 2 # whether to read correct meta info + + # test __getitem__ + result = dataset.__getitem__(0) + # check returned keys + expected_keys = ['gt', 'lq', 'gt_path', 'lq_path'] + assert set(expected_keys).issubset(set(result.keys())) + # check shape and contents + assert result['gt'].shape == (3, 128, 128) + assert result['lq'].shape == (3, 32, 32) + + # ------------------ test normalization -------------------- # + dataset.mean = [0.5, 0.5, 0.5] + dataset.std = [0.5, 0.5, 0.5] + # test __getitem__ + result = dataset.__getitem__(0) + # check returned keys + expected_keys = ['gt', 'lq', 'gt_path', 'lq_path'] + assert set(expected_keys).issubset(set(result.keys())) + # check shape and contents + assert result['gt'].shape == (3, 128, 128) + assert result['lq'].shape == (3, 32, 32) diff --git a/tests/test_discriminator_arch.py b/tests/test_discriminator_arch.py new file mode 100644 index 0000000..c56a40c --- /dev/null +++ b/tests/test_discriminator_arch.py @@ -0,0 +1,19 @@ +import torch + +from realesrgan.archs.discriminator_arch import UNetDiscriminatorSN + + +def test_unetdiscriminatorsn(): + """Test arch: UNetDiscriminatorSN.""" + + # model init and forward (cpu) + net = UNetDiscriminatorSN(num_in_ch=3, num_feat=4, skip_connection=True) + img = torch.rand((1, 3, 32, 32), dtype=torch.float32) + output = net(img) + assert output.shape == (1, 1, 32, 32) + + # model init and forward (gpu) + if torch.cuda.is_available(): + net.cuda() + output = net(img.cuda()) + assert output.shape == (1, 1, 32, 32)