fix bug: gt_sum; fix typo

This commit is contained in:
Xintao
2021-07-27 11:06:43 +08:00
parent 8454fd2c7a
commit 492a829c14
2 changed files with 6 additions and 4 deletions

View File

@@ -18,7 +18,7 @@ class RealESRGANModel(SRGANModel):
def __init__(self, opt):
super(RealESRGANModel, self).__init__(opt)
self.jpeger = DiffJPEG(differentiable=False).cuda()
self.usm_shaper = USMSharp().cuda()
self.usm_sharpener = USMSharp().cuda()
self.queue_size = opt['queue_size']
@torch.no_grad()
@@ -58,7 +58,7 @@ class RealESRGANModel(SRGANModel):
if self.is_train:
# training data synthesis
self.gt = data['gt'].to(self.device)
self.gt_usm = self.usm_shaper(self.gt)
self.gt_usm = self.usm_sharpener(self.gt)
self.kernel1 = data['kernel1'].to(self.device)
self.kernel2 = data['kernel2'].to(self.device)
@@ -160,6 +160,8 @@ class RealESRGANModel(SRGANModel):
# training pair pool
self._dequeue_and_enqueue()
# sharpen self.gt again, as we have changed the self.gt with self._dequeue_and_enqueue
self.gt_usm = self.usm_sharpener(self.gt)
else:
self.lq = data['lq'].to(self.device)
if 'gt' in data:

View File

@@ -17,7 +17,7 @@ class RealESRNetModel(SRModel):
def __init__(self, opt):
super(RealESRNetModel, self).__init__(opt)
self.jpeger = DiffJPEG(differentiable=False).cuda()
self.usm_shaper = USMSharp().cuda()
self.usm_sharpener = USMSharp().cuda()
self.queue_size = opt['queue_size']
@torch.no_grad()
@@ -59,7 +59,7 @@ class RealESRNetModel(SRModel):
self.gt = data['gt'].to(self.device)
# USM the GT images
if self.opt['gt_usm'] is True:
self.gt = self.usm_shaper(self.gt)
self.gt = self.usm_sharpener(self.gt)
self.kernel1 = data['kernel1'].to(self.device)
self.kernel2 = data['kernel2'].to(self.device)