support finetune with paired data

This commit is contained in:
Xintao
2021-08-27 16:14:48 +08:00
parent 194c2c14b3
commit f5ccd64ce5
11 changed files with 426 additions and 7 deletions

View File

@@ -19,7 +19,7 @@ class RealESRGANModel(SRGANModel):
super(RealESRGANModel, self).__init__(opt)
self.jpeger = DiffJPEG(differentiable=False).cuda()
self.usm_sharpener = USMSharp().cuda()
self.queue_size = opt['queue_size']
self.queue_size = opt.get('queue_size', 180)
@torch.no_grad()
def _dequeue_and_enqueue(self):
@@ -55,7 +55,7 @@ class RealESRGANModel(SRGANModel):
@torch.no_grad()
def feed_data(self, data):
if self.is_train:
if self.is_train and self.opt.get('high_order_degradation', True):
# training data synthesis
self.gt = data['gt'].to(self.device)
self.gt_usm = self.usm_sharpener(self.gt)
@@ -166,6 +166,7 @@ class RealESRGANModel(SRGANModel):
self.lq = data['lq'].to(self.device)
if 'gt' in data:
self.gt = data['gt'].to(self.device)
self.gt_usm = self.usm_sharpener(self.gt)
def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
# do not use the synthetic process during validation