add unittest for dataset and archs
This commit is contained in:
19
tests/test_discriminator_arch.py
Normal file
19
tests/test_discriminator_arch.py
Normal file
@@ -0,0 +1,19 @@
|
||||
import torch
|
||||
|
||||
from realesrgan.archs.discriminator_arch import UNetDiscriminatorSN
|
||||
|
||||
|
||||
def test_unetdiscriminatorsn():
|
||||
"""Test arch: UNetDiscriminatorSN."""
|
||||
|
||||
# model init and forward (cpu)
|
||||
net = UNetDiscriminatorSN(num_in_ch=3, num_feat=4, skip_connection=True)
|
||||
img = torch.rand((1, 3, 32, 32), dtype=torch.float32)
|
||||
output = net(img)
|
||||
assert output.shape == (1, 1, 32, 32)
|
||||
|
||||
# model init and forward (gpu)
|
||||
if torch.cuda.is_available():
|
||||
net.cuda()
|
||||
output = net(img.cuda())
|
||||
assert output.shape == (1, 1, 32, 32)
|
||||
Reference in New Issue
Block a user