diff --git a/waifu2x.py b/waifu2x.py index 2b3c1a7..0fa33ca 100644 --- a/waifu2x.py +++ b/waifu2x.py @@ -129,7 +129,9 @@ def load_models(cfg): p = argparse.ArgumentParser() p.add_argument('--gpu', '-g', type=int, default=-1) p.add_argument('--input', '-i', default='images/small.png') -p.add_argument('--output', '-o', default='./') +p.add_argument('--output_dir', '-o', default='./') +p.add_argument('--extension', '-e', default='png') +p.add_argument('--quality', '-q', type=int, default=100) p.add_argument('--arch', '-a', choices=['VGG7', '0', 'UpConv7', '1', 'ResNet10', '2', 'UpResNet10', '3'], @@ -157,15 +159,13 @@ def load_models(cfg): if __name__ == '__main__': models = load_models(args) - if '.png' not in args.output: - if not os.path.exists(args.output): - os.makedirs(args.output) - else: - dirname = os.path.dirname(args.output) - if len(dirname) == 0: - dirname = './' - if not os.path.exists(dirname): - os.makedirs(dirname) + extensions = ['.png', '.jpg', '.jpeg', '.bmp', '.tif', '.tiff', '.webp'] + if args.extension not in ['png', 'webp']: + raise ValueError('{} format is not supported'.format(args.extension)) + + if not os.path.exists(args.output_dir): + os.makedirs(args.output_dir) + if os.path.isdir(args.input): filelist = utils.load_filelist(args.input) else: @@ -173,14 +173,16 @@ def load_models(cfg): for path in filelist: src = Image.open(path) + icc_profile = src.info.get('icc_profile') w, h = src.size[:2] if args.width != 0: args.scale_ratio = args.width / w if args.height != 0: args.scale_ratio = args.height / h - basename = os.path.basename(path) - outname, ext = os.path.splitext(basename) - if ext.lower() in ['.png', '.jpg', '.jpeg', '.bmp', '.tif', '.tiff']: + outname, ext = os.path.splitext(os.path.basename(path)) + basepath = os.path.join( + args.output_dir, '{}.{}'.format(outname, args.extension)) + if ext.lower() in extensions: outname += '_(tta{})'.format(args.tta_level) if args.tta else '_' dst = src.copy() start = time.time() @@ -198,14 +200,16 @@ def load_models(cfg): dst = upscale_image(args, dst, models['scale']) print('Elapsed time: {:.6f} sec'.format(time.time() - start)) - outname += '({}_{}).png'.format(args.arch.lower(), args.color) - if os.path.isdir(args.output): - outpath = os.path.join(args.output, outname) + outname += '({}_{}).{}'.format( + args.arch.lower(), args.color, args.extension) + outpath = os.path.join(args.output_dir, outname) + if not os.path.exists(basepath): + outpath = basepath + if icc_profile is not None: + dst.convert(src.mode).save( + outpath, quality=args.quality, lossloss=True, method=6, + icc_profile=icc_profile) else: - if os.path.exists(args.output): - outpath = os.path.join(dirname, outname) - else: - outpath = args.output - dst.convert(src.mode).save( - outpath, icc_profile=src.info.get('icc_profile')) + dst.convert(src.mode).save( + outpath, quality=args.quality, lossloss=True, method=6) six.print_('Saved as \'{}\''.format(outpath))