Support outscale; Add RealESRGANx2 model; Version 0.2.1

This commit is contained in:
Xintao
2021-08-08 21:30:51 +08:00
parent 5745599813
commit 64ad194dda
4 changed files with 36 additions and 18 deletions

View File

@@ -15,6 +15,7 @@ We extend the powerful ESRGAN to a practical restoration application (namely, Re
:triangular_flag_on_post: **Updates** :triangular_flag_on_post: **Updates**
- :white_check_mark: Support arbitrary scale with `--outscale` (It actually further resizes outputs with `LANCZOS4`). Add *RealESRGAN_x2plus.pth* model.
- :white_check_mark: [The inference code](inference_realesrgan.py) supports: 1) **tile** options; 2) images with **alpha channel**; 3) **gray** images; 4) **16-bit** images. - :white_check_mark: [The inference code](inference_realesrgan.py) supports: 1) **tile** options; 2) images with **alpha channel**; 3) **gray** images; 4) **16-bit** images.
- :white_check_mark: The training codes have been released. A detailed guide can be found in [Training.md](Training.md). - :white_check_mark: The training codes have been released. A detailed guide can be found in [Training.md](Training.md).
@@ -124,6 +125,7 @@ Results are in the `results` folder
- [RealESRGAN-x4plus](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth) - [RealESRGAN-x4plus](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth)
- [RealESRNet-x4plus](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth) - [RealESRNet-x4plus](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth)
- [RealESRGAN-x2plus](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.0/RealESRGAN_x2plus.pth)
- [official ESRGAN-x4](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth) - [official ESRGAN-x4](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth)
## :computer: Training ## :computer: Training

View File

@@ -1 +1 @@
0.2.0 0.2.1

View File

@@ -15,7 +15,8 @@ def main():
default='experiments/pretrained_models/RealESRGAN_x4plus.pth', default='experiments/pretrained_models/RealESRGAN_x4plus.pth',
help='Path to the pre-trained model') help='Path to the pre-trained model')
parser.add_argument('--output', type=str, default='results', help='Output folder') parser.add_argument('--output', type=str, default='results', help='Output folder')
parser.add_argument('--scale', type=int, default=4, help='Upsample scale factor') parser.add_argument('--netscale', type=int, default=4, help='Upsample scale factor of the network')
parser.add_argument('--outscale', type=float, default=4, help='The final upsampling scale of the image')
parser.add_argument('--suffix', type=str, default='out', help='Suffix of the restored image') parser.add_argument('--suffix', type=str, default='out', help='Suffix of the restored image')
parser.add_argument('--tile', type=int, default=0, help='Tile size, 0 for no tile during testing') parser.add_argument('--tile', type=int, default=0, help='Tile size, 0 for no tile during testing')
parser.add_argument('--tile_pad', type=int, default=10, help='Tile padding') parser.add_argument('--tile_pad', type=int, default=10, help='Tile padding')
@@ -34,7 +35,7 @@ def main():
args = parser.parse_args() args = parser.parse_args()
upsampler = RealESRGANer( upsampler = RealESRGANer(
scale=args.scale, scale=args.netscale,
model_path=args.model_path, model_path=args.model_path,
tile=args.tile, tile=args.tile,
tile_pad=args.tile_pad, tile_pad=args.tile_pad,
@@ -51,15 +52,25 @@ def main():
print('Testing', idx, imgname) print('Testing', idx, imgname)
img = cv2.imread(path, cv2.IMREAD_UNCHANGED) img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
output, img_mode = upsampler.enhance(img) h, w = img.shape[0:2]
if args.ext == 'auto': if max(h, w) > 1000 and args.netscale == 4:
extension = extension[1:] print('WARNING: The input image is large, try X2 model for better performace.')
if max(h, w) < 500 and args.netscale == 2:
print('WARNING: The input image is small, try X4 model for better performace.')
try:
output, img_mode = upsampler.enhance(img, outscale=args.outscale)
except Exception as error:
print('Error', error)
else: else:
extension = args.ext if args.ext == 'auto':
if img_mode == 'RGBA': # RGBA images should be saved in png format extension = extension[1:]
extension = 'png' else:
save_path = os.path.join(args.output, f'{imgname}_{args.suffix}.{extension}') extension = args.ext
cv2.imwrite(save_path, output) if img_mode == 'RGBA': # RGBA images should be saved in png format
extension = 'png'
save_path = os.path.join(args.output, f'{imgname}_{args.suffix}.{extension}')
cv2.imwrite(save_path, output)
if __name__ == '__main__': if __name__ == '__main__':

View File

@@ -63,12 +63,7 @@ class RealESRGANer():
self.img = F.pad(self.img, (0, self.mod_pad_w, 0, self.mod_pad_h), 'reflect') self.img = F.pad(self.img, (0, self.mod_pad_w, 0, self.mod_pad_h), 'reflect')
def process(self): def process(self):
try: self.output = self.model(self.img)
# inference
with torch.no_grad():
self.output = self.model(self.img)
except Exception as error:
print('Error', error)
def tile_process(self): def tile_process(self):
"""Modified from: https://github.com/ata4/esrgan-launcher """Modified from: https://github.com/ata4/esrgan-launcher
@@ -143,7 +138,9 @@ class RealESRGANer():
self.output = self.output[:, :, 0:h - self.pre_pad * self.scale, 0:w - self.pre_pad * self.scale] self.output = self.output[:, :, 0:h - self.pre_pad * self.scale, 0:w - self.pre_pad * self.scale]
return self.output return self.output
def enhance(self, img, tile=False, alpha_upsampler='realesrgan'): @torch.no_grad()
def enhance(self, img, outscale=None, alpha_upsampler='realesrgan'):
h_input, w_input = img.shape[0:2]
# img: numpy # img: numpy
img = img.astype(np.float32) img = img.astype(np.float32)
if np.max(img) > 255: # 16-bit image if np.max(img) > 255: # 16-bit image
@@ -203,6 +200,14 @@ class RealESRGANer():
output = (output_img * 65535.0).round().astype(np.uint16) output = (output_img * 65535.0).round().astype(np.uint16)
else: else:
output = (output_img * 255.0).round().astype(np.uint8) output = (output_img * 255.0).round().astype(np.uint8)
if outscale is not None and outscale != float(self.scale):
output = cv2.resize(
output, (
int(w_input * outscale),
int(h_input * outscale),
), interpolation=cv2.INTER_LANCZOS4)
return output, img_mode return output, img_mode