Skip to content

Commit

Permalink
Merge pull request #10 from randolf-scholz/main
Browse files Browse the repository at this point in the history
Added .pre-commit-hooks.yaml, fixes #8 and #9
  • Loading branch information
mlucool committed Sep 18, 2023
2 parents 913b2d9 + b3e8dab commit 2be4ee8
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 35 deletions.
6 changes: 6 additions & 0 deletions .pre-commit-hooks.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
- id: nbstripout-fast
name: nbstripout-fast
entry: nbstripout-fast
types: [jupyter]
language: rust
description: "Strip output from Jupyter notebooks (modifies the files in place by default)."
54 changes: 29 additions & 25 deletions examples/comparison.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
#!/usr/bin/env python

import json
import re
import subprocess
import tempfile
Expand Down Expand Up @@ -32,27 +31,32 @@ def run(nb):
return nb


with tempfile.TemporaryDirectory() as tmpdir:
tmpdir = Path(tmpdir)
count_plus_filenames = []
for num_cells in [1, 10, 100, 1_000, 10_000]:
nb = run(create(num_cells))
filename = tmpdir / f"generated-{num_cells}-cells.ipynb"
count_plus_filenames.append((num_cells, filename))
with open(filename, "w") as f:
nbformat.write(nb, f)

print("{:<7} {:<12} {:<12}".format("Cells", "nbstripout", "nbstripout_fast"))
for num_cells, file in count_plus_filenames:
times = []
for cmd in [NBSTRIPOUT, NBSTRIPOUT_FAST]:
# emulate git filter by outputting to stdout
output = subprocess.check_output(
f"time {cmd} {file} -t > /dev/null",
stderr=subprocess.STDOUT,
universal_newlines=True,
shell=True,
)
real_time = re.match(r"real\s+(\w+\.\w+s)", output.strip()).group(1)
times.append(real_time)
print("{:<7} {:<12} {:<12}".format(num_cells, times[0], times[1]))
def main():
with tempfile.TemporaryDirectory() as tmpdir:
tmpdir = Path(tmpdir)
count_plus_filenames = []
for num_cells in [1, 10, 100, 1_000, 10_000]:
nb = run(create(num_cells))
filename = tmpdir / f"generated-{num_cells}-cells.ipynb"
count_plus_filenames.append((num_cells, filename))
with open(filename, "w") as f:
nbformat.write(nb, f)

print("{:<7} {:<12} {:<12}".format("Cells", "nbstripout", "nbstripout_fast"))
for num_cells, file in count_plus_filenames:
times = []
for cmd in [NBSTRIPOUT, NBSTRIPOUT_FAST]:
# emulate git filter by outputting to stdout
output = subprocess.check_output(
f"time {cmd} {file} -t > /dev/null",
stderr=subprocess.STDOUT,
universal_newlines=True,
shell=True,
)
real_time = re.match(r"real\s+(\w+\.\w+s)", output.strip()).group(1)
times.append(real_time)
print("{:<7} {:<12} {:<12}".format(num_cells, times[0], times[1]))


if __name__ == "__main__":
main()
31 changes: 21 additions & 10 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ use std::env;
use std::fs;
use std::io;
use std::io::BufRead;
use std::path::PathBuf;

mod stripoutlib;

Expand Down Expand Up @@ -58,7 +59,7 @@ struct Cli {
keep_output: bool,

#[clap(long, action)]
/// Remove cells where `source` is empty or contains only whitepace
/// Remove cells where `source` is empty or contains only whitespace
drop_empty_cells: bool,

#[clap(short, long, action)]
Expand All @@ -79,7 +80,7 @@ struct Cli {

#[clap(parse(from_os_str))]
/// Files to strip output from
files: Vec<std::path::PathBuf>,
files: Vec<PathBuf>,
}

#[derive(Deserialize, Debug)]
Expand Down Expand Up @@ -135,7 +136,7 @@ fn process_file(
keep_count: bool,
extra_keys: &Vec<String>,
drop_empty_cells: bool,
output_file: Option<std::path::PathBuf>,
output_file: Option<PathBuf>,
) -> Result<(), String> {
let mut nb: serde_json::Value = serde_json::from_str(&contents)
.map_err(|e| format!("JSON was not well-formatted: {:?}", e))?;
Expand All @@ -155,18 +156,28 @@ fn process_file(
let mut ser = serde_json::Serializer::with_formatter(buf, formatter);
nb.serialize(&mut ser).map_err(|e| {
format!(
"Unable to serialize notebook. Likely an intenral error: {:?}",
"Unable to serialize notebook. Likely an internal error: {:?}",
e
)
})?;
let cleaned_contents = String::from_utf8(ser.into_inner()).map_err(|e| format!("{:?}", e))?;
let mut cleaned_contents = String::from_utf8(ser.into_inner()).map_err(|e| format!("{:?}", e))?;

if let Some(file) = output_file {
fs::write(&file, cleaned_contents)
.map_err(|e| format!("Could not write to {:?} due to {:?}", file, e))?;
// Check if the original content ended with a newline and the cleaned content doesn't
if contents.ends_with('\n') && !cleaned_contents.ends_with('\n') {
cleaned_contents.push('\n'); // Append a newline if necessary
}

if cleaned_contents != *contents {
if let Some(file) = output_file {
fs::write(&file, cleaned_contents)
.map_err(|e| format!("Could not write to {:?} due to {:?}", file, e))?;
} else {
println!("{}", cleaned_contents);
}
} else {
println!("{}", cleaned_contents);
log::debug!("Content unchanged. File not modified.");
}

Ok(())
}

Expand Down Expand Up @@ -197,7 +208,7 @@ fn main() -> Result<(), String> {
}
if let Some(config_keep_keys) = nbstripout_fast.keep_keys {
for key in config_keep_keys {
// Remove all occurances
// Remove all occurrences
extra_keys.retain(|x| x != &key);
}
}
Expand Down

0 comments on commit 2be4ee8

Please sign in to comment.