Skip to content

Commit

Permalink
apacheGH-44125: [Python] Add concat_batches function (apache#44126)
Browse files Browse the repository at this point in the history
### Rationale for this change

Allows to concatenate recordbatches in Python

### What changes are included in this PR?

Adds `concat_batches` function and tests

### Are these changes tested?

yes

### Are there any user-facing changes?

A new public function has been added
* GitHub Issue: apache#44125

---------

Co-authored-by: Joris Van den Bossche <jorisvandenbossche@gmail.com>
  • Loading branch information
amol- and jorisvandenbossche authored Oct 16, 2024
1 parent f8ae352 commit d55b1af
Show file tree
Hide file tree
Showing 5 changed files with 100 additions and 1 deletion.
1 change: 1 addition & 0 deletions docs/source/python/api/tables.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ Factory Functions
concat_arrays
concat_tables
record_batch
concat_batches
table

Classes
Expand Down
2 changes: 1 addition & 1 deletion python/pyarrow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ def print_entry(label, value):

from pyarrow.lib import (ChunkedArray, RecordBatch, Table, table,
concat_arrays, concat_tables, TableGroupBy,
RecordBatchReader)
RecordBatchReader, concat_batches)

# Exceptions
from pyarrow.lib import (ArrowCancelled,
Expand Down
4 changes: 4 additions & 0 deletions python/pyarrow/includes/libarrow.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -1356,6 +1356,10 @@ cdef extern from "arrow/api.h" namespace "arrow" nogil:
CConcatenateTablesOptions options,
CMemoryPool* memory_pool)

CResult[shared_ptr[CRecordBatch]] ConcatenateRecordBatches(
const vector[shared_ptr[CRecordBatch]]& batches,
CMemoryPool* memory_pool)

cdef cppclass CDictionaryUnifier" arrow::DictionaryUnifier":
@staticmethod
CResult[shared_ptr[CChunkedArray]] UnifyChunkedArray(
Expand Down
51 changes: 51 additions & 0 deletions python/pyarrow/table.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -6259,6 +6259,57 @@ def concat_tables(tables, MemoryPool memory_pool=None, str promote_options="none
return pyarrow_wrap_table(c_result_table)


def concat_batches(recordbatches, MemoryPool memory_pool=None):
"""
Concatenate pyarrow.RecordBatch objects.
All recordbatches must share the same Schema,
the operation implies a copy of the data to merge
the arrays of the different RecordBatches.
Parameters
----------
recordbatches : iterable of pyarrow.RecordBatch objects
Pyarrow record batches to concatenate into a single RecordBatch.
memory_pool : MemoryPool, default None
For memory allocations, if required, otherwise use default pool.
Examples
--------
>>> import pyarrow as pa
>>> t1 = pa.record_batch([
... pa.array([2, 4, 5, 100]),
... pa.array(["Flamingo", "Horse", "Brittle stars", "Centipede"])
... ], names=['n_legs', 'animals'])
>>> t2 = pa.record_batch([
... pa.array([2, 4]),
... pa.array(["Parrot", "Dog"])
... ], names=['n_legs', 'animals'])
>>> pa.concat_batches([t1,t2])
pyarrow.RecordBatch
n_legs: int64
animals: string
----
n_legs: [2,4,5,100,2,4]
animals: ["Flamingo","Horse","Brittle stars","Centipede","Parrot","Dog"]
"""
cdef:
vector[shared_ptr[CRecordBatch]] c_recordbatches
shared_ptr[CRecordBatch] c_result_recordbatch
RecordBatch recordbatch
CMemoryPool* pool = maybe_unbox_memory_pool(memory_pool)

for recordbatch in recordbatches:
c_recordbatches.push_back(recordbatch.sp_batch)

with nogil:
c_result_recordbatch = GetResultValue(
ConcatenateRecordBatches(c_recordbatches, pool))

return pyarrow_wrap_batch(c_result_recordbatch)


def _from_pydict(cls, mapping, schema, metadata):
"""
Construct a Table/RecordBatch from Arrow arrays or columns.
Expand Down
43 changes: 43 additions & 0 deletions python/pyarrow/tests/test_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -2037,6 +2037,49 @@ def test_table_negative_indexing():
table[4]


def test_concat_batches():
data = [
list(range(5)),
[-10., -5., 0., 5., 10.]
]
data2 = [
list(range(5, 10)),
[1., 2., 3., 4., 5.]
]

t1 = pa.RecordBatch.from_arrays([pa.array(x) for x in data],
names=('a', 'b'))
t2 = pa.RecordBatch.from_arrays([pa.array(x) for x in data2],
names=('a', 'b'))

result = pa.concat_batches([t1, t2])
result.validate()
assert len(result) == 10

expected = pa.RecordBatch.from_arrays([pa.array(x + y)
for x, y in zip(data, data2)],
names=('a', 'b'))

assert result.equals(expected)


def test_concat_batches_different_schema():
t1 = pa.RecordBatch.from_arrays(
[pa.array([1, 2], type=pa.int64())], ["f"])
t2 = pa.RecordBatch.from_arrays(
[pa.array([1, 2], type=pa.float32())], ["f"])

with pytest.raises(pa.ArrowInvalid,
match="not match index 0 recordbatch schema"):
pa.concat_batches([t1, t2])


def test_concat_batches_none_batches():
# ARROW-11997
with pytest.raises(AttributeError):
pa.concat_batches([None])


@pytest.mark.parametrize(
('cls'),
[
Expand Down

0 comments on commit d55b1af

Please sign in to comment.