add SRVGGNetCompact arch, update inference

This commit is contained in:
Xintao
2021-12-12 13:29:21 +08:00
parent 3e0085aeda
commit 696e1a6741
7 changed files with 139 additions and 62 deletions

View File

@@ -3,7 +3,6 @@ import math
import numpy as np
import os
import torch
from basicsr.archs.rrdbnet_arch import RRDBNet
from basicsr.utils.download_util import load_file_from_url
from torch.nn import functional as F
@@ -16,7 +15,7 @@ class RealESRGANer():
Args:
scale (int): Upsampling scale factor used in the networks. It is usually 2 or 4.
model_path (str): The path to the pretrained model. It can be urls (will first download it automatically).
model (nn.Module): The defined network. If None, the model will be constructed here. Default: None.
model (nn.Module): The defined network. Default: None.
tile (int): As too large images result in the out of GPU memory issue, so this tile option will first crop
input images into tiles, and then process each of them. Finally, they will be merged into one image.
0 denotes for do not use tile. Default: 0.
@@ -35,9 +34,6 @@ class RealESRGANer():
# initialize model
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if model is None:
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=scale)
# if the model_path starts with https, it will first download models to the folder: realesrgan/weights
if model_path.startswith('https://'):
model_path = load_file_from_url(