20 lines
561 B
Python
20 lines
561 B
Python
import torch
|
|
|
|
from realesrgan.archs.discriminator_arch import UNetDiscriminatorSN
|
|
|
|
|
|
def test_unetdiscriminatorsn():
|
|
"""Test arch: UNetDiscriminatorSN."""
|
|
|
|
# model init and forward (cpu)
|
|
net = UNetDiscriminatorSN(num_in_ch=3, num_feat=4, skip_connection=True)
|
|
img = torch.rand((1, 3, 32, 32), dtype=torch.float32)
|
|
output = net(img)
|
|
assert output.shape == (1, 1, 32, 32)
|
|
|
|
# model init and forward (gpu)
|
|
if torch.cuda.is_available():
|
|
net.cuda()
|
|
output = net(img.cuda())
|
|
assert output.shape == (1, 1, 32, 32)
|