improve codes comments
This commit is contained in:
@@ -14,34 +14,24 @@ def main(args):
|
||||
|
||||
opt (dict): Configuration dict. It contains:
|
||||
n_thread (int): Thread number.
|
||||
compression_level (int): CV_IMWRITE_PNG_COMPRESSION from 0 to 9.
|
||||
A higher value means a smaller size and longer compression time.
|
||||
Use 0 for faster CPU decompression. Default: 3, same in cv2.
|
||||
|
||||
compression_level (int): CV_IMWRITE_PNG_COMPRESSION from 0 to 9. A higher value means a smaller size
|
||||
and longer compression time. Use 0 for faster CPU decompression. Default: 3, same in cv2.
|
||||
input_folder (str): Path to the input folder.
|
||||
save_folder (str): Path to save folder.
|
||||
crop_size (int): Crop size.
|
||||
step (int): Step for overlapped sliding window.
|
||||
thresh_size (int): Threshold size. Patches whose size is lower
|
||||
than thresh_size will be dropped.
|
||||
thresh_size (int): Threshold size. Patches whose size is lower than thresh_size will be dropped.
|
||||
|
||||
Usage:
|
||||
For each folder, run this script.
|
||||
Typically, there are four folders to be processed for DIV2K dataset.
|
||||
DIV2K_train_HR
|
||||
DIV2K_train_LR_bicubic/X2
|
||||
DIV2K_train_LR_bicubic/X3
|
||||
DIV2K_train_LR_bicubic/X4
|
||||
After process, each sub_folder should have the same number of
|
||||
subimages.
|
||||
Typically, there are GT folder and LQ folder to be processed for DIV2K dataset.
|
||||
After process, each sub_folder should have the same number of subimages.
|
||||
Remember to modify opt configurations according to your settings.
|
||||
"""
|
||||
|
||||
opt = {}
|
||||
opt['n_thread'] = args.n_thread
|
||||
opt['compression_level'] = args.compression_level
|
||||
|
||||
# HR images
|
||||
opt['input_folder'] = args.input
|
||||
opt['save_folder'] = args.output
|
||||
opt['crop_size'] = args.crop_size
|
||||
@@ -68,6 +58,7 @@ def extract_subimages(opt):
|
||||
print(f'Folder {save_folder} already exists. Exit.')
|
||||
sys.exit(1)
|
||||
|
||||
# scan all images
|
||||
img_list = list(scandir(input_folder, full_path=True))
|
||||
|
||||
pbar = tqdm(total=len(img_list), unit='image', desc='Extract')
|
||||
@@ -88,8 +79,7 @@ def worker(path, opt):
|
||||
opt (dict): Configuration dict. It contains:
|
||||
crop_size (int): Crop size.
|
||||
step (int): Step for overlapped sliding window.
|
||||
thresh_size (int): Threshold size. Patches whose size is lower
|
||||
than thresh_size will be dropped.
|
||||
thresh_size (int): Threshold size. Patches whose size is lower than thresh_size will be dropped.
|
||||
save_folder (str): Path to save folder.
|
||||
compression_level (int): for cv2.IMWRITE_PNG_COMPRESSION.
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@ def main(args):
|
||||
for img_path in img_paths:
|
||||
status = True
|
||||
if args.check:
|
||||
# read the image once for check, as some images may have errors
|
||||
try:
|
||||
img = cv2.imread(img_path)
|
||||
except Exception as error:
|
||||
@@ -20,6 +21,7 @@ def main(args):
|
||||
status = False
|
||||
print(f'Img is None: {img_path}')
|
||||
if status:
|
||||
# get the relative path
|
||||
img_name = os.path.relpath(img_path, root)
|
||||
print(img_name)
|
||||
txt_file.write(f'{img_name}\n')
|
||||
|
||||
@@ -5,6 +5,7 @@ import os
|
||||
|
||||
def main(args):
|
||||
txt_file = open(args.meta_info, 'w')
|
||||
# sca images
|
||||
img_paths_gt = sorted(glob.glob(os.path.join(args.input[0], '*')))
|
||||
img_paths_lq = sorted(glob.glob(os.path.join(args.input[1], '*')))
|
||||
|
||||
@@ -12,6 +13,7 @@ def main(args):
|
||||
f'{len(img_paths_gt)} and {len(img_paths_lq)}.')
|
||||
|
||||
for img_path_gt, img_path_lq in zip(img_paths_gt, img_paths_lq):
|
||||
# get the relative paths
|
||||
img_name_gt = os.path.relpath(img_path_gt, args.root[0])
|
||||
img_name_lq = os.path.relpath(img_path_lq, args.root[1])
|
||||
print(f'{img_name_gt}, {img_name_lq}')
|
||||
@@ -19,7 +21,7 @@ def main(args):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
"""Generate meta info (txt file) for paired images.
|
||||
"""This script is used to generate meta info (txt file) for paired images.
|
||||
"""
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
|
||||
@@ -5,7 +5,6 @@ from PIL import Image
|
||||
|
||||
|
||||
def main(args):
|
||||
|
||||
# For DF2K, we consider the following three scales,
|
||||
# and the smallest image whose shortest edge is 400
|
||||
scale_list = [0.75, 0.5, 1 / 3]
|
||||
@@ -37,6 +36,9 @@ def main(args):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
"""Generate multi-scale versions for GT images with LANCZOS resampling.
|
||||
It is now used for DF2K dataset (DIV2K + Flickr 2K)
|
||||
"""
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--input', type=str, default='datasets/DF2K/DF2K_HR', help='Input folder')
|
||||
parser.add_argument('--output', type=str, default='datasets/DF2K/DF2K_multiscale', help='Output folder')
|
||||
|
||||
@@ -1,17 +1,36 @@
|
||||
import argparse
|
||||
import torch
|
||||
import torch.onnx
|
||||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||
|
||||
# An instance of your model
|
||||
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
|
||||
model.load_state_dict(torch.load('experiments/pretrained_models/RealESRGAN_x4plus.pth')['params_ema'])
|
||||
# set the train mode to false since we will only run the forward pass.
|
||||
model.train(False)
|
||||
model.cpu().eval()
|
||||
|
||||
# An example input you would normally provide to your model's forward() method
|
||||
x = torch.rand(1, 3, 64, 64)
|
||||
def main(args):
|
||||
# An instance of the model
|
||||
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
|
||||
if args.params:
|
||||
keyname = 'params'
|
||||
else:
|
||||
keyname = 'params_ema'
|
||||
model.load_state_dict(torch.load(args.input)[keyname])
|
||||
# set the train mode to false since we will only run the forward pass.
|
||||
model.train(False)
|
||||
model.cpu().eval()
|
||||
|
||||
# Export the model
|
||||
with torch.no_grad():
|
||||
torch_out = torch.onnx._export(model, x, 'realesrgan-x4.onnx', opset_version=11, export_params=True)
|
||||
# An example input
|
||||
x = torch.rand(1, 3, 64, 64)
|
||||
# Export the model
|
||||
with torch.no_grad():
|
||||
torch_out = torch.onnx._export(model, x, args.output, opset_version=11, export_params=True)
|
||||
print(torch_out.shape)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
"""Convert pytorch model to onnx models"""
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
'--input', type=str, default='experiments/pretrained_models/RealESRGAN_x4plus.pth', help='Input model path')
|
||||
parser.add_argument('--output', type=str, default='realesrgan-x4.onnx', help='Output onnx path')
|
||||
parser.add_argument('--params', action='store_false', help='Use params instead of params_ema')
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
||||
|
||||
Reference in New Issue
Block a user