Skip to content

Commit

Permalink
Fix regression where input file was truncated before reading
Browse files Browse the repository at this point in the history
Closes #190, reported by @Koncopd.
  • Loading branch information
kynan committed Feb 4, 2024
1 parent 1b4b956 commit cb98ba9
Showing 1 changed file with 40 additions and 34 deletions.
74 changes: 40 additions & 34 deletions nbstripout/_nbstripout.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,36 +331,38 @@ def status(git_config, install_location=INSTALL_LOCATION_LOCAL, verbose=False):
return 1

def process_notebook(input_stream, output_stream, args, extra_keys, filename='input from stdin'):
try:
if args.mode == 'zeppelin':
nb = json.load(input_stream, object_pairs_hook=collections.OrderedDict)
nb_stripped = strip_zeppelin_output(nb)
if args.dry_run:
output_stream.write(f'Dry run: would have stripped {filename}\n')
return
json.dump(nb_stripped, output_stream, indent=2)
output_stream.write('\n')
output_stream.flush()
if args.mode == 'zeppelin':
nb = json.load(input_stream, object_pairs_hook=collections.OrderedDict)
nb_stripped = strip_zeppelin_output(nb)
if args.dry_run:
output_stream.write(f'Dry run: would have stripped {filename}\n')
return
if output_stream.seekable():
output_stream.seek(0)
output_stream.truncate()
json.dump(nb_stripped, output_stream, indent=2)
output_stream.write('\n')
output_stream.flush()
return
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=UserWarning)
nb = nbformat.read(input_stream, as_version=nbformat.NO_CONVERT)

nb = strip_output(nb, args.keep_output, args.keep_count, args.keep_id,
extra_keys, args.drop_empty_cells,
args.drop_tagged_cells.split(), args.strip_init_cells,
_parse_size(args.max_size))

if args.dry_run:
output_stream.write(f'Dry run: would have stripped {filename}\n')
else:
if output_stream.seekable():
output_stream.seek(0)
output_stream.truncate()
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=UserWarning)
nb = nbformat.read(input_stream, as_version=nbformat.NO_CONVERT)

nb = strip_output(nb, args.keep_output, args.keep_count, args.keep_id,
extra_keys, args.drop_empty_cells,
args.drop_tagged_cells.split(), args.strip_init_cells,
_parse_size(args.max_size))

if args.dry_run:
output_stream.write(f'Dry run: would have stripped {filename}\n')
else:
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=UserWarning)
nbformat.write(nb, output_stream)
output_stream.flush()
except nbformat.reader.NotJSONError:
print('No valid notebook detected', file=sys.stderr)
raise SystemExit(1)
nbformat.write(nb, output_stream)
output_stream.flush()


def main():
Expand Down Expand Up @@ -486,12 +488,12 @@ def main():
continue

try:
with io.open(filename, 'r', encoding='utf8') as fin:
if args.textconv or args.dry_run:
process_notebook(fin, output_stream, args, extra_keys, filename)
else:
with io.open(filename, 'w', encoding='utf8', newline='') as fout:
process_notebook(fin, fout, args, extra_keys, filename)
with io.open(filename, 'r+', encoding='utf8', newline='') as f:
out = output_stream if args.textconv or args.dry_run else f
process_notebook(f, out, args, extra_keys, filename)
except nbformat.reader.NotJSONError:
print(f"No valid notebook detected in '{filename}'", file=sys.stderr)
raise SystemExit(1)
except FileNotFoundError:
print(f"Could not strip '{filename}': file not found", file=sys.stderr)
raise SystemExit(1)
Expand All @@ -501,4 +503,8 @@ def main():
raise

if not args.files and input_stream:
process_notebook(input_stream, output_stream, args, extra_keys)
try:
process_notebook(input_stream, output_stream, args, extra_keys)
except nbformat.reader.NotJSONError:
print('No valid notebook detected on stdin', file=sys.stderr)
raise SystemExit(1)

0 comments on commit cb98ba9

Please sign in to comment.