support denoise strength for realesr-general-x4v3
This commit is contained in:
@@ -29,6 +29,7 @@ class RealESRGANer():
|
||||
def __init__(self,
|
||||
scale,
|
||||
model_path,
|
||||
dni_weight=None,
|
||||
model=None,
|
||||
tile=0,
|
||||
tile_pad=10,
|
||||
@@ -49,22 +50,44 @@ class RealESRGANer():
|
||||
f'cuda:{gpu_id}' if torch.cuda.is_available() else 'cpu') if device is None else device
|
||||
else:
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device
|
||||
# if the model_path starts with https, it will first download models to the folder: realesrgan/weights
|
||||
if model_path.startswith('https://'):
|
||||
model_path = load_file_from_url(
|
||||
url=model_path, model_dir=os.path.join(ROOT_DIR, 'realesrgan/weights'), progress=True, file_name=None)
|
||||
loadnet = torch.load(model_path, map_location=torch.device('cpu'))
|
||||
|
||||
if isinstance(model_path, list):
|
||||
# dni
|
||||
assert len(model_path) == len(dni_weight), 'model_path and dni_weight should have the save length.'
|
||||
loadnet = self.dni(model_path[0], model_path[1], dni_weight)
|
||||
else:
|
||||
# if the model_path starts with https, it will first download models to the folder: realesrgan/weights
|
||||
if model_path.startswith('https://'):
|
||||
model_path = load_file_from_url(
|
||||
url=model_path,
|
||||
model_dir=os.path.join(ROOT_DIR, 'realesrgan/weights'),
|
||||
progress=True,
|
||||
file_name=None)
|
||||
loadnet = torch.load(model_path, map_location=torch.device('cpu'))
|
||||
|
||||
# prefer to use params_ema
|
||||
if 'params_ema' in loadnet:
|
||||
keyname = 'params_ema'
|
||||
else:
|
||||
keyname = 'params'
|
||||
model.load_state_dict(loadnet[keyname], strict=True)
|
||||
|
||||
model.eval()
|
||||
self.model = model.to(self.device)
|
||||
if self.half:
|
||||
self.model = self.model.half()
|
||||
|
||||
def dni(self, net_a, net_b, dni_weight, key='params', loc='cpu'):
|
||||
"""Deep network interpolation.
|
||||
|
||||
``Paper: Deep Network Interpolation for Continuous Imagery Effect Transition``
|
||||
"""
|
||||
net_a = torch.load(net_a, map_location=torch.device(loc))
|
||||
net_b = torch.load(net_b, map_location=torch.device(loc))
|
||||
for k, v_a in net_a[key].items():
|
||||
net_a[key][k] = dni_weight[0] * v_a + dni_weight[1] * net_b[key][k]
|
||||
return net_a
|
||||
|
||||
def pre_process(self, img):
|
||||
"""Pre-process, such as pre-pad and mod pad, so that the images can be divisible
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user