Skip to content

Commit d85228f

Browse files
committed
Be more careful about reference counts in zero-copy handoff, add pyarrow.Array.to_pandas method
Change-Id: Ic66c86f6900ff95463228667305760f44d71185c
1 parent cc7a6b3 commit d85228f

File tree

6 files changed

+72
-10
lines changed

6 files changed

+72
-10
lines changed

python/pyarrow/array.pyx

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import numpy as np
2323

2424
from pyarrow.includes.libarrow cimport *
25+
from pyarrow.includes.common cimport PyObject_to_object
2526
cimport pyarrow.includes.pyarrow as pyarrow
2627

2728
import pyarrow.config
@@ -35,6 +36,8 @@ from pyarrow.scalar import NA
3536
from pyarrow.schema cimport Schema
3637
import pyarrow.schema as schema
3738

39+
cimport cpython
40+
3841

3942
def total_allocated_bytes():
4043
cdef MemoryPool* pool = pyarrow.get_memory_pool()
@@ -111,6 +114,24 @@ cdef class Array:
111114
def slice(self, start, end):
112115
pass
113116

117+
def to_pandas(self):
118+
"""
119+
Convert to an array object suitable for use in pandas
120+
121+
See also
122+
--------
123+
Column.to_pandas
124+
Table.to_pandas
125+
RecordBatch.to_pandas
126+
"""
127+
cdef:
128+
PyObject* np_arr
129+
130+
check_status(pyarrow.ConvertArrayToPandas(
131+
self.sp_array, <PyObject*> self, &np_arr))
132+
133+
return PyObject_to_object(np_arr)
134+
114135

115136
cdef class NullArray(Array):
116137
pass

python/pyarrow/includes/common.pxd

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,3 +47,10 @@ cdef extern from "arrow/api.h" namespace "arrow" nogil:
4747
c_bool IsKeyError()
4848
c_bool IsNotImplemented()
4949
c_bool IsInvalid()
50+
51+
52+
cdef inline object PyObject_to_object(PyObject* o):
53+
# Cast to "object" increments reference count
54+
cdef object result = <object> o
55+
cpython.Py_DECREF(result)
56+
return result

python/pyarrow/includes/pyarrow.pxd

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,10 @@ cdef extern from "pyarrow/api.h" namespace "pyarrow" nogil:
3434
shared_ptr[CArray]* out)
3535

3636
CStatus ConvertArrayToPandas(const shared_ptr[CArray]& arr,
37-
object py_ref, PyObject** out)
37+
PyObject* py_ref, PyObject** out)
3838

3939
CStatus ConvertColumnToPandas(const shared_ptr[CColumn]& arr,
40-
object py_ref, PyObject** out)
40+
PyObject* py_ref, PyObject** out)
4141

4242
MemoryPool* get_memory_pool()
4343

python/pyarrow/table.pyx

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from cython.operator cimport dereference as deref
2323

2424
from pyarrow.includes.libarrow cimport *
25+
from pyarrow.includes.common cimport PyObject_to_object
2526
cimport pyarrow.includes.pyarrow as pyarrow
2627

2728
import pyarrow.config
@@ -32,6 +33,7 @@ from pyarrow.schema cimport box_data_type, box_schema
3233

3334
from pyarrow.compat import frombytes, tobytes
3435

36+
cimport cpython
3537

3638
cdef class ChunkedArray:
3739
'''
@@ -100,8 +102,10 @@ cdef class Column:
100102

101103
import pandas as pd
102104

103-
check_status(pyarrow.ConvertColumnToPandas(self.sp_column, self, &arr))
104-
return pd.Series(<object>arr, name=self.name)
105+
check_status(pyarrow.ConvertColumnToPandas(self.sp_column,
106+
<PyObject*> self, &arr))
107+
108+
return pd.Series(PyObject_to_object(arr), name=self.name)
105109

106110
cdef _check_nullptr(self):
107111
if self.column == NULL:
@@ -248,9 +252,10 @@ cdef class RecordBatch:
248252
data = []
249253
for i in range(self.batch.num_columns()):
250254
arr = self.batch.column(i)
251-
check_status(pyarrow.ConvertArrayToPandas(arr, self, &np_arr))
255+
check_status(pyarrow.ConvertArrayToPandas(arr, <PyObject*> self,
256+
&np_arr))
252257
names.append(frombytes(self.batch.column_name(i)))
253-
data.append(<object> np_arr)
258+
data.append(PyObject_to_object(np_arr))
254259

255260
return pd.DataFrame(dict(zip(names, data)), columns=names)
256261

@@ -375,9 +380,10 @@ cdef class Table:
375380
for i in range(self.table.num_columns()):
376381
col = self.table.column(i)
377382
column = self.column(i)
378-
check_status(pyarrow.ConvertColumnToPandas(col, column, &arr))
383+
check_status(pyarrow.ConvertColumnToPandas(
384+
col, <PyObject*> column, &arr))
379385
names.append(frombytes(col.get().name()))
380-
data.append(<object> arr)
386+
data.append(PyObject_to_object(arr))
381387

382388
return pd.DataFrame(dict(zip(names, data)), columns=names)
383389

python/pyarrow/tests/test_array.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717

18+
import sys
19+
1820
import pyarrow
1921
import pyarrow.formatting as fmt
2022

@@ -71,3 +73,30 @@ def test_long_array_format():
7173
99
7274
]"""
7375
assert result == expected
76+
77+
78+
def test_to_pandas_zero_copy():
79+
import gc
80+
81+
arr = pyarrow.from_pylist(range(10))
82+
83+
for i in range(10):
84+
np_arr = arr.to_pandas()
85+
assert sys.getrefcount(np_arr) == 2
86+
np_arr = None # noqa
87+
88+
assert sys.getrefcount(arr) == 2
89+
90+
for i in range(10):
91+
arr = pyarrow.from_pylist(range(10))
92+
np_arr = arr.to_pandas()
93+
arr = None
94+
gc.collect()
95+
96+
# Ensure base is still valid
97+
98+
# Because of py.test's assert inspection magic, if you put getrefcount
99+
# on the line being examined, it will be 1 higher than you expect
100+
base_refcount = sys.getrefcount(np_arr.base)
101+
assert base_refcount == 2
102+
np_arr.sum()

python/src/pyarrow/adapters/pandas.cc

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -628,7 +628,6 @@ class ArrowDeserializer {
628628
PyAcquireGIL lock;
629629

630630
// Zero-Copy. We can pass the data pointer directly to NumPy.
631-
OwnedRef py_ref(py_ref_);
632631
npy_intp dims[1] = {col_->length()};
633632
out_ = reinterpret_cast<PyArrayObject*>(PyArray_SimpleNewFromData(1, dims,
634633
type, data));
@@ -645,7 +644,7 @@ class ArrowDeserializer {
645644
return Status::OK();
646645
} else {
647646
// PyArray_SetBaseObject steals our reference to py_ref_
648-
py_ref.release();
647+
Py_INCREF(py_ref_);
649648
}
650649

651650
// Arrow data is immutable.

0 commit comments

Comments
 (0)