From 93dedd8d58f80f93350d13cdf7943cdd3fce6257 Mon Sep 17 00:00:00 2001 From: nameexhaustion Date: Wed, 17 Jul 2024 22:36:19 +1000 Subject: [PATCH] feat: Include file path option for NDJSON (#17681) --- crates/polars-core/src/utils/mod.rs | 2 ++ crates/polars-lazy/src/scan/ndjson.rs | 9 +++++- .../src/executors/scan/ndjson.rs | 23 +++++++++++++- .../src/plans/conversion/dsl_to_ir.rs | 13 ++++++++ py-polars/polars/io/ndjson.py | 10 +++++-- py-polars/src/lazyframe/mod.rs | 4 ++- py-polars/tests/unit/io/test_scan.py | 30 ++++++++++++++++++- 7 files changed, 85 insertions(+), 6 deletions(-) diff --git a/crates/polars-core/src/utils/mod.rs b/crates/polars-core/src/utils/mod.rs index 854e60ba4d7d..1b07c8206b99 100644 --- a/crates/polars-core/src/utils/mod.rs +++ b/crates/polars-core/src/utils/mod.rs @@ -731,6 +731,8 @@ where } /// This takes ownership of the DataFrame so that drop is called earlier. +/// # Panics +/// Panics if `dfs` is empty. pub fn accumulate_dataframes_vertical(dfs: I) -> PolarsResult where I: IntoIterator, diff --git a/crates/polars-lazy/src/scan/ndjson.rs b/crates/polars-lazy/src/scan/ndjson.rs index 37209ed5710c..69dbf0e6d678 100644 --- a/crates/polars-lazy/src/scan/ndjson.rs +++ b/crates/polars-lazy/src/scan/ndjson.rs @@ -21,6 +21,7 @@ pub struct LazyJsonLineReader { pub(crate) infer_schema_length: Option, pub(crate) n_rows: Option, pub(crate) ignore_errors: bool, + pub(crate) include_file_paths: Option>, } impl LazyJsonLineReader { @@ -39,6 +40,7 @@ impl LazyJsonLineReader { infer_schema_length: NonZeroUsize::new(100), ignore_errors: false, n_rows: None, + include_file_paths: None, } } /// Add a row index column. @@ -89,6 +91,11 @@ impl LazyJsonLineReader { self.batch_size = batch_size; self } + + pub fn with_include_file_paths(mut self, include_file_paths: Option>) -> Self { + self.include_file_paths = include_file_paths; + self + } } impl LazyFileListReader for LazyJsonLineReader { @@ -108,7 +115,7 @@ impl LazyFileListReader for LazyJsonLineReader { file_counter: 0, hive_options: Default::default(), glob: true, - include_file_paths: None, + include_file_paths: self.include_file_paths, }; let options = NDJsonReadOptions { diff --git a/crates/polars-mem-engine/src/executors/scan/ndjson.rs b/crates/polars-mem-engine/src/executors/scan/ndjson.rs index 4b9097d525a2..38e5c3965fd9 100644 --- a/crates/polars-mem-engine/src/executors/scan/ndjson.rs +++ b/crates/polars-mem-engine/src/executors/scan/ndjson.rs @@ -40,6 +40,18 @@ impl JsonExec { let mut n_rows = self.file_scan_options.n_rows; + // Avoid panicking + if n_rows == Some(0) { + let mut df = DataFrame::empty_with_schema(schema); + if let Some(col) = &self.file_scan_options.include_file_paths { + unsafe { df.with_column_unchecked(StringChunked::full_null(col, 0).into_series()) }; + } + if let Some(row_index) = &self.file_scan_options.row_index { + df.with_row_index_mut(row_index.name.as_ref(), Some(row_index.offset)); + } + return Ok(df); + } + let dfs = self .paths .iter() @@ -67,7 +79,7 @@ impl JsonExec { .with_ignore_errors(self.options.ignore_errors) .finish(); - let df = match df { + let mut df = match df { Ok(df) => df, Err(e) => return Some(Err(e)), }; @@ -76,6 +88,15 @@ impl JsonExec { *n_rows -= df.height(); } + if let Some(col) = &self.file_scan_options.include_file_paths { + let path = p.to_str().unwrap(); + unsafe { + df.with_column_unchecked( + StringChunked::full(col, path, df.height()).into_series(), + ) + }; + } + Some(Ok(df)) }) .collect::>>()?; diff --git a/crates/polars-plan/src/plans/conversion/dsl_to_ir.rs b/crates/polars-plan/src/plans/conversion/dsl_to_ir.rs index 63399bdc5974..eed5ecb1fdc4 100644 --- a/crates/polars-plan/src/plans/conversion/dsl_to_ir.rs +++ b/crates/polars-plan/src/plans/conversion/dsl_to_ir.rs @@ -185,6 +185,19 @@ pub fn to_alp_impl( None }; + file_options.include_file_paths = + file_options.include_file_paths.filter(|_| match scan_type { + #[cfg(feature = "parquet")] + FileScan::Parquet { .. } => true, + #[cfg(feature = "ipc")] + FileScan::Ipc { .. } => true, + #[cfg(feature = "csv")] + FileScan::Csv { .. } => true, + #[cfg(feature = "json")] + FileScan::NDJson { .. } => true, + FileScan::Anonymous { .. } => false, + }); + // Only if we have a writing file handle we must resolve hive partitions // update schema's etc. if let Some(lock) = &mut _file_info_write { diff --git a/py-polars/polars/io/ndjson.py b/py-polars/polars/io/ndjson.py index aa142c9e77b2..3b12b263a19d 100644 --- a/py-polars/polars/io/ndjson.py +++ b/py-polars/polars/io/ndjson.py @@ -99,6 +99,7 @@ def scan_ndjson( row_index_name: str | None = None, row_index_offset: int = 0, ignore_errors: bool = False, + include_file_paths: str | None = None, ) -> LazyFrame: """ Lazily read from a newline delimited JSON file or multiple files via glob patterns. @@ -138,12 +139,16 @@ def scan_ndjson( Offset to start the row index column (only use if the name is set) ignore_errors Return `Null` if parsing fails because of schema mismatches. + include_file_paths + Include the path of the source file(s) as a column with this name. """ if isinstance(source, (str, Path)): - source = normalize_filepath(source) + source = normalize_filepath(source, check_not_directory=False) sources = [] else: - sources = [normalize_filepath(source) for source in source] + sources = [ + normalize_filepath(source, check_not_directory=False) for source in source + ] source = None # type: ignore[assignment] if infer_schema_length == 0: msg = "'infer_schema_length' should be positive" @@ -160,5 +165,6 @@ def scan_ndjson( rechunk, parse_row_index_args(row_index_name, row_index_offset), ignore_errors, + include_file_paths=include_file_paths, ) return wrap_ldf(pylf) diff --git a/py-polars/src/lazyframe/mod.rs b/py-polars/src/lazyframe/mod.rs index 007fdec09611..e2f1708689f7 100644 --- a/py-polars/src/lazyframe/mod.rs +++ b/py-polars/src/lazyframe/mod.rs @@ -44,7 +44,7 @@ impl PyLazyFrame { #[staticmethod] #[cfg(feature = "json")] #[allow(clippy::too_many_arguments)] - #[pyo3(signature = (path, paths, infer_schema_length, schema, batch_size, n_rows, low_memory, rechunk, row_index, ignore_errors))] + #[pyo3(signature = (path, paths, infer_schema_length, schema, batch_size, n_rows, low_memory, rechunk, row_index, ignore_errors, include_file_paths))] fn new_from_ndjson( path: Option, paths: Vec, @@ -56,6 +56,7 @@ impl PyLazyFrame { rechunk: bool, row_index: Option<(String, IdxSize)>, ignore_errors: bool, + include_file_paths: Option, ) -> PyResult { let row_index = row_index.map(|(name, offset)| RowIndex { name: Arc::from(name.as_str()), @@ -77,6 +78,7 @@ impl PyLazyFrame { .with_schema(schema.map(|schema| Arc::new(schema.0))) .with_row_index(row_index) .with_ignore_errors(ignore_errors) + .with_include_file_paths(include_file_paths.map(Arc::from)) .finish() .map_err(PyPolarsErr::from)?; diff --git a/py-polars/tests/unit/io/test_scan.py b/py-polars/tests/unit/io/test_scan.py index 2cc9cbb025b1..637cefc66b2a 100644 --- a/py-polars/tests/unit/io/test_scan.py +++ b/py-polars/tests/unit/io/test_scan.py @@ -447,6 +447,33 @@ def test_scan_with_row_index_filter_and_limit( ) +@pytest.mark.write_disk() +@pytest.mark.parametrize( + ("scan_func", "write_func"), + [ + (pl.scan_parquet, pl.DataFrame.write_parquet), + (pl.scan_ipc, pl.DataFrame.write_ipc), + (pl.scan_csv, pl.DataFrame.write_csv), + (pl.scan_ndjson, pl.DataFrame.write_ndjson), + ], +) +@pytest.mark.parametrize( + "streaming", + [True, False], +) +def test_scan_limit_0_does_not_panic( + tmp_path: Path, + scan_func: Callable[[Any], pl.LazyFrame], + write_func: Callable[[pl.DataFrame, Path], None], + streaming: bool, +) -> None: + tmp_path.mkdir(exist_ok=True) + path = tmp_path / "data.bin" + df = pl.DataFrame({"x": 1}) + write_func(df, path) + assert_frame_equal(scan_func(path).head(0).collect(streaming=streaming), df.clear()) + + @pytest.mark.write_disk() @pytest.mark.parametrize( ("scan_func", "write_func"), @@ -598,6 +625,7 @@ def test_scan_nonexistent_path(format: str) -> None: (pl.scan_parquet, pl.DataFrame.write_parquet), (pl.scan_ipc, pl.DataFrame.write_ipc), (pl.scan_csv, pl.DataFrame.write_csv), + (pl.scan_ndjson, pl.DataFrame.write_ndjson), ], ) @pytest.mark.parametrize( @@ -639,7 +667,7 @@ def test_scan_include_file_name( assert_frame_equal(lf.collect(streaming=streaming), df) # TODO: Support this with CSV - if scan_func is not pl.scan_csv: + if scan_func not in [pl.scan_csv, pl.scan_ndjson]: # Test projecting only the path column assert_frame_equal( lf.select("path").collect(streaming=streaming),