fix bug: gt_sum; fix typo
This commit is contained in:
@@ -18,7 +18,7 @@ class RealESRGANModel(SRGANModel):
|
|||||||
def __init__(self, opt):
|
def __init__(self, opt):
|
||||||
super(RealESRGANModel, self).__init__(opt)
|
super(RealESRGANModel, self).__init__(opt)
|
||||||
self.jpeger = DiffJPEG(differentiable=False).cuda()
|
self.jpeger = DiffJPEG(differentiable=False).cuda()
|
||||||
self.usm_shaper = USMSharp().cuda()
|
self.usm_sharpener = USMSharp().cuda()
|
||||||
self.queue_size = opt['queue_size']
|
self.queue_size = opt['queue_size']
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@@ -58,7 +58,7 @@ class RealESRGANModel(SRGANModel):
|
|||||||
if self.is_train:
|
if self.is_train:
|
||||||
# training data synthesis
|
# training data synthesis
|
||||||
self.gt = data['gt'].to(self.device)
|
self.gt = data['gt'].to(self.device)
|
||||||
self.gt_usm = self.usm_shaper(self.gt)
|
self.gt_usm = self.usm_sharpener(self.gt)
|
||||||
|
|
||||||
self.kernel1 = data['kernel1'].to(self.device)
|
self.kernel1 = data['kernel1'].to(self.device)
|
||||||
self.kernel2 = data['kernel2'].to(self.device)
|
self.kernel2 = data['kernel2'].to(self.device)
|
||||||
@@ -160,6 +160,8 @@ class RealESRGANModel(SRGANModel):
|
|||||||
|
|
||||||
# training pair pool
|
# training pair pool
|
||||||
self._dequeue_and_enqueue()
|
self._dequeue_and_enqueue()
|
||||||
|
# sharpen self.gt again, as we have changed the self.gt with self._dequeue_and_enqueue
|
||||||
|
self.gt_usm = self.usm_sharpener(self.gt)
|
||||||
else:
|
else:
|
||||||
self.lq = data['lq'].to(self.device)
|
self.lq = data['lq'].to(self.device)
|
||||||
if 'gt' in data:
|
if 'gt' in data:
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ class RealESRNetModel(SRModel):
|
|||||||
def __init__(self, opt):
|
def __init__(self, opt):
|
||||||
super(RealESRNetModel, self).__init__(opt)
|
super(RealESRNetModel, self).__init__(opt)
|
||||||
self.jpeger = DiffJPEG(differentiable=False).cuda()
|
self.jpeger = DiffJPEG(differentiable=False).cuda()
|
||||||
self.usm_shaper = USMSharp().cuda()
|
self.usm_sharpener = USMSharp().cuda()
|
||||||
self.queue_size = opt['queue_size']
|
self.queue_size = opt['queue_size']
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@@ -59,7 +59,7 @@ class RealESRNetModel(SRModel):
|
|||||||
self.gt = data['gt'].to(self.device)
|
self.gt = data['gt'].to(self.device)
|
||||||
# USM the GT images
|
# USM the GT images
|
||||||
if self.opt['gt_usm'] is True:
|
if self.opt['gt_usm'] is True:
|
||||||
self.gt = self.usm_shaper(self.gt)
|
self.gt = self.usm_sharpener(self.gt)
|
||||||
|
|
||||||
self.kernel1 = data['kernel1'].to(self.device)
|
self.kernel1 = data['kernel1'].to(self.device)
|
||||||
self.kernel2 = data['kernel2'].to(self.device)
|
self.kernel2 = data['kernel2'].to(self.device)
|
||||||
|
|||||||
Reference in New Issue
Block a user