add inference_realesrgan_video
This commit is contained in:
@@ -3,4 +3,4 @@ from .archs import *
|
||||
from .data import *
|
||||
from .models import *
|
||||
from .utils import *
|
||||
from .version import __version__
|
||||
from .version import *
|
||||
|
||||
@@ -2,6 +2,8 @@ import cv2
|
||||
import math
|
||||
import numpy as np
|
||||
import os
|
||||
import queue
|
||||
import threading
|
||||
import torch
|
||||
from basicsr.utils.download_util import load_file_from_url
|
||||
from torch.nn import functional as F
|
||||
@@ -38,7 +40,7 @@ class RealESRGANer():
|
||||
if model_path.startswith('https://'):
|
||||
model_path = load_file_from_url(
|
||||
url=model_path, model_dir=os.path.join(ROOT_DIR, 'realesrgan/weights'), progress=True, file_name=None)
|
||||
loadnet = torch.load(model_path)
|
||||
loadnet = torch.load(model_path, map_location=torch.device('cpu'))
|
||||
# prefer to use params_ema
|
||||
if 'params_ema' in loadnet:
|
||||
keyname = 'params_ema'
|
||||
@@ -226,3 +228,53 @@ class RealESRGANer():
|
||||
), interpolation=cv2.INTER_LANCZOS4)
|
||||
|
||||
return output, img_mode
|
||||
|
||||
|
||||
class PrefetchReader(threading.Thread):
|
||||
"""Prefetch images.
|
||||
|
||||
Args:
|
||||
img_list (list[str]): A image list of image paths to be read.
|
||||
num_prefetch_queue (int): Number of prefetch queue.
|
||||
"""
|
||||
|
||||
def __init__(self, img_list, num_prefetch_queue):
|
||||
super().__init__()
|
||||
self.que = queue.Queue(num_prefetch_queue)
|
||||
self.img_list = img_list
|
||||
|
||||
def run(self):
|
||||
for img_path in self.img_list:
|
||||
img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
|
||||
self.que.put(img)
|
||||
|
||||
self.que.put(None)
|
||||
|
||||
def __next__(self):
|
||||
next_item = self.que.get()
|
||||
if next_item is None:
|
||||
raise StopIteration
|
||||
return next_item
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
|
||||
class IOConsumer(threading.Thread):
|
||||
|
||||
def __init__(self, opt, que, qid):
|
||||
super().__init__()
|
||||
self._queue = que
|
||||
self.qid = qid
|
||||
self.opt = opt
|
||||
|
||||
def run(self):
|
||||
while True:
|
||||
msg = self._queue.get()
|
||||
if isinstance(msg, str) and msg == 'quit':
|
||||
break
|
||||
|
||||
output = msg['output']
|
||||
save_path = msg['save_path']
|
||||
cv2.imwrite(save_path, output)
|
||||
print(f'IO worker {self.qid} is done.')
|
||||
|
||||
Reference in New Issue
Block a user