diff --git a/realesrgan/utils.py b/realesrgan/utils.py index 8dad829..a30ac1c 100644 --- a/realesrgan/utils.py +++ b/realesrgan/utils.py @@ -4,9 +4,8 @@ import numpy as np import os import torch from basicsr.archs.rrdbnet_arch import RRDBNet -from torch.hub import download_url_to_file, get_dir +from basicsr.utils.download_util import load_file_from_url from torch.nn import functional as F -from urllib.parse import urlparse ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) @@ -42,7 +41,7 @@ class RealESRGANer(): # if the model_path starts with https, it will first download models to the folder: realesrgan/weights if model_path.startswith('https://'): model_path = load_file_from_url( - url=model_path, model_dir='realesrgan/weights', progress=True, file_name=None) + url=model_path, model_dir=os.path.join(ROOT_DIR, 'realesrgan/weights'), progress=True, file_name=None) loadnet = torch.load(model_path) # prefer to use params_ema if 'params_ema' in loadnet: @@ -231,25 +230,3 @@ class RealESRGANer(): ), interpolation=cv2.INTER_LANCZOS4) return output, img_mode - - -def load_file_from_url(url, model_dir=None, progress=True, file_name=None): - """Load file form http url, will download models if necessary. - - Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py - """ - if model_dir is None: - hub_dir = get_dir() - model_dir = os.path.join(hub_dir, 'checkpoints') - - os.makedirs(os.path.join(ROOT_DIR, model_dir), exist_ok=True) - - parts = urlparse(url) - filename = os.path.basename(parts.path) - if file_name is not None: - filename = file_name - cached_file = os.path.abspath(os.path.join(ROOT_DIR, model_dir, filename)) - if not os.path.exists(cached_file): - print(f'Downloading: "{url}" to {cached_file}\n') - download_url_to_file(url, cached_file, hash_prefix=None, progress=progress) - return cached_file diff --git a/tests/data/demo_option_realesrgan_dataset.yml b/tests/data/test_realesrgan_dataset.yml similarity index 100% rename from tests/data/demo_option_realesrgan_dataset.yml rename to tests/data/test_realesrgan_dataset.yml diff --git a/tests/data/test_realesrgan_model.yml b/tests/data/test_realesrgan_model.yml new file mode 100644 index 0000000..1cbdab2 --- /dev/null +++ b/tests/data/test_realesrgan_model.yml @@ -0,0 +1,115 @@ +scale: 4 +num_gpu: 1 +manual_seed: 0 +is_train: True +dist: False + +# ----------------- options for synthesizing training data ----------------- # +# USM the ground-truth +l1_gt_usm: True +percep_gt_usm: True +gan_gt_usm: False + +# the first degradation process +resize_prob: [0.2, 0.7, 0.1] # up, down, keep +resize_range: [0.15, 1.5] +gaussian_noise_prob: 1 +noise_range: [1, 30] +poisson_scale_range: [0.05, 3] +gray_noise_prob: 1 +jpeg_range: [30, 95] + +# the second degradation process +second_blur_prob: 1 +resize_prob2: [0.3, 0.4, 0.3] # up, down, keep +resize_range2: [0.3, 1.2] +gaussian_noise_prob2: 1 +noise_range2: [1, 25] +poisson_scale_range2: [0.05, 2.5] +gray_noise_prob2: 1 +jpeg_range2: [30, 95] + +gt_size: 32 +queue_size: 1 + +# network structures +network_g: + type: RRDBNet + num_in_ch: 3 + num_out_ch: 3 + num_feat: 4 + num_block: 1 + num_grow_ch: 2 + +network_d: + type: UNetDiscriminatorSN + num_in_ch: 3 + num_feat: 2 + skip_connection: True + +# path +path: + pretrain_network_g: ~ + param_key_g: params_ema + strict_load_g: true + resume_state: ~ + +# training settings +train: + ema_decay: 0.999 + optim_g: + type: Adam + lr: !!float 1e-4 + weight_decay: 0 + betas: [0.9, 0.99] + optim_d: + type: Adam + lr: !!float 1e-4 + weight_decay: 0 + betas: [0.9, 0.99] + + scheduler: + type: MultiStepLR + milestones: [400000] + gamma: 0.5 + + total_iter: 400000 + warmup_iter: -1 # no warm up + + # losses + pixel_opt: + type: L1Loss + loss_weight: 1.0 + reduction: mean + # perceptual loss (content and style losses) + perceptual_opt: + type: PerceptualLoss + layer_weights: + # before relu + 'conv1_2': 0.1 + 'conv2_2': 0.1 + 'conv3_4': 1 + 'conv4_4': 1 + 'conv5_4': 1 + vgg_type: vgg19 + use_input_norm: true + perceptual_weight: !!float 1.0 + style_weight: 0 + range_norm: false + criterion: l1 + # gan loss + gan_opt: + type: GANLoss + gan_type: vanilla + real_label_val: 1.0 + fake_label_val: 0.0 + loss_weight: !!float 1e-1 + + net_d_iters: 1 + net_d_init_iters: 0 + + +# validation settings +val: + val_freq: !!float 5e3 + save_img: False diff --git a/tests/data/demo_option_realesrgan_paired_dataset.yml b/tests/data/test_realesrgan_paired_dataset.yml similarity index 100% rename from tests/data/demo_option_realesrgan_paired_dataset.yml rename to tests/data/test_realesrgan_paired_dataset.yml diff --git a/tests/data/test_realesrnet_model.yml b/tests/data/test_realesrnet_model.yml new file mode 100644 index 0000000..06ceb26 --- /dev/null +++ b/tests/data/test_realesrnet_model.yml @@ -0,0 +1,75 @@ +scale: 4 +num_gpu: 1 +manual_seed: 0 +is_train: True +dist: False + +# ----------------- options for synthesizing training data ----------------- # +gt_usm: True # USM the ground-truth + +# the first degradation process +resize_prob: [0.2, 0.7, 0.1] # up, down, keep +resize_range: [0.15, 1.5] +gaussian_noise_prob: 1 +noise_range: [1, 30] +poisson_scale_range: [0.05, 3] +gray_noise_prob: 1 +jpeg_range: [30, 95] + +# the second degradation process +second_blur_prob: 1 +resize_prob2: [0.3, 0.4, 0.3] # up, down, keep +resize_range2: [0.3, 1.2] +gaussian_noise_prob2: 1 +noise_range2: [1, 25] +poisson_scale_range2: [0.05, 2.5] +gray_noise_prob2: 1 +jpeg_range2: [30, 95] + +gt_size: 32 +queue_size: 1 + +# network structures +network_g: + type: RRDBNet + num_in_ch: 3 + num_out_ch: 3 + num_feat: 4 + num_block: 1 + num_grow_ch: 2 + +# path +path: + pretrain_network_g: ~ + param_key_g: params_ema + strict_load_g: true + resume_state: ~ + +# training settings +train: + ema_decay: 0.999 + optim_g: + type: Adam + lr: !!float 2e-4 + weight_decay: 0 + betas: [0.9, 0.99] + + scheduler: + type: MultiStepLR + milestones: [1000000] + gamma: 0.5 + + total_iter: 1000000 + warmup_iter: -1 # no warm up + + # losses + pixel_opt: + type: L1Loss + loss_weight: 1.0 + reduction: mean + + +# validation settings +val: + val_freq: !!float 5e3 + save_img: False diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 3fb051a..715b408 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -7,7 +7,7 @@ from realesrgan.data.realesrgan_paired_dataset import RealESRGANPairedDataset def test_realesrgan_dataset(): - with open('tests/data/demo_option_realesrgan_dataset.yml', mode='r') as f: + with open('tests/data/test_realesrgan_dataset.yml', mode='r') as f: opt = yaml.load(f, Loader=yaml.FullLoader) dataset = RealESRGANDataset(opt) @@ -81,7 +81,7 @@ def test_realesrgan_dataset(): def test_realesrgan_paired_dataset(): - with open('tests/data/demo_option_realesrgan_paired_dataset.yml', mode='r') as f: + with open('tests/data/test_realesrgan_paired_dataset.yml', mode='r') as f: opt = yaml.load(f, Loader=yaml.FullLoader) dataset = RealESRGANPairedDataset(opt) diff --git a/tests/test_model.py b/tests/test_model.py new file mode 100644 index 0000000..c20bb1d --- /dev/null +++ b/tests/test_model.py @@ -0,0 +1,126 @@ +import torch +import yaml +from basicsr.archs.rrdbnet_arch import RRDBNet +from basicsr.data.paired_image_dataset import PairedImageDataset +from basicsr.losses.losses import GANLoss, L1Loss, PerceptualLoss + +from realesrgan.archs.discriminator_arch import UNetDiscriminatorSN +from realesrgan.models.realesrgan_model import RealESRGANModel +from realesrgan.models.realesrnet_model import RealESRNetModel + + +def test_realesrnet_model(): + with open('tests/data/test_realesrnet_model.yml', mode='r') as f: + opt = yaml.load(f, Loader=yaml.FullLoader) + + # build model + model = RealESRNetModel(opt) + # test attributes + assert model.__class__.__name__ == 'RealESRNetModel' + assert isinstance(model.net_g, RRDBNet) + assert isinstance(model.cri_pix, L1Loss) + assert isinstance(model.optimizers[0], torch.optim.Adam) + + # prepare data + gt = torch.rand((1, 3, 32, 32), dtype=torch.float32) + kernel1 = torch.rand((1, 5, 5), dtype=torch.float32) + kernel2 = torch.rand((1, 5, 5), dtype=torch.float32) + sinc_kernel = torch.rand((1, 5, 5), dtype=torch.float32) + data = dict(gt=gt, kernel1=kernel1, kernel2=kernel2, sinc_kernel=sinc_kernel) + model.feed_data(data) + # check dequeue + model.feed_data(data) + # check data shape + assert model.lq.shape == (1, 3, 8, 8) + assert model.gt.shape == (1, 3, 32, 32) + + # change probability to test if-else + model.opt['gaussian_noise_prob'] = 0 + model.opt['gray_noise_prob'] = 0 + model.opt['second_blur_prob'] = 0 + model.opt['gaussian_noise_prob2'] = 0 + model.opt['gray_noise_prob2'] = 0 + model.feed_data(data) + # check data shape + assert model.lq.shape == (1, 3, 8, 8) + assert model.gt.shape == (1, 3, 32, 32) + + # ----------------- test nondist_validation -------------------- # + # construct dataloader + dataset_opt = dict( + name='Demo', + dataroot_gt='tests/data/gt', + dataroot_lq='tests/data/lq', + io_backend=dict(type='disk'), + scale=4, + phase='val') + dataset = PairedImageDataset(dataset_opt) + dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=1, shuffle=False, num_workers=0) + assert model.is_train is True + model.nondist_validation(dataloader, 1, None, False) + assert model.is_train is True + + +def test_realesrgan_model(): + with open('tests/data/test_realesrgan_model.yml', mode='r') as f: + opt = yaml.load(f, Loader=yaml.FullLoader) + + # build model + model = RealESRGANModel(opt) + # test attributes + assert model.__class__.__name__ == 'RealESRGANModel' + assert isinstance(model.net_g, RRDBNet) # generator + assert isinstance(model.net_d, UNetDiscriminatorSN) # discriminator + assert isinstance(model.cri_pix, L1Loss) + assert isinstance(model.cri_perceptual, PerceptualLoss) + assert isinstance(model.cri_gan, GANLoss) + assert isinstance(model.optimizers[0], torch.optim.Adam) + assert isinstance(model.optimizers[1], torch.optim.Adam) + + # prepare data + gt = torch.rand((1, 3, 32, 32), dtype=torch.float32) + kernel1 = torch.rand((1, 5, 5), dtype=torch.float32) + kernel2 = torch.rand((1, 5, 5), dtype=torch.float32) + sinc_kernel = torch.rand((1, 5, 5), dtype=torch.float32) + data = dict(gt=gt, kernel1=kernel1, kernel2=kernel2, sinc_kernel=sinc_kernel) + model.feed_data(data) + # check dequeue + model.feed_data(data) + # check data shape + assert model.lq.shape == (1, 3, 8, 8) + assert model.gt.shape == (1, 3, 32, 32) + + # change probability to test if-else + model.opt['gaussian_noise_prob'] = 0 + model.opt['gray_noise_prob'] = 0 + model.opt['second_blur_prob'] = 0 + model.opt['gaussian_noise_prob2'] = 0 + model.opt['gray_noise_prob2'] = 0 + model.feed_data(data) + # check data shape + assert model.lq.shape == (1, 3, 8, 8) + assert model.gt.shape == (1, 3, 32, 32) + + # ----------------- test nondist_validation -------------------- # + # construct dataloader + dataset_opt = dict( + name='Demo', + dataroot_gt='tests/data/gt', + dataroot_lq='tests/data/lq', + io_backend=dict(type='disk'), + scale=4, + phase='val') + dataset = PairedImageDataset(dataset_opt) + dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=1, shuffle=False, num_workers=0) + assert model.is_train is True + model.nondist_validation(dataloader, 1, None, False) + assert model.is_train is True + + # ----------------- test optimize_parameters -------------------- # + model.feed_data(data) + model.optimize_parameters(1) + assert model.output.shape == (1, 3, 32, 32) + assert isinstance(model.log_dict, dict) + # check returned keys + expected_keys = ['l_g_pix', 'l_g_percep', 'l_g_gan', 'l_d_real', 'out_d_real', 'l_d_fake', 'out_d_fake'] + assert set(expected_keys).issubset(set(model.log_dict.keys())) diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000..7919b74 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,87 @@ +import numpy as np +from basicsr.archs.rrdbnet_arch import RRDBNet + +from realesrgan.utils import RealESRGANer + + +def test_realesrganer(): + # initialize with default model + restorer = RealESRGANer( + scale=4, + model_path='experiments/pretrained_models/RealESRGAN_x4plus.pth', + model=None, + tile=10, + tile_pad=10, + pre_pad=2, + half=False) + assert isinstance(restorer.model, RRDBNet) + assert restorer.half is False + # initialize with user-defined model + model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4) + restorer = RealESRGANer( + scale=4, + model_path='experiments/pretrained_models/RealESRGAN_x4plus_anime_6B.pth', + model=model, + tile=10, + tile_pad=10, + pre_pad=2, + half=True) + # test attribute + assert isinstance(restorer.model, RRDBNet) + assert restorer.half is True + + # ------------------ test pre_process ---------------- # + img = np.random.random((12, 12, 3)).astype(np.float32) + restorer.pre_process(img) + assert restorer.img.shape == (1, 3, 14, 14) + # with modcrop + restorer.scale = 1 + restorer.pre_process(img) + assert restorer.img.shape == (1, 3, 16, 16) + + # ------------------ test process ---------------- # + restorer.process() + assert restorer.output.shape == (1, 3, 64, 64) + + # ------------------ test post_process ---------------- # + restorer.mod_scale = 4 + output = restorer.post_process() + assert output.shape == (1, 3, 60, 60) + + # ------------------ test tile_process ---------------- # + restorer.scale = 4 + img = np.random.random((12, 12, 3)).astype(np.float32) + restorer.pre_process(img) + restorer.tile_process() + assert restorer.output.shape == (1, 3, 64, 64) + + # ------------------ test enhance ---------------- # + img = np.random.random((12, 12, 3)).astype(np.float32) + result = restorer.enhance(img, outscale=2) + assert result[0].shape == (24, 24, 3) + assert result[1] == 'RGB' + + # ------------------ test enhance with 16-bit image---------------- # + img = np.random.random((4, 4, 3)).astype(np.uint16) + 512 + result = restorer.enhance(img, outscale=2) + assert result[0].shape == (8, 8, 3) + assert result[1] == 'RGB' + + # ------------------ test enhance with gray image---------------- # + img = np.random.random((4, 4)).astype(np.float32) + result = restorer.enhance(img, outscale=2) + assert result[0].shape == (8, 8) + assert result[1] == 'L' + + # ------------------ test enhance with RGBA---------------- # + img = np.random.random((4, 4, 4)).astype(np.float32) + result = restorer.enhance(img, outscale=2) + assert result[0].shape == (8, 8, 4) + assert result[1] == 'RGBA' + + # ------------------ test enhance with RGBA, alpha_upsampler---------------- # + restorer.tile_size = 0 + img = np.random.random((4, 4, 4)).astype(np.float32) + result = restorer.enhance(img, outscale=2, alpha_upsampler=None) + assert result[0].shape == (8, 8, 4) + assert result[1] == 'RGBA'