add extract_subimages
This commit is contained in:
@@ -5,3 +5,4 @@ numpy
|
||||
opencv-python
|
||||
Pillow
|
||||
torch>=1.7
|
||||
tqdm
|
||||
|
||||
151
scripts/extract_subimages.py
Normal file
151
scripts/extract_subimages.py
Normal file
@@ -0,0 +1,151 @@
|
||||
import argparse
|
||||
import cv2
|
||||
import numpy as np
|
||||
import os
|
||||
import sys
|
||||
from basicsr.utils import scandir
|
||||
from multiprocessing import Pool
|
||||
from os import path as osp
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def main(args):
|
||||
"""A multi-thread tool to crop large images to sub-images for faster IO.
|
||||
|
||||
opt (dict): Configuration dict. It contains:
|
||||
n_thread (int): Thread number.
|
||||
compression_level (int): CV_IMWRITE_PNG_COMPRESSION from 0 to 9.
|
||||
A higher value means a smaller size and longer compression time.
|
||||
Use 0 for faster CPU decompression. Default: 3, same in cv2.
|
||||
|
||||
input_folder (str): Path to the input folder.
|
||||
save_folder (str): Path to save folder.
|
||||
crop_size (int): Crop size.
|
||||
step (int): Step for overlapped sliding window.
|
||||
thresh_size (int): Threshold size. Patches whose size is lower
|
||||
than thresh_size will be dropped.
|
||||
|
||||
Usage:
|
||||
For each folder, run this script.
|
||||
Typically, there are four folders to be processed for DIV2K dataset.
|
||||
DIV2K_train_HR
|
||||
DIV2K_train_LR_bicubic/X2
|
||||
DIV2K_train_LR_bicubic/X3
|
||||
DIV2K_train_LR_bicubic/X4
|
||||
After process, each sub_folder should have the same number of
|
||||
subimages.
|
||||
Remember to modify opt configurations according to your settings.
|
||||
"""
|
||||
|
||||
opt = {}
|
||||
opt['n_thread'] = args.n_thread
|
||||
opt['compression_level'] = args.compression_level
|
||||
|
||||
# HR images
|
||||
opt['input_folder'] = args.input
|
||||
opt['save_folder'] = args.output
|
||||
opt['crop_size'] = args.crop_size
|
||||
opt['step'] = args.step
|
||||
opt['thresh_size'] = args.thresh_size
|
||||
extract_subimages(opt)
|
||||
|
||||
|
||||
def extract_subimages(opt):
|
||||
"""Crop images to subimages.
|
||||
|
||||
Args:
|
||||
opt (dict): Configuration dict. It contains:
|
||||
input_folder (str): Path to the input folder.
|
||||
save_folder (str): Path to save folder.
|
||||
n_thread (int): Thread number.
|
||||
"""
|
||||
input_folder = opt['input_folder']
|
||||
save_folder = opt['save_folder']
|
||||
if not osp.exists(save_folder):
|
||||
os.makedirs(save_folder)
|
||||
print(f'mkdir {save_folder} ...')
|
||||
else:
|
||||
print(f'Folder {save_folder} already exists. Exit.')
|
||||
sys.exit(1)
|
||||
|
||||
img_list = list(scandir(input_folder, full_path=True))
|
||||
|
||||
pbar = tqdm(total=len(img_list), unit='image', desc='Extract')
|
||||
pool = Pool(opt['n_thread'])
|
||||
for path in img_list:
|
||||
pool.apply_async(worker, args=(path, opt), callback=lambda arg: pbar.update(1))
|
||||
pool.close()
|
||||
pool.join()
|
||||
pbar.close()
|
||||
print('All processes done.')
|
||||
|
||||
|
||||
def worker(path, opt):
|
||||
"""Worker for each process.
|
||||
|
||||
Args:
|
||||
path (str): Image path.
|
||||
opt (dict): Configuration dict. It contains:
|
||||
crop_size (int): Crop size.
|
||||
step (int): Step for overlapped sliding window.
|
||||
thresh_size (int): Threshold size. Patches whose size is lower
|
||||
than thresh_size will be dropped.
|
||||
save_folder (str): Path to save folder.
|
||||
compression_level (int): for cv2.IMWRITE_PNG_COMPRESSION.
|
||||
|
||||
Returns:
|
||||
process_info (str): Process information displayed in progress bar.
|
||||
"""
|
||||
crop_size = opt['crop_size']
|
||||
step = opt['step']
|
||||
thresh_size = opt['thresh_size']
|
||||
img_name, extension = osp.splitext(osp.basename(path))
|
||||
|
||||
# remove the x2, x3, x4 and x8 in the filename for DIV2K
|
||||
img_name = img_name.replace('x2', '').replace('x3', '').replace('x4', '').replace('x8', '')
|
||||
|
||||
img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
|
||||
|
||||
if img.ndim == 2:
|
||||
h, w = img.shape
|
||||
elif img.ndim == 3:
|
||||
h, w, c = img.shape
|
||||
else:
|
||||
raise ValueError(f'Image ndim should be 2 or 3, but got {img.ndim}')
|
||||
|
||||
h_space = np.arange(0, h - crop_size + 1, step)
|
||||
if h - (h_space[-1] + crop_size) > thresh_size:
|
||||
h_space = np.append(h_space, h - crop_size)
|
||||
w_space = np.arange(0, w - crop_size + 1, step)
|
||||
if w - (w_space[-1] + crop_size) > thresh_size:
|
||||
w_space = np.append(w_space, w - crop_size)
|
||||
|
||||
index = 0
|
||||
for x in h_space:
|
||||
for y in w_space:
|
||||
index += 1
|
||||
cropped_img = img[x:x + crop_size, y:y + crop_size, ...]
|
||||
cropped_img = np.ascontiguousarray(cropped_img)
|
||||
cv2.imwrite(
|
||||
osp.join(opt['save_folder'], f'{img_name}_s{index:03d}{extension}'), cropped_img,
|
||||
[cv2.IMWRITE_PNG_COMPRESSION, opt['compression_level']])
|
||||
process_info = f'Processing {img_name} ...'
|
||||
return process_info
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--input', type=str, default='datasets/DF2K/DF2K_HR', help='Input folder')
|
||||
parser.add_argument('--output', type=str, default='datasets/DF2K/DF2K_HR_sub', help='Output folder')
|
||||
parser.add_argument('--crop_size', type=int, default=480, help='Crop size')
|
||||
parser.add_argument('--step', type=int, default=240, help='Step for overlapped sliding window')
|
||||
parser.add_argument(
|
||||
'--thresh_size',
|
||||
type=int,
|
||||
default=0,
|
||||
help='Threshold size. Patches whose size is lower than thresh_size will be dropped.')
|
||||
parser.add_argument('--n_thread', type=int, default=20, help='Thread number.')
|
||||
parser.add_argument('--compression_level', type=int, default=3, help='Compression level')
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
||||
@@ -17,6 +17,6 @@ line_length = 120
|
||||
multi_line_output = 0
|
||||
known_standard_library = pkg_resources,setuptools
|
||||
known_first_party = realesrgan
|
||||
known_third_party = PIL,basicsr,cv2,numpy,torch
|
||||
known_third_party = PIL,basicsr,cv2,numpy,torch,tqdm
|
||||
no_lines_before = STDLIB,LOCALFOLDER
|
||||
default_section = THIRDPARTY
|
||||
|
||||
Reference in New Issue
Block a user