Skip to content

Commit

Permalink
Add support for WebP format
Browse files Browse the repository at this point in the history
  • Loading branch information
tsurumeso committed Nov 6, 2018
1 parent 9029a20 commit 9bb635d
Showing 1 changed file with 26 additions and 22 deletions.
48 changes: 26 additions & 22 deletions waifu2x.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,9 @@ def load_models(cfg):
p = argparse.ArgumentParser()
p.add_argument('--gpu', '-g', type=int, default=-1)
p.add_argument('--input', '-i', default='images/small.png')
p.add_argument('--output', '-o', default='./')
p.add_argument('--output_dir', '-o', default='./')
p.add_argument('--extension', '-e', default='png')
p.add_argument('--quality', '-q', type=int, default=100)
p.add_argument('--arch', '-a',
choices=['VGG7', '0', 'UpConv7', '1',
'ResNet10', '2', 'UpResNet10', '3'],
Expand Down Expand Up @@ -157,30 +159,30 @@ def load_models(cfg):
if __name__ == '__main__':
models = load_models(args)

if '.png' not in args.output:
if not os.path.exists(args.output):
os.makedirs(args.output)
else:
dirname = os.path.dirname(args.output)
if len(dirname) == 0:
dirname = './'
if not os.path.exists(dirname):
os.makedirs(dirname)
extensions = ['.png', '.jpg', '.jpeg', '.bmp', '.tif', '.tiff', '.webp']
if args.extension not in ['png', 'webp']:
raise ValueError('{} format is not supported'.format(args.extension))

if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)

if os.path.isdir(args.input):
filelist = utils.load_filelist(args.input)
else:
filelist = [args.input]

for path in filelist:
src = Image.open(path)
icc_profile = src.info.get('icc_profile')
w, h = src.size[:2]
if args.width != 0:
args.scale_ratio = args.width / w
if args.height != 0:
args.scale_ratio = args.height / h
basename = os.path.basename(path)
outname, ext = os.path.splitext(basename)
if ext.lower() in ['.png', '.jpg', '.jpeg', '.bmp', '.tif', '.tiff']:
outname, ext = os.path.splitext(os.path.basename(path))
basepath = os.path.join(
args.output_dir, '{}.{}'.format(outname, args.extension))
if ext.lower() in extensions:
outname += '_(tta{})'.format(args.tta_level) if args.tta else '_'
dst = src.copy()
start = time.time()
Expand All @@ -198,14 +200,16 @@ def load_models(cfg):
dst = upscale_image(args, dst, models['scale'])
print('Elapsed time: {:.6f} sec'.format(time.time() - start))

outname += '({}_{}).png'.format(args.arch.lower(), args.color)
if os.path.isdir(args.output):
outpath = os.path.join(args.output, outname)
outname += '({}_{}).{}'.format(
args.arch.lower(), args.color, args.extension)
outpath = os.path.join(args.output_dir, outname)
if not os.path.exists(basepath):
outpath = basepath
if icc_profile is not None:
dst.convert(src.mode).save(
outpath, quality=args.quality, lossloss=True, method=6,
icc_profile=icc_profile)
else:
if os.path.exists(args.output):
outpath = os.path.join(dirname, outname)
else:
outpath = args.output
dst.convert(src.mode).save(
outpath, icc_profile=src.info.get('icc_profile'))
dst.convert(src.mode).save(
outpath, quality=args.quality, lossloss=True, method=6)
six.print_('Saved as \'{}\''.format(outpath))

0 comments on commit 9bb635d

Please sign in to comment.