From 217ede86961d3ffc356556ad75cbe2b514373048 Mon Sep 17 00:00:00 2001 From: Daniel Mesejo Date: Thu, 17 Aug 2023 19:19:22 +0200 Subject: [PATCH] feat: add register_json (#458) --- datafusion/tests/test_sql.py | 53 +++++++++++++++++++++++++++++++++++- src/context.rs | 36 ++++++++++++++++++++++++ 2 files changed, 88 insertions(+), 1 deletion(-) diff --git a/datafusion/tests/test_sql.py b/datafusion/tests/test_sql.py index 608bb196..9d42a1f9 100644 --- a/datafusion/tests/test_sql.py +++ b/datafusion/tests/test_sql.py @@ -14,12 +14,13 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import gzip +import os import numpy as np import pyarrow as pa import pyarrow.dataset as ds import pytest -import gzip from datafusion import udf @@ -154,6 +155,56 @@ def test_register_dataset(ctx, tmp_path): assert result.to_pydict() == {"cnt": [100]} +def test_register_json(ctx, tmp_path): + path = os.path.dirname(os.path.abspath(__file__)) + test_data_path = os.path.join(path, "data_test_context", "data.json") + gzip_path = tmp_path / "data.json.gz" + + with open(test_data_path, "rb") as json_file: + with gzip.open(gzip_path, "wb") as gzipped_file: + gzipped_file.writelines(json_file) + + ctx.register_json("json", test_data_path) + ctx.register_json("json1", str(test_data_path)) + ctx.register_json( + "json2", + test_data_path, + schema_infer_max_records=10, + ) + ctx.register_json( + "json_gzip", + gzip_path, + file_extension="gz", + file_compression_type="gzip", + ) + + alternative_schema = pa.schema( + [ + ("some_int", pa.int16()), + ("some_bytes", pa.string()), + ("some_floats", pa.float32()), + ] + ) + ctx.register_json("json3", path, schema=alternative_schema) + + assert ctx.tables() == {"json", "json1", "json2", "json3", "json_gzip"} + + for table in ["json", "json1", "json2", "json_gzip"]: + result = ctx.sql(f'SELECT COUNT("B") AS cnt FROM {table}').collect() + result = pa.Table.from_batches(result) + assert result.to_pydict() == {"cnt": [3]} + + result = ctx.sql("SELECT * FROM json3").collect() + result = pa.Table.from_batches(result) + assert result.schema == alternative_schema + + with pytest.raises( + ValueError, + match="file_compression_type must one of: gzip, bz2, xz, zstd", + ): + ctx.register_json("json4", gzip_path, file_compression_type="rar") + + def test_execute(ctx, tmp_path): data = [1, 1, 2, 2, 3, 11, 12] diff --git a/src/context.rs b/src/context.rs index 1dca8a79..317ab785 100644 --- a/src/context.rs +++ b/src/context.rs @@ -509,6 +509,42 @@ impl PySessionContext { Ok(()) } + #[allow(clippy::too_many_arguments)] + #[pyo3(signature = (name, + path, + schema=None, + schema_infer_max_records=1000, + file_extension=".json", + table_partition_cols=vec![], + file_compression_type=None))] + fn register_json( + &mut self, + name: &str, + path: PathBuf, + schema: Option>, + schema_infer_max_records: usize, + file_extension: &str, + table_partition_cols: Vec<(String, String)>, + file_compression_type: Option, + py: Python, + ) -> PyResult<()> { + let path = path + .to_str() + .ok_or_else(|| PyValueError::new_err("Unable to convert path to a string"))?; + + let mut options = NdJsonReadOptions::default() + .file_compression_type(parse_file_compression_type(file_compression_type)?) + .table_partition_cols(convert_table_partition_cols(table_partition_cols)?); + options.schema_infer_max_records = schema_infer_max_records; + options.file_extension = file_extension; + options.schema = schema.as_ref().map(|x| &x.0); + + let result = self.ctx.register_json(name, path, options); + wait_for_future(py, result).map_err(DataFusionError::from)?; + + Ok(()) + } + // Registers a PyArrow.Dataset fn register_dataset(&self, name: &str, dataset: &PyAny, py: Python) -> PyResult<()> { let table: Arc = Arc::new(Dataset::new(dataset, py)?);