Added GPU selection feature to python inference (#321)

* Added GPU selection feature to python inference

* pylint pep8 fixes

* pep8 fixes
This commit is contained in:
Mert Cobanov
2022-05-24 15:24:49 +03:00
committed by GitHub
parent bc77ca5666
commit 6b15fc6936
2 changed files with 20 additions and 3 deletions

View File

@@ -26,7 +26,16 @@ class RealESRGANer():
half (float): Whether to use half precision during inference. Default: False.
"""
def __init__(self, scale, model_path, model=None, tile=0, tile_pad=10, pre_pad=10, half=False, device=None):
def __init__(self,
scale,
model_path,
model=None,
tile=0,
tile_pad=10,
pre_pad=10,
half=False,
device=None,
gpu_id=None):
self.scale = scale
self.tile_size = tile
self.tile_pad = tile_pad
@@ -35,7 +44,11 @@ class RealESRGANer():
self.half = half
# initialize model
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device
if gpu_id:
self.device = torch.device(
f'cuda:{gpu_id}' if torch.cuda.is_available() else 'cpu') if device is None else device
else:
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device
# if the model_path starts with https, it will first download models to the folder: realesrgan/weights
if model_path.startswith('https://'):
model_path = load_file_from_url(