diff --git a/requirements.txt b/requirements.txt index 402f547..f4ed4c7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,3 +5,4 @@ numpy opencv-python Pillow torch>=1.7 +tqdm diff --git a/scripts/extract_subimages.py b/scripts/extract_subimages.py new file mode 100644 index 0000000..9630f95 --- /dev/null +++ b/scripts/extract_subimages.py @@ -0,0 +1,151 @@ +import argparse +import cv2 +import numpy as np +import os +import sys +from basicsr.utils import scandir +from multiprocessing import Pool +from os import path as osp +from tqdm import tqdm + + +def main(args): + """A multi-thread tool to crop large images to sub-images for faster IO. + + opt (dict): Configuration dict. It contains: + n_thread (int): Thread number. + compression_level (int): CV_IMWRITE_PNG_COMPRESSION from 0 to 9. + A higher value means a smaller size and longer compression time. + Use 0 for faster CPU decompression. Default: 3, same in cv2. + + input_folder (str): Path to the input folder. + save_folder (str): Path to save folder. + crop_size (int): Crop size. + step (int): Step for overlapped sliding window. + thresh_size (int): Threshold size. Patches whose size is lower + than thresh_size will be dropped. + + Usage: + For each folder, run this script. + Typically, there are four folders to be processed for DIV2K dataset. + DIV2K_train_HR + DIV2K_train_LR_bicubic/X2 + DIV2K_train_LR_bicubic/X3 + DIV2K_train_LR_bicubic/X4 + After process, each sub_folder should have the same number of + subimages. + Remember to modify opt configurations according to your settings. + """ + + opt = {} + opt['n_thread'] = args.n_thread + opt['compression_level'] = args.compression_level + + # HR images + opt['input_folder'] = args.input + opt['save_folder'] = args.output + opt['crop_size'] = args.crop_size + opt['step'] = args.step + opt['thresh_size'] = args.thresh_size + extract_subimages(opt) + + +def extract_subimages(opt): + """Crop images to subimages. + + Args: + opt (dict): Configuration dict. It contains: + input_folder (str): Path to the input folder. + save_folder (str): Path to save folder. + n_thread (int): Thread number. + """ + input_folder = opt['input_folder'] + save_folder = opt['save_folder'] + if not osp.exists(save_folder): + os.makedirs(save_folder) + print(f'mkdir {save_folder} ...') + else: + print(f'Folder {save_folder} already exists. Exit.') + sys.exit(1) + + img_list = list(scandir(input_folder, full_path=True)) + + pbar = tqdm(total=len(img_list), unit='image', desc='Extract') + pool = Pool(opt['n_thread']) + for path in img_list: + pool.apply_async(worker, args=(path, opt), callback=lambda arg: pbar.update(1)) + pool.close() + pool.join() + pbar.close() + print('All processes done.') + + +def worker(path, opt): + """Worker for each process. + + Args: + path (str): Image path. + opt (dict): Configuration dict. It contains: + crop_size (int): Crop size. + step (int): Step for overlapped sliding window. + thresh_size (int): Threshold size. Patches whose size is lower + than thresh_size will be dropped. + save_folder (str): Path to save folder. + compression_level (int): for cv2.IMWRITE_PNG_COMPRESSION. + + Returns: + process_info (str): Process information displayed in progress bar. + """ + crop_size = opt['crop_size'] + step = opt['step'] + thresh_size = opt['thresh_size'] + img_name, extension = osp.splitext(osp.basename(path)) + + # remove the x2, x3, x4 and x8 in the filename for DIV2K + img_name = img_name.replace('x2', '').replace('x3', '').replace('x4', '').replace('x8', '') + + img = cv2.imread(path, cv2.IMREAD_UNCHANGED) + + if img.ndim == 2: + h, w = img.shape + elif img.ndim == 3: + h, w, c = img.shape + else: + raise ValueError(f'Image ndim should be 2 or 3, but got {img.ndim}') + + h_space = np.arange(0, h - crop_size + 1, step) + if h - (h_space[-1] + crop_size) > thresh_size: + h_space = np.append(h_space, h - crop_size) + w_space = np.arange(0, w - crop_size + 1, step) + if w - (w_space[-1] + crop_size) > thresh_size: + w_space = np.append(w_space, w - crop_size) + + index = 0 + for x in h_space: + for y in w_space: + index += 1 + cropped_img = img[x:x + crop_size, y:y + crop_size, ...] + cropped_img = np.ascontiguousarray(cropped_img) + cv2.imwrite( + osp.join(opt['save_folder'], f'{img_name}_s{index:03d}{extension}'), cropped_img, + [cv2.IMWRITE_PNG_COMPRESSION, opt['compression_level']]) + process_info = f'Processing {img_name} ...' + return process_info + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--input', type=str, default='datasets/DF2K/DF2K_HR', help='Input folder') + parser.add_argument('--output', type=str, default='datasets/DF2K/DF2K_HR_sub', help='Output folder') + parser.add_argument('--crop_size', type=int, default=480, help='Crop size') + parser.add_argument('--step', type=int, default=240, help='Step for overlapped sliding window') + parser.add_argument( + '--thresh_size', + type=int, + default=0, + help='Threshold size. Patches whose size is lower than thresh_size will be dropped.') + parser.add_argument('--n_thread', type=int, default=20, help='Thread number.') + parser.add_argument('--compression_level', type=int, default=3, help='Compression level') + args = parser.parse_args() + + main(args) diff --git a/setup.cfg b/setup.cfg index 4dbe63d..5dcf1ab 100644 --- a/setup.cfg +++ b/setup.cfg @@ -17,6 +17,6 @@ line_length = 120 multi_line_output = 0 known_standard_library = pkg_resources,setuptools known_first_party = realesrgan -known_third_party = PIL,basicsr,cv2,numpy,torch +known_third_party = PIL,basicsr,cv2,numpy,torch,tqdm no_lines_before = STDLIB,LOCALFOLDER default_section = THIRDPARTY