fix colorspace bug & support multi-gpu and multi-processing (#312)

* fix colorspace bug of ffmpeg stream, add multi-gpu and multi-processing suport for inference_realesrgan_video.py

* fix code format

Co-authored-by: yanzewu <yanzewu@tencent.com>
This commit is contained in:
wyz
2022-05-04 13:09:51 +08:00
committed by GitHub
parent 8041099021
commit 8cb9bd403e
3 changed files with 288 additions and 223 deletions

View File

@@ -35,13 +35,20 @@ The following are some demos (best view in the full screen mode).
```bash ```bash
# download model # download model
wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth -P realesrgan/weights wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth -P realesrgan/weights
# inference # single gpu and single process inference
python inference_realesrgan_video.py -i inputs/video/onepiece_demo.mp4 -n realesr-animevideov3 -s 2 --suffix outx2 --stream CUDA_VISIBLE_DEVICES=0 python inference_realesrgan_video.py -i inputs/video/onepiece_demo.mp4 -n realesr-animevideov3 -s 2 --suffix outx2
# single gpu and multi process inference (you can use multi-processing to improve GPU utilization)
CUDA_VISIBLE_DEVICES=0 python inference_realesrgan_video.py -i inputs/video/onepiece_demo.mp4 -n realesr-animevideov3 -s 2 --suffix outx2 --num_process_per_gpu 2
# multi gpu and multi process inference
CUDA_VISIBLE_DEVICES=0,1,2,3 python inference_realesrgan_video.py -i inputs/video/onepiece_demo.mp4 -n realesr-animevideov3 -s 2 --suffix outx2 --num_process_per_gpu 2
``` ```
```console ```console
Usage: Usage:
--stream with this option, the enhanced frames are sent directly to a ffmpeg stream, --num_process_per_gpu The total number of process is num_gpu * num_process_per_gpu. The bottleneck of
avoiding storing large (usually tens of GB) intermediate results. the program lies on the IO, so the GPUs are usually not fully utilized. To alleviate
this issue, you can use multi-processing by setting this parameter. As long as it
does not exceed the CUDA memory
--extract_frame_first If you encounter ffmpeg error when using multi-processing, you can turn this option on.
``` ```
### NCNN Executable File ### NCNN Executable File

View File

@@ -4,111 +4,235 @@ import glob
import mimetypes import mimetypes
import numpy as np import numpy as np
import os import os
import queue
import shutil import shutil
import subprocess
import torch import torch
from basicsr.archs.rrdbnet_arch import RRDBNet from basicsr.archs.rrdbnet_arch import RRDBNet
from basicsr.utils.logger import AvgTimer from os import path as osp
from tqdm import tqdm from tqdm import tqdm
from realesrgan import IOConsumer, PrefetchReader, RealESRGANer from realesrgan import RealESRGANer
from realesrgan.archs.srvgg_arch import SRVGGNetCompact from realesrgan.archs.srvgg_arch import SRVGGNetCompact
try:
import ffmpeg
except ImportError:
import pip
pip.main(['install', '--user', 'ffmpeg-python'])
import ffmpeg
def get_video_meta_info(video_path):
ret = {}
probe = ffmpeg.probe(video_path)
video_streams = [stream for stream in probe['streams'] if stream['codec_type'] == 'video']
has_audio = any(stream['codec_type'] == 'audio' for stream in probe['streams'])
ret['width'] = video_streams[0]['width']
ret['height'] = video_streams[0]['height']
ret['fps'] = eval(video_streams[0]['avg_frame_rate'])
ret['audio'] = ffmpeg.input(video_path).audio if has_audio else None
ret['nb_frames'] = int(video_streams[0]['nb_frames'])
return ret
def get_sub_video(args, num_process, process_idx):
if num_process == 1:
return args.input
meta = get_video_meta_info(args.input)
duration = int(meta['nb_frames'] / meta['fps'])
part_time = duration // num_process
print(f'duration: {duration}, part_time: {part_time}')
os.makedirs(osp.join(args.output, f'{args.video_name}_inp_tmp_videos'), exist_ok=True)
out_path = osp.join(args.output, f'{args.video_name}_inp_tmp_videos', f'{process_idx:03d}.mp4')
cmd = [
args.ffmpeg_bin, f'-i {args.input}', '-ss', f'{part_time * process_idx}',
f'-to {part_time * (process_idx + 1)}' if process_idx != num_process - 1 else '', '-async 1', out_path, '-y'
]
print(' '.join(cmd))
subprocess.call(' '.join(cmd), shell=True)
return out_path
class Reader:
def __init__(self, args, total_workers=1, worker_idx=0):
self.args = args
input_type = mimetypes.guess_type(args.input)[0]
self.input_type = 'folder' if input_type is None else input_type
self.paths = [] # for image&folder type
self.audio = None
self.input_fps = None
if self.input_type.startswith('video'):
video_path = get_sub_video(args, total_workers, worker_idx)
self.stream_reader = (
ffmpeg.input(video_path).output('pipe:', format='rawvideo', pix_fmt='bgr24',
loglevel='error').run_async(
pipe_stdin=True, pipe_stdout=True, cmd=args.ffmpeg_bin))
meta = get_video_meta_info(video_path)
self.width = meta['width']
self.height = meta['height']
self.input_fps = meta['fps']
self.audio = meta['audio']
self.nb_frames = meta['nb_frames']
def get_frames(args, extract_frames=False):
# input can be a video file / a folder of frames / an image
is_video = False
if mimetypes.guess_type(args.input)[0].startswith('video'): # is a video file
is_video = True
video_name = os.path.splitext(os.path.basename(args.input))[0]
if extract_frames:
frame_folder = os.path.join('tmp_frames', video_name)
os.makedirs(frame_folder, exist_ok=True)
# use ffmpeg to extract frames
os.system(f'ffmpeg -i {args.input} -qscale:v 1 -qmin 1 -qmax 1 -vsync 0 {frame_folder}/frame%08d.png')
# get image path list
paths = sorted(glob.glob(os.path.join(frame_folder, '*')))
else: else:
paths = [] if self.input_type.startswith('image'):
# get input video fps self.paths = [args.input]
if args.fps is None: else:
import ffmpeg paths = sorted(glob.glob(os.path.join(args.input, '*')))
probe = ffmpeg.probe(args.input) tot_frames = len(paths)
video_streams = [stream for stream in probe['streams'] if stream['codec_type'] == 'video'] num_frame_per_worker = tot_frames // total_workers + (1 if tot_frames % total_workers else 0)
args.fps = eval(video_streams[0]['avg_frame_rate']) self.paths = paths[num_frame_per_worker * worker_idx:num_frame_per_worker * (worker_idx + 1)]
elif mimetypes.guess_type(args.input)[0].startswith('image'): # is an image file
paths = [args.input] self.nb_frames = len(self.paths)
assert self.nb_frames > 0, 'empty folder'
from PIL import Image
tmp_img = Image.open(self.paths[0])
self.width, self.height = tmp_img.size
self.idx = 0
def get_resolution(self):
return self.height, self.width
def get_fps(self):
if self.args.fps is not None:
return self.args.fps
elif self.input_fps is not None:
return self.input_fps
return 24
def get_audio(self):
return self.audio
def __len__(self):
return self.nb_frames
def get_frame_from_stream(self):
img_bytes = self.stream_reader.stdout.read(self.width * self.height * 3) # 3 bytes for one pixel
if not img_bytes:
return None
img = np.frombuffer(img_bytes, np.uint8).reshape([self.height, self.width, 3])
return img
def get_frame_from_list(self):
if self.idx >= self.nb_frames:
return None
img = cv2.imread(self.paths[self.idx])
self.idx += 1
return img
def get_frame(self):
if self.input_type.startswith('video'):
return self.get_frame_from_stream()
else:
return self.get_frame_from_list()
def close(self):
if self.input_type.startswith('video'):
self.stream_reader.stdin.close()
self.stream_reader.wait()
class Writer:
def __init__(self, args, audio, height, width, video_save_path, fps):
out_width, out_height = int(width * args.outscale), int(height * args.outscale)
if out_height > 2160:
print('You are generating video that is larger than 4K, which will be very slow due to IO speed.',
'We highly recommend to decrease the outscale(aka, -s).')
if audio is not None:
self.stream_writer = (
ffmpeg.input('pipe:', format='rawvideo', pix_fmt='bgr24', s=f'{out_width}x{out_height}',
framerate=fps).output(
audio,
video_save_path,
pix_fmt='yuv420p',
vcodec='libx264',
loglevel='error',
acodec='copy').overwrite_output().run_async(
pipe_stdin=True, pipe_stdout=True, cmd=args.ffmpeg_bin))
else:
self.stream_writer = (
ffmpeg.input('pipe:', format='rawvideo', pix_fmt='bgr24', s=f'{out_width}x{out_height}',
framerate=fps).output(
video_save_path, pix_fmt='yuv420p', vcodec='libx264',
loglevel='error').overwrite_output().run_async(
pipe_stdin=True, pipe_stdout=True, cmd=args.ffmpeg_bin))
def write_frame(self, frame):
frame = frame.astype(np.uint8).tobytes()
self.stream_writer.stdin.write(frame)
def close(self):
self.stream_writer.stdin.close()
self.stream_writer.wait()
def inference_video(args, video_save_path, device=None, total_workers=1, worker_idx=0):
# ---------------------- determine models according to model names ---------------------- #
args.model_name = args.model_name.split('.pth')[0]
if args.model_name in ['RealESRGAN_x4plus', 'RealESRNet_x4plus']: # x4 RRDBNet model
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
netscale = 4
elif args.model_name in ['RealESRGAN_x4plus_anime_6B']: # x4 RRDBNet model with 6 blocks
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
netscale = 4
elif args.model_name in ['RealESRGAN_x2plus']: # x2 RRDBNet model
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
netscale = 2
elif args.model_name in ['realesr-animevideov3']: # x4 VGG-style model (XS size)
model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu')
netscale = 4
else: else:
paths = sorted(glob.glob(os.path.join(args.input, '*'))) raise NotImplementedError
assert len(paths) > 0, 'the input folder is empty'
if args.fps is None: # ---------------------- determine model paths ---------------------- #
args.fps = 24 model_path = os.path.join('experiments/pretrained_models', args.model_name + '.pth')
if not os.path.isfile(model_path):
model_path = os.path.join('realesrgan/weights', args.model_name + '.pth')
if not os.path.isfile(model_path):
raise ValueError(f'Model {args.model_name} does not exist.')
return is_video, paths # restorer
upsampler = RealESRGANer(
scale=netscale,
model_path=model_path,
model=model,
tile=args.tile,
tile_pad=args.tile_pad,
pre_pad=args.pre_pad,
half=not args.fp32,
device=device,
)
if 'anime' in args.model_name and args.face_enhance:
print('face_enhance is not supported in anime models, we turned this option off for you. '
'if you insist on turning it on, please manually comment the relevant lines of code.')
args.face_enhance = False
def inference_stream(args, upsampler, face_enhancer): if args.face_enhance: # Use GFPGAN for face enhancement
try: from gfpgan import GFPGANer
import ffmpeg face_enhancer = GFPGANer(
except ImportError: model_path='https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth',
import pip upscale=args.outscale,
pip.main(['install', '--user', 'ffmpeg-python']) arch='clean',
import ffmpeg channel_multiplier=2,
bg_upsampler=upsampler) # TODO support custom device
is_video, paths = get_frames(args, extract_frames=False)
video_name = os.path.splitext(os.path.basename(args.input))[0]
video_save_path = os.path.join(args.output, f'{video_name}_{args.suffix}.mp4')
# decoder
if is_video:
# get height and width
probe = ffmpeg.probe(args.input)
video_streams = [stream for stream in probe['streams'] if stream['codec_type'] == 'video']
width = video_streams[0]['width']
height = video_streams[0]['height']
# set up frame decoder
decoder = (
ffmpeg.input(args.input).output('pipe:', format='rawvideo', pix_fmt='rgb24', loglevel='warning').run_async(
pipe_stdin=True, pipe_stdout=True, cmd=args.ffmpeg_bin))
else: else:
from PIL import Image face_enhancer = None
tmp_img = Image.open(paths[0])
width, height = tmp_img.size
idx = 0
out_width, out_height = int(width * args.outscale), int(height * args.outscale) reader = Reader(args, total_workers, worker_idx)
if out_height > 2160: audio = reader.get_audio()
print('You are generating video that is larger than 4K, which will be very slow due to IO speed.', height, width = reader.get_resolution()
'We highly recommend to decrease the outscale(aka, -s).') fps = reader.get_fps()
# encoder writer = Writer(args, audio, height, width, video_save_path, fps)
if is_video:
audio = ffmpeg.input(args.input).audio
encoder = (
ffmpeg.input(
'pipe:', format='rawvideo', pix_fmt='rgb24', s=f'{out_width}x{out_height}', framerate=args.fps).output(
audio, video_save_path, pix_fmt='yuv420p', vcodec='libx264', loglevel='info',
acodec='copy').overwrite_output().run_async(pipe_stdin=True, pipe_stdout=True, cmd=args.ffmpeg_bin))
else:
encoder = (
ffmpeg.input(
'pipe:', format='rawvideo', pix_fmt='rgb24', s=f'{out_width}x{out_height}',
framerate=args.fps).output(video_save_path, pix_fmt='yuv420p', vcodec='libx264',
loglevel='info').overwrite_output().run_async(
pipe_stdin=True, pipe_stdout=True, cmd=args.ffmpeg_bin))
pbar = tqdm(total=len(reader), unit='frame', desc='inference')
while True: while True:
if is_video: img = reader.get_frame()
img_bytes = decoder.stdout.read(width * height * 3) # 3 bytes for one pixel if img is None:
if not img_bytes: break
break
img = np.frombuffer(img_bytes, np.uint8).reshape([height, width, 3])
else:
if idx >= len(paths):
break
img = cv2.imread(paths[idx])
idx += 1
try: try:
if args.face_enhance: if args.face_enhance:
@@ -119,86 +243,60 @@ def inference_stream(args, upsampler, face_enhancer):
print('Error', error) print('Error', error)
print('If you encounter CUDA out of memory, try to set --tile with a smaller number.') print('If you encounter CUDA out of memory, try to set --tile with a smaller number.')
else: else:
output = output.astype(np.uint8).tobytes() writer.write_frame(output)
encoder.stdin.write(output)
torch.cuda.synchronize()
if is_video:
decoder.stdin.close()
decoder.wait()
encoder.stdin.close()
encoder.wait()
def inference_frames(args, upsampler, face_enhancer):
is_video, paths = get_frames(args, extract_frames=True)
video_name = os.path.splitext(os.path.basename(args.input))[0]
# for saving restored frames
save_frame_folder = os.path.join(args.output, video_name, 'frames_tmpout')
os.makedirs(save_frame_folder, exist_ok=True)
timer = AvgTimer()
timer.start()
pbar = tqdm(total=len(paths), unit='frame', desc='inference')
# set up prefetch reader
reader = PrefetchReader(paths, num_prefetch_queue=4)
reader.start()
que = queue.Queue()
consumers = [IOConsumer(args, que, f'IO_{i}') for i in range(args.consumer)]
for consumer in consumers:
consumer.start()
for idx, (path, img) in enumerate(zip(paths, reader)):
imgname, extension = os.path.splitext(os.path.basename(path))
if len(img.shape) == 3 and img.shape[2] == 4:
img_mode = 'RGBA'
else:
img_mode = None
try:
if args.face_enhance:
_, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True)
else:
output, _ = upsampler.enhance(img, outscale=args.outscale)
except RuntimeError as error:
print('Error', error)
print('If you encounter CUDA out of memory, try to set --tile with a smaller number.')
else:
if args.ext == 'auto':
extension = extension[1:]
else:
extension = args.ext
if img_mode == 'RGBA': # RGBA images should be saved in png format
extension = 'png'
save_path = os.path.join(save_frame_folder, f'{imgname}_out.{extension}')
que.put({'output': output, 'save_path': save_path})
torch.cuda.synchronize(device)
pbar.update(1) pbar.update(1)
torch.cuda.synchronize()
timer.record()
avg_fps = 1. / (timer.get_avg_time() + 1e-7)
pbar.set_description(f'idx {idx}, fps {avg_fps:.2f}')
for _ in range(args.consumer): reader.close()
que.put('quit') writer.close()
for consumer in consumers:
consumer.join()
pbar.close()
# merge frames to video
video_save_path = os.path.join(args.output, f'{video_name}_{args.suffix}.mp4') def run(args):
os.system(f'ffmpeg -r {args.fps} -i {save_frame_folder}/frame%08d_out.{extension} -i {args.input}' args.video_name = osp.splitext(os.path.basename(args.input))[0]
f' -map 0:v:0 -map 1:a:0 -c:a copy -c:v libx264 -r {args.fps} -pix_fmt yuv420p {video_save_path}') video_save_path = osp.join(args.output, f'{args.video_name}_{args.suffix}.mp4')
# delete tmp file
shutil.rmtree(save_frame_folder) if args.extract_frame_first:
frame_folder = os.path.join('tmp_frames', video_name) tmp_frames_folder = osp.join(args.output, f'{args.video_name}_inp_tmp_frames')
if os.path.isdir(frame_folder): os.makedirs(tmp_frames_folder, exist_ok=True)
shutil.rmtree(frame_folder) os.system(f'ffmpeg -i {args.input} -qscale:v 1 -qmin 1 -qmax 1 -vsync 0 {tmp_frames_folder}/frame%08d.png')
args.input = tmp_frames_folder
num_gpus = torch.cuda.device_count()
num_process = num_gpus * args.num_process_per_gpu
if num_process == 1:
inference_video(args, video_save_path)
return
ctx = torch.multiprocessing.get_context('spawn')
pool = ctx.Pool(num_process)
os.makedirs(osp.join(args.output, f'{args.video_name}_out_tmp_videos'), exist_ok=True)
pbar = tqdm(total=num_process, unit='sub_video', desc='inference')
for i in range(num_process):
sub_video_save_path = osp.join(args.output, f'{args.video_name}_out_tmp_videos', f'{i:03d}.mp4')
pool.apply_async(
inference_video,
args=(args, sub_video_save_path, torch.device(i % num_gpus), num_process, i),
callback=lambda arg: pbar.update(1))
pool.close()
pool.join()
# combine sub videos
# prepare vidlist.txt
with open(f'{args.output}/{args.video_name}_vidlist.txt', 'w') as f:
for i in range(num_process):
f.write(f'file \'{args.video_name}_out_tmp_videos/{i:03d}.mp4\'\n')
cmd = [
args.ffmpeg_bin, '-f', 'concat', '-safe', '0', '-i', f'{args.output}/{args.video_name}_vidlist.txt', '-c',
'copy', f'{video_save_path}'
]
print(' '.join(cmd))
subprocess.call(cmd)
shutil.rmtree(osp.join(args.output, f'{args.video_name}_out_tmp_videos'))
if osp.exists(osp.join(args.output, f'{args.video_name}_inp_tmp_videos')):
shutil.rmtree(osp.join(args.output, f'{args.video_name}_inp_tmp_videos'))
os.remove(f'{args.output}/{args.video_name}_vidlist.txt')
def main(): def main():
@@ -226,9 +324,9 @@ def main():
parser.add_argument( parser.add_argument(
'--fp32', action='store_true', help='Use fp32 precision during inference. Default: fp16 (half precision).') '--fp32', action='store_true', help='Use fp32 precision during inference. Default: fp16 (half precision).')
parser.add_argument('--fps', type=float, default=None, help='FPS of the output video') parser.add_argument('--fps', type=float, default=None, help='FPS of the output video')
parser.add_argument('--consumer', type=int, default=4, help='Number of IO consumers')
parser.add_argument('--stream', action='store_true')
parser.add_argument('--ffmpeg_bin', type=str, default='ffmpeg', help='The path to ffmpeg') parser.add_argument('--ffmpeg_bin', type=str, default='ffmpeg', help='The path to ffmpeg')
parser.add_argument('--extract_frame_first', action='store_true')
parser.add_argument('--num_process_per_gpu', type=int, default=1)
parser.add_argument( parser.add_argument(
'--alpha_upsampler', '--alpha_upsampler',
@@ -243,61 +341,21 @@ def main():
args = parser.parse_args() args = parser.parse_args()
args.input = args.input.rstrip('/').rstrip('\\') args.input = args.input.rstrip('/').rstrip('\\')
# ---------------------- determine models according to model names ---------------------- #
args.model_name = args.model_name.split('.pth')[0]
if args.model_name in ['RealESRGAN_x4plus', 'RealESRNet_x4plus']: # x4 RRDBNet model
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
netscale = 4
elif args.model_name in ['RealESRGAN_x4plus_anime_6B']: # x4 RRDBNet model with 6 blocks
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
netscale = 4
elif args.model_name in ['RealESRGAN_x2plus']: # x2 RRDBNet model
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
netscale = 2
elif args.model_name in ['realesr-animevideov3']: # x4 VGG-style model (XS size)
model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu')
netscale = 4
# ---------------------- determine model paths ---------------------- #
model_path = os.path.join('experiments/pretrained_models', args.model_name + '.pth')
if not os.path.isfile(model_path):
model_path = os.path.join('realesrgan/weights', args.model_name + '.pth')
if not os.path.isfile(model_path):
raise ValueError(f'Model {args.model_name} does not exist.')
# restorer
upsampler = RealESRGANer(
scale=netscale,
model_path=model_path,
model=model,
tile=args.tile,
tile_pad=args.tile_pad,
pre_pad=args.pre_pad,
half=not args.fp32)
if 'anime' in args.model_name and args.face_enhance:
print('face_enhance is not supported in anime models, we turned this option off for you. '
'if you insist on turning it on, please manually comment the relevant lines of code.')
args.face_enhance = False
if args.face_enhance: # Use GFPGAN for face enhancement
from gfpgan import GFPGANer
face_enhancer = GFPGANer(
model_path='https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth',
upscale=args.outscale,
arch='clean',
channel_multiplier=2,
bg_upsampler=upsampler)
else:
face_enhancer = None
os.makedirs(args.output, exist_ok=True) os.makedirs(args.output, exist_ok=True)
if args.stream: if mimetypes.guess_type(args.input)[0] is not None and mimetypes.guess_type(args.input)[0].startswith('video'):
inference_stream(args, upsampler, face_enhancer) is_video = True
else: else:
inference_frames(args, upsampler, face_enhancer) is_video = False
if args.extract_frame_first and not is_video:
args.extract_frame_first = False
run(args)
if args.extract_frame_first:
tmp_frames_folder = osp.join(args.output, f'{args.video_name}_inp_tmp_frames')
shutil.rmtree(tmp_frames_folder)
if __name__ == '__main__': if __name__ == '__main__':

View File

@@ -26,7 +26,7 @@ class RealESRGANer():
half (float): Whether to use half precision during inference. Default: False. half (float): Whether to use half precision during inference. Default: False.
""" """
def __init__(self, scale, model_path, model=None, tile=0, tile_pad=10, pre_pad=10, half=False): def __init__(self, scale, model_path, model=None, tile=0, tile_pad=10, pre_pad=10, half=False, device=None):
self.scale = scale self.scale = scale
self.tile_size = tile self.tile_size = tile
self.tile_pad = tile_pad self.tile_pad = tile_pad
@@ -35,7 +35,7 @@ class RealESRGANer():
self.half = half self.half = half
# initialize model # initialize model
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device
# if the model_path starts with https, it will first download models to the folder: realesrgan/weights # if the model_path starts with https, it will first download models to the folder: realesrgan/weights
if model_path.startswith('https://'): if model_path.startswith('https://'):
model_path = load_file_from_url( model_path = load_file_from_url(