Added GPU selection feature to python inference (#321)
* Added GPU selection feature to python inference * pylint pep8 fixes * pep8 fixes
This commit is contained in:
@@ -39,6 +39,9 @@ def main():
|
|||||||
type=str,
|
type=str,
|
||||||
default='auto',
|
default='auto',
|
||||||
help='Image extension. Options: auto | jpg | png, auto means using the same extension as inputs')
|
help='Image extension. Options: auto | jpg | png, auto means using the same extension as inputs')
|
||||||
|
parser.add_argument(
|
||||||
|
'-g', '--gpu-id', type=int, default=None, help='gpu device to use (default=None) can be 0,1,2 for multi-gpu')
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# determine models according to model names
|
# determine models according to model names
|
||||||
@@ -71,7 +74,8 @@ def main():
|
|||||||
tile=args.tile,
|
tile=args.tile,
|
||||||
tile_pad=args.tile_pad,
|
tile_pad=args.tile_pad,
|
||||||
pre_pad=args.pre_pad,
|
pre_pad=args.pre_pad,
|
||||||
half=not args.fp32)
|
half=not args.fp32,
|
||||||
|
gpu_id=args.gpu_id)
|
||||||
|
|
||||||
if args.face_enhance: # Use GFPGAN for face enhancement
|
if args.face_enhance: # Use GFPGAN for face enhancement
|
||||||
from gfpgan import GFPGANer
|
from gfpgan import GFPGANer
|
||||||
|
|||||||
@@ -26,7 +26,16 @@ class RealESRGANer():
|
|||||||
half (float): Whether to use half precision during inference. Default: False.
|
half (float): Whether to use half precision during inference. Default: False.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, scale, model_path, model=None, tile=0, tile_pad=10, pre_pad=10, half=False, device=None):
|
def __init__(self,
|
||||||
|
scale,
|
||||||
|
model_path,
|
||||||
|
model=None,
|
||||||
|
tile=0,
|
||||||
|
tile_pad=10,
|
||||||
|
pre_pad=10,
|
||||||
|
half=False,
|
||||||
|
device=None,
|
||||||
|
gpu_id=None):
|
||||||
self.scale = scale
|
self.scale = scale
|
||||||
self.tile_size = tile
|
self.tile_size = tile
|
||||||
self.tile_pad = tile_pad
|
self.tile_pad = tile_pad
|
||||||
@@ -35,7 +44,11 @@ class RealESRGANer():
|
|||||||
self.half = half
|
self.half = half
|
||||||
|
|
||||||
# initialize model
|
# initialize model
|
||||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device
|
if gpu_id:
|
||||||
|
self.device = torch.device(
|
||||||
|
f'cuda:{gpu_id}' if torch.cuda.is_available() else 'cpu') if device is None else device
|
||||||
|
else:
|
||||||
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device
|
||||||
# if the model_path starts with https, it will first download models to the folder: realesrgan/weights
|
# if the model_path starts with https, it will first download models to the folder: realesrgan/weights
|
||||||
if model_path.startswith('https://'):
|
if model_path.startswith('https://'):
|
||||||
model_path = load_file_from_url(
|
model_path = load_file_from_url(
|
||||||
|
|||||||
Reference in New Issue
Block a user