From 492a829c1414735f1a705861cce6ce40e8d70194 Mon Sep 17 00:00:00 2001 From: Xintao Date: Tue, 27 Jul 2021 11:06:43 +0800 Subject: [PATCH] fix bug: gt_sum; fix typo --- models/realesrgan_model.py | 6 ++++-- models/realesrnet_model.py | 4 ++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/models/realesrgan_model.py b/models/realesrgan_model.py index 180103e..5b1268e 100644 --- a/models/realesrgan_model.py +++ b/models/realesrgan_model.py @@ -18,7 +18,7 @@ class RealESRGANModel(SRGANModel): def __init__(self, opt): super(RealESRGANModel, self).__init__(opt) self.jpeger = DiffJPEG(differentiable=False).cuda() - self.usm_shaper = USMSharp().cuda() + self.usm_sharpener = USMSharp().cuda() self.queue_size = opt['queue_size'] @torch.no_grad() @@ -58,7 +58,7 @@ class RealESRGANModel(SRGANModel): if self.is_train: # training data synthesis self.gt = data['gt'].to(self.device) - self.gt_usm = self.usm_shaper(self.gt) + self.gt_usm = self.usm_sharpener(self.gt) self.kernel1 = data['kernel1'].to(self.device) self.kernel2 = data['kernel2'].to(self.device) @@ -160,6 +160,8 @@ class RealESRGANModel(SRGANModel): # training pair pool self._dequeue_and_enqueue() + # sharpen self.gt again, as we have changed the self.gt with self._dequeue_and_enqueue + self.gt_usm = self.usm_sharpener(self.gt) else: self.lq = data['lq'].to(self.device) if 'gt' in data: diff --git a/models/realesrnet_model.py b/models/realesrnet_model.py index e92e833..1b5651d 100644 --- a/models/realesrnet_model.py +++ b/models/realesrnet_model.py @@ -17,7 +17,7 @@ class RealESRNetModel(SRModel): def __init__(self, opt): super(RealESRNetModel, self).__init__(opt) self.jpeger = DiffJPEG(differentiable=False).cuda() - self.usm_shaper = USMSharp().cuda() + self.usm_sharpener = USMSharp().cuda() self.queue_size = opt['queue_size'] @torch.no_grad() @@ -59,7 +59,7 @@ class RealESRNetModel(SRModel): self.gt = data['gt'].to(self.device) # USM the GT images if self.opt['gt_usm'] is True: - self.gt = self.usm_shaper(self.gt) + self.gt = self.usm_sharpener(self.gt) self.kernel1 = data['kernel1'].to(self.device) self.kernel2 = data['kernel2'].to(self.device)