diff --git a/nbstripout/_nbstripout.py b/nbstripout/_nbstripout.py index 07ff69a..d9d1b9a 100644 --- a/nbstripout/_nbstripout.py +++ b/nbstripout/_nbstripout.py @@ -331,36 +331,38 @@ def status(git_config, install_location=INSTALL_LOCATION_LOCAL, verbose=False): return 1 def process_notebook(input_stream, output_stream, args, extra_keys, filename='input from stdin'): - try: - if args.mode == 'zeppelin': - nb = json.load(input_stream, object_pairs_hook=collections.OrderedDict) - nb_stripped = strip_zeppelin_output(nb) - if args.dry_run: - output_stream.write(f'Dry run: would have stripped {filename}\n') - return - json.dump(nb_stripped, output_stream, indent=2) - output_stream.write('\n') - output_stream.flush() + if args.mode == 'zeppelin': + nb = json.load(input_stream, object_pairs_hook=collections.OrderedDict) + nb_stripped = strip_zeppelin_output(nb) + if args.dry_run: + output_stream.write(f'Dry run: would have stripped {filename}\n') return + if output_stream.seekable(): + output_stream.seek(0) + output_stream.truncate() + json.dump(nb_stripped, output_stream, indent=2) + output_stream.write('\n') + output_stream.flush() + return + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=UserWarning) + nb = nbformat.read(input_stream, as_version=nbformat.NO_CONVERT) + + nb = strip_output(nb, args.keep_output, args.keep_count, args.keep_id, + extra_keys, args.drop_empty_cells, + args.drop_tagged_cells.split(), args.strip_init_cells, + _parse_size(args.max_size)) + + if args.dry_run: + output_stream.write(f'Dry run: would have stripped {filename}\n') + else: + if output_stream.seekable(): + output_stream.seek(0) + output_stream.truncate() with warnings.catch_warnings(): warnings.simplefilter("ignore", category=UserWarning) - nb = nbformat.read(input_stream, as_version=nbformat.NO_CONVERT) - - nb = strip_output(nb, args.keep_output, args.keep_count, args.keep_id, - extra_keys, args.drop_empty_cells, - args.drop_tagged_cells.split(), args.strip_init_cells, - _parse_size(args.max_size)) - - if args.dry_run: - output_stream.write(f'Dry run: would have stripped {filename}\n') - else: - with warnings.catch_warnings(): - warnings.simplefilter("ignore", category=UserWarning) - nbformat.write(nb, output_stream) - output_stream.flush() - except nbformat.reader.NotJSONError: - print('No valid notebook detected', file=sys.stderr) - raise SystemExit(1) + nbformat.write(nb, output_stream) + output_stream.flush() def main(): @@ -486,12 +488,12 @@ def main(): continue try: - with io.open(filename, 'r', encoding='utf8') as fin: - if args.textconv or args.dry_run: - process_notebook(fin, output_stream, args, extra_keys, filename) - else: - with io.open(filename, 'w', encoding='utf8', newline='') as fout: - process_notebook(fin, fout, args, extra_keys, filename) + with io.open(filename, 'r+', encoding='utf8', newline='') as f: + out = output_stream if args.textconv or args.dry_run else f + process_notebook(f, out, args, extra_keys, filename) + except nbformat.reader.NotJSONError: + print(f"No valid notebook detected in '{filename}'", file=sys.stderr) + raise SystemExit(1) except FileNotFoundError: print(f"Could not strip '{filename}': file not found", file=sys.stderr) raise SystemExit(1) @@ -501,4 +503,8 @@ def main(): raise if not args.files and input_stream: - process_notebook(input_stream, output_stream, args, extra_keys) + try: + process_notebook(input_stream, output_stream, args, extra_keys) + except nbformat.reader.NotJSONError: + print('No valid notebook detected on stdin', file=sys.stderr) + raise SystemExit(1)