From d55b1af7b78ed2210ad9705d7484b38f1744f37b Mon Sep 17 00:00:00 2001 From: Alessandro Molina Date: Wed, 16 Oct 2024 23:46:20 +0200 Subject: [PATCH] GH-44125: [Python] Add concat_batches function (#44126) ### Rationale for this change Allows to concatenate recordbatches in Python ### What changes are included in this PR? Adds `concat_batches` function and tests ### Are these changes tested? yes ### Are there any user-facing changes? A new public function has been added * GitHub Issue: #44125 --------- Co-authored-by: Joris Van den Bossche --- docs/source/python/api/tables.rst | 1 + python/pyarrow/__init__.py | 2 +- python/pyarrow/includes/libarrow.pxd | 4 +++ python/pyarrow/table.pxi | 51 ++++++++++++++++++++++++++++ python/pyarrow/tests/test_table.py | 43 +++++++++++++++++++++++ 5 files changed, 100 insertions(+), 1 deletion(-) diff --git a/docs/source/python/api/tables.rst b/docs/source/python/api/tables.rst index ae9f5de127dfd..48cc67eb66720 100644 --- a/docs/source/python/api/tables.rst +++ b/docs/source/python/api/tables.rst @@ -32,6 +32,7 @@ Factory Functions concat_arrays concat_tables record_batch + concat_batches table Classes diff --git a/python/pyarrow/__init__.py b/python/pyarrow/__init__.py index d31c93119b73a..fb7c242187cb0 100644 --- a/python/pyarrow/__init__.py +++ b/python/pyarrow/__init__.py @@ -267,7 +267,7 @@ def print_entry(label, value): from pyarrow.lib import (ChunkedArray, RecordBatch, Table, table, concat_arrays, concat_tables, TableGroupBy, - RecordBatchReader) + RecordBatchReader, concat_batches) # Exceptions from pyarrow.lib import (ArrowCancelled, diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd index 8e6922a912a32..d304641e0f4d1 100644 --- a/python/pyarrow/includes/libarrow.pxd +++ b/python/pyarrow/includes/libarrow.pxd @@ -1356,6 +1356,10 @@ cdef extern from "arrow/api.h" namespace "arrow" nogil: CConcatenateTablesOptions options, CMemoryPool* memory_pool) + CResult[shared_ptr[CRecordBatch]] ConcatenateRecordBatches( + const vector[shared_ptr[CRecordBatch]]& batches, + CMemoryPool* memory_pool) + cdef cppclass CDictionaryUnifier" arrow::DictionaryUnifier": @staticmethod CResult[shared_ptr[CChunkedArray]] UnifyChunkedArray( diff --git a/python/pyarrow/table.pxi b/python/pyarrow/table.pxi index 819bbc34c66b9..af241e4be07d9 100644 --- a/python/pyarrow/table.pxi +++ b/python/pyarrow/table.pxi @@ -6259,6 +6259,57 @@ def concat_tables(tables, MemoryPool memory_pool=None, str promote_options="none return pyarrow_wrap_table(c_result_table) +def concat_batches(recordbatches, MemoryPool memory_pool=None): + """ + Concatenate pyarrow.RecordBatch objects. + + All recordbatches must share the same Schema, + the operation implies a copy of the data to merge + the arrays of the different RecordBatches. + + Parameters + ---------- + recordbatches : iterable of pyarrow.RecordBatch objects + Pyarrow record batches to concatenate into a single RecordBatch. + memory_pool : MemoryPool, default None + For memory allocations, if required, otherwise use default pool. + + Examples + -------- + >>> import pyarrow as pa + >>> t1 = pa.record_batch([ + ... pa.array([2, 4, 5, 100]), + ... pa.array(["Flamingo", "Horse", "Brittle stars", "Centipede"]) + ... ], names=['n_legs', 'animals']) + >>> t2 = pa.record_batch([ + ... pa.array([2, 4]), + ... pa.array(["Parrot", "Dog"]) + ... ], names=['n_legs', 'animals']) + >>> pa.concat_batches([t1,t2]) + pyarrow.RecordBatch + n_legs: int64 + animals: string + ---- + n_legs: [2,4,5,100,2,4] + animals: ["Flamingo","Horse","Brittle stars","Centipede","Parrot","Dog"] + + """ + cdef: + vector[shared_ptr[CRecordBatch]] c_recordbatches + shared_ptr[CRecordBatch] c_result_recordbatch + RecordBatch recordbatch + CMemoryPool* pool = maybe_unbox_memory_pool(memory_pool) + + for recordbatch in recordbatches: + c_recordbatches.push_back(recordbatch.sp_batch) + + with nogil: + c_result_recordbatch = GetResultValue( + ConcatenateRecordBatches(c_recordbatches, pool)) + + return pyarrow_wrap_batch(c_result_recordbatch) + + def _from_pydict(cls, mapping, schema, metadata): """ Construct a Table/RecordBatch from Arrow arrays or columns. diff --git a/python/pyarrow/tests/test_table.py b/python/pyarrow/tests/test_table.py index b66a5eb083cc5..4c058ccecda5e 100644 --- a/python/pyarrow/tests/test_table.py +++ b/python/pyarrow/tests/test_table.py @@ -2037,6 +2037,49 @@ def test_table_negative_indexing(): table[4] +def test_concat_batches(): + data = [ + list(range(5)), + [-10., -5., 0., 5., 10.] + ] + data2 = [ + list(range(5, 10)), + [1., 2., 3., 4., 5.] + ] + + t1 = pa.RecordBatch.from_arrays([pa.array(x) for x in data], + names=('a', 'b')) + t2 = pa.RecordBatch.from_arrays([pa.array(x) for x in data2], + names=('a', 'b')) + + result = pa.concat_batches([t1, t2]) + result.validate() + assert len(result) == 10 + + expected = pa.RecordBatch.from_arrays([pa.array(x + y) + for x, y in zip(data, data2)], + names=('a', 'b')) + + assert result.equals(expected) + + +def test_concat_batches_different_schema(): + t1 = pa.RecordBatch.from_arrays( + [pa.array([1, 2], type=pa.int64())], ["f"]) + t2 = pa.RecordBatch.from_arrays( + [pa.array([1, 2], type=pa.float32())], ["f"]) + + with pytest.raises(pa.ArrowInvalid, + match="not match index 0 recordbatch schema"): + pa.concat_batches([t1, t2]) + + +def test_concat_batches_none_batches(): + # ARROW-11997 + with pytest.raises(AttributeError): + pa.concat_batches([None]) + + @pytest.mark.parametrize( ('cls'), [