Skip to content

Commit

Permalink
Custom drop on PyBufferWrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
kylebarron committed Oct 14, 2024
1 parent 523585b commit 2c6f95a
Showing 1 changed file with 171 additions and 133 deletions.
304 changes: 171 additions & 133 deletions pyo3-arrow/src/buffer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use arrow_array::{
};
use arrow_buffer::{Buffer, ScalarBuffer};
use arrow_schema::Field;
use pyo3::buffer::{ElementType, PyBuffer};
use pyo3::buffer::{Element, ElementType, PyBuffer};
use pyo3::exceptions::PyValueError;
use pyo3::ffi;
use pyo3::prelude::*;
Expand Down Expand Up @@ -73,95 +73,128 @@ impl PyArrowBuffer {
}
}

/// A wrapper around a PyBuffer that applies a custom destructor that checks if the Python
/// interpreter is still initialized before freeing the buffer memory.
#[derive(Debug)]
pub struct PyBufferWrapper<T: Element>(Option<PyBuffer<T>>);

impl<T: Element> PyBufferWrapper<T> {
fn inner(&self) -> PyResult<&PyBuffer<T>> {
self.0
.as_ref()
.ok_or(PyValueError::new_err("Buffer already disposed"))
}
}

impl<T: Element> Drop for PyBufferWrapper<T> {
fn drop(&mut self) {
// Only call the underlying Drop of PyBuffer if the Python interpreter is still
// initialized. Sometimes the Drop can attempt to happen after the Python interpreter was
// already finalized.
// https://github.com/kylebarron/arro3/issues/230
let is_initialized = unsafe { ffi::Py_IsInitialized() };
if let Some(val) = self.0.take() {
if is_initialized == 0 {
std::mem::forget(val);
} else {
std::mem::drop(val);
}
}
}
}

/// An enum over buffer protocol input types.
#[allow(missing_docs)]
#[derive(Debug)]
pub enum AnyBufferProtocol {
UInt8(PyBuffer<u8>),
UInt16(PyBuffer<u16>),
UInt32(PyBuffer<u32>),
UInt64(PyBuffer<u64>),
Int8(PyBuffer<i8>),
Int16(PyBuffer<i16>),
Int32(PyBuffer<i32>),
Int64(PyBuffer<i64>),
Float32(PyBuffer<f32>),
Float64(PyBuffer<f64>),
UInt8(PyBufferWrapper<u8>),
UInt16(PyBufferWrapper<u16>),
UInt32(PyBufferWrapper<u32>),
UInt64(PyBufferWrapper<u64>),
Int8(PyBufferWrapper<i8>),
Int16(PyBufferWrapper<i16>),
Int32(PyBufferWrapper<i32>),
Int64(PyBufferWrapper<i64>),
Float32(PyBufferWrapper<f32>),
Float64(PyBufferWrapper<f64>),
}

impl<'py> FromPyObject<'py> for AnyBufferProtocol {
fn extract_bound(ob: &Bound<'py, PyAny>) -> PyResult<Self> {
if let Ok(buf) = ob.extract::<PyBuffer<u8>>() {
Ok(Self::UInt8(buf))
Ok(Self::UInt8(PyBufferWrapper(Some(buf))))
} else if let Ok(buf) = ob.extract::<PyBuffer<u16>>() {
Ok(Self::UInt16(buf))
Ok(Self::UInt16(PyBufferWrapper(Some(buf))))
} else if let Ok(buf) = ob.extract::<PyBuffer<u32>>() {
Ok(Self::UInt32(buf))
Ok(Self::UInt32(PyBufferWrapper(Some(buf))))
} else if let Ok(buf) = ob.extract::<PyBuffer<u64>>() {
Ok(Self::UInt64(buf))
Ok(Self::UInt64(PyBufferWrapper(Some(buf))))
} else if let Ok(buf) = ob.extract::<PyBuffer<i8>>() {
Ok(Self::Int8(buf))
Ok(Self::Int8(PyBufferWrapper(Some(buf))))
} else if let Ok(buf) = ob.extract::<PyBuffer<i16>>() {
Ok(Self::Int16(buf))
Ok(Self::Int16(PyBufferWrapper(Some(buf))))
} else if let Ok(buf) = ob.extract::<PyBuffer<i32>>() {
Ok(Self::Int32(buf))
Ok(Self::Int32(PyBufferWrapper(Some(buf))))
} else if let Ok(buf) = ob.extract::<PyBuffer<i64>>() {
Ok(Self::Int64(buf))
Ok(Self::Int64(PyBufferWrapper(Some(buf))))
} else if let Ok(buf) = ob.extract::<PyBuffer<f32>>() {
Ok(Self::Float32(buf))
Ok(Self::Float32(PyBufferWrapper(Some(buf))))
} else if let Ok(buf) = ob.extract::<PyBuffer<f64>>() {
Ok(Self::Float64(buf))
Ok(Self::Float64(PyBufferWrapper(Some(buf))))
} else {
Err(PyValueError::new_err("Not a buffer protocol object"))
}
}
}

impl AnyBufferProtocol {
fn buf_ptr(&self) -> *mut raw::c_void {
match self {
Self::UInt8(buf) => buf.buf_ptr(),
Self::UInt16(buf) => buf.buf_ptr(),
Self::UInt32(buf) => buf.buf_ptr(),
Self::UInt64(buf) => buf.buf_ptr(),
Self::Int8(buf) => buf.buf_ptr(),
Self::Int16(buf) => buf.buf_ptr(),
Self::Int32(buf) => buf.buf_ptr(),
Self::Int64(buf) => buf.buf_ptr(),
Self::Float32(buf) => buf.buf_ptr(),
Self::Float64(buf) => buf.buf_ptr(),
}
fn buf_ptr(&self) -> PyResult<*mut raw::c_void> {
let out = match self {
Self::UInt8(buf) => buf.inner()?.buf_ptr(),
Self::UInt16(buf) => buf.inner()?.buf_ptr(),
Self::UInt32(buf) => buf.inner()?.buf_ptr(),
Self::UInt64(buf) => buf.inner()?.buf_ptr(),
Self::Int8(buf) => buf.inner()?.buf_ptr(),
Self::Int16(buf) => buf.inner()?.buf_ptr(),
Self::Int32(buf) => buf.inner()?.buf_ptr(),
Self::Int64(buf) => buf.inner()?.buf_ptr(),
Self::Float32(buf) => buf.inner()?.buf_ptr(),
Self::Float64(buf) => buf.inner()?.buf_ptr(),
};
Ok(out)
}

#[allow(dead_code)]
fn dimensions(&self) -> usize {
match self {
Self::UInt8(buf) => buf.dimensions(),
Self::UInt16(buf) => buf.dimensions(),
Self::UInt32(buf) => buf.dimensions(),
Self::UInt64(buf) => buf.dimensions(),
Self::Int8(buf) => buf.dimensions(),
Self::Int16(buf) => buf.dimensions(),
Self::Int32(buf) => buf.dimensions(),
Self::Int64(buf) => buf.dimensions(),
Self::Float32(buf) => buf.dimensions(),
Self::Float64(buf) => buf.dimensions(),
}
fn dimensions(&self) -> PyResult<usize> {
let out = match self {
Self::UInt8(buf) => buf.inner()?.dimensions(),
Self::UInt16(buf) => buf.inner()?.dimensions(),
Self::UInt32(buf) => buf.inner()?.dimensions(),
Self::UInt64(buf) => buf.inner()?.dimensions(),
Self::Int8(buf) => buf.inner()?.dimensions(),
Self::Int16(buf) => buf.inner()?.dimensions(),
Self::Int32(buf) => buf.inner()?.dimensions(),
Self::Int64(buf) => buf.inner()?.dimensions(),
Self::Float32(buf) => buf.inner()?.dimensions(),
Self::Float64(buf) => buf.inner()?.dimensions(),
};
Ok(out)
}

fn format(&self) -> &CStr {
match self {
Self::UInt8(buf) => buf.format(),
Self::UInt16(buf) => buf.format(),
Self::UInt32(buf) => buf.format(),
Self::UInt64(buf) => buf.format(),
Self::Int8(buf) => buf.format(),
Self::Int16(buf) => buf.format(),
Self::Int32(buf) => buf.format(),
Self::Int64(buf) => buf.format(),
Self::Float32(buf) => buf.format(),
Self::Float64(buf) => buf.format(),
}
fn format(&self) -> PyResult<&CStr> {
let out = match self {
Self::UInt8(buf) => buf.inner()?.format(),
Self::UInt16(buf) => buf.inner()?.format(),
Self::UInt32(buf) => buf.inner()?.format(),
Self::UInt64(buf) => buf.inner()?.format(),
Self::Int8(buf) => buf.inner()?.format(),
Self::Int16(buf) => buf.inner()?.format(),
Self::Int32(buf) => buf.inner()?.format(),
Self::Int64(buf) => buf.inner()?.format(),
Self::Float32(buf) => buf.inner()?.format(),
Self::Float64(buf) => buf.inner()?.format(),
};
Ok(out)
}

/// Consume this and convert to an Arrow [`ArrayRef`].
Expand All @@ -187,7 +220,7 @@ impl AnyBufferProtocol {
pub fn into_arrow_array(self) -> PyArrowResult<ArrayRef> {
self.validate_buffer()?;

let shape = self.shape().to_vec();
let shape = self.shape()?.to_vec();

// Handle multi dimensional arrays by wrapping in FixedSizeLists
if shape.len() == 1 {
Expand All @@ -212,10 +245,10 @@ impl AnyBufferProtocol {
/// In `into_arrow_array` the values will be wrapped in FixedSizeLists if needed for multi
/// dimensional input.
fn into_arrow_values(self) -> PyArrowResult<ArrayRef> {
let len = self.item_count();
let len_bytes = self.len_bytes();
let ptr = NonNull::new(self.buf_ptr() as _).unwrap();
let element_type = ElementType::from_format(self.format());
let len = self.item_count()?;
let len_bytes = self.len_bytes()?;
let ptr = NonNull::new(self.buf_ptr()? as _).unwrap();
let element_type = ElementType::from_format(self.format()?);

// TODO: couldn't get this macro to work with error
// cannot find value `buf` in this scope
Expand Down Expand Up @@ -340,93 +373,98 @@ impl AnyBufferProtocol {
}
}

fn item_count(&self) -> usize {
match self {
Self::UInt8(buf) => buf.item_count(),
Self::UInt16(buf) => buf.item_count(),
Self::UInt32(buf) => buf.item_count(),
Self::UInt64(buf) => buf.item_count(),
Self::Int8(buf) => buf.item_count(),
Self::Int16(buf) => buf.item_count(),
Self::Int32(buf) => buf.item_count(),
Self::Int64(buf) => buf.item_count(),
Self::Float32(buf) => buf.item_count(),
Self::Float64(buf) => buf.item_count(),
}
fn item_count(&self) -> PyResult<usize> {
let out = match self {
Self::UInt8(buf) => buf.inner()?.item_count(),
Self::UInt16(buf) => buf.inner()?.item_count(),
Self::UInt32(buf) => buf.inner()?.item_count(),
Self::UInt64(buf) => buf.inner()?.item_count(),
Self::Int8(buf) => buf.inner()?.item_count(),
Self::Int16(buf) => buf.inner()?.item_count(),
Self::Int32(buf) => buf.inner()?.item_count(),
Self::Int64(buf) => buf.inner()?.item_count(),
Self::Float32(buf) => buf.inner()?.item_count(),
Self::Float64(buf) => buf.inner()?.item_count(),
};
Ok(out)
}

fn is_c_contiguous(&self) -> bool {
match self {
Self::UInt8(buf) => buf.is_c_contiguous(),
Self::UInt16(buf) => buf.is_c_contiguous(),
Self::UInt32(buf) => buf.is_c_contiguous(),
Self::UInt64(buf) => buf.is_c_contiguous(),
Self::Int8(buf) => buf.is_c_contiguous(),
Self::Int16(buf) => buf.is_c_contiguous(),
Self::Int32(buf) => buf.is_c_contiguous(),
Self::Int64(buf) => buf.is_c_contiguous(),
Self::Float32(buf) => buf.is_c_contiguous(),
Self::Float64(buf) => buf.is_c_contiguous(),
}
fn is_c_contiguous(&self) -> PyResult<bool> {
let out = match self {
Self::UInt8(buf) => buf.inner()?.is_c_contiguous(),
Self::UInt16(buf) => buf.inner()?.is_c_contiguous(),
Self::UInt32(buf) => buf.inner()?.is_c_contiguous(),
Self::UInt64(buf) => buf.inner()?.is_c_contiguous(),
Self::Int8(buf) => buf.inner()?.is_c_contiguous(),
Self::Int16(buf) => buf.inner()?.is_c_contiguous(),
Self::Int32(buf) => buf.inner()?.is_c_contiguous(),
Self::Int64(buf) => buf.inner()?.is_c_contiguous(),
Self::Float32(buf) => buf.inner()?.is_c_contiguous(),
Self::Float64(buf) => buf.inner()?.is_c_contiguous(),
};
Ok(out)
}

fn len_bytes(&self) -> usize {
match self {
Self::UInt8(buf) => buf.len_bytes(),
Self::UInt16(buf) => buf.len_bytes(),
Self::UInt32(buf) => buf.len_bytes(),
Self::UInt64(buf) => buf.len_bytes(),
Self::Int8(buf) => buf.len_bytes(),
Self::Int16(buf) => buf.len_bytes(),
Self::Int32(buf) => buf.len_bytes(),
Self::Int64(buf) => buf.len_bytes(),
Self::Float32(buf) => buf.len_bytes(),
Self::Float64(buf) => buf.len_bytes(),
}
fn len_bytes(&self) -> PyResult<usize> {
let out = match self {
Self::UInt8(buf) => buf.inner()?.len_bytes(),
Self::UInt16(buf) => buf.inner()?.len_bytes(),
Self::UInt32(buf) => buf.inner()?.len_bytes(),
Self::UInt64(buf) => buf.inner()?.len_bytes(),
Self::Int8(buf) => buf.inner()?.len_bytes(),
Self::Int16(buf) => buf.inner()?.len_bytes(),
Self::Int32(buf) => buf.inner()?.len_bytes(),
Self::Int64(buf) => buf.inner()?.len_bytes(),
Self::Float32(buf) => buf.inner()?.len_bytes(),
Self::Float64(buf) => buf.inner()?.len_bytes(),
};
Ok(out)
}

fn shape(&self) -> &[usize] {
match self {
Self::UInt8(buf) => buf.shape(),
Self::UInt16(buf) => buf.shape(),
Self::UInt32(buf) => buf.shape(),
Self::UInt64(buf) => buf.shape(),
Self::Int8(buf) => buf.shape(),
Self::Int16(buf) => buf.shape(),
Self::Int32(buf) => buf.shape(),
Self::Int64(buf) => buf.shape(),
Self::Float32(buf) => buf.shape(),
Self::Float64(buf) => buf.shape(),
}
fn shape(&self) -> PyResult<&[usize]> {
let out = match self {
Self::UInt8(buf) => buf.inner()?.shape(),
Self::UInt16(buf) => buf.inner()?.shape(),
Self::UInt32(buf) => buf.inner()?.shape(),
Self::UInt64(buf) => buf.inner()?.shape(),
Self::Int8(buf) => buf.inner()?.shape(),
Self::Int16(buf) => buf.inner()?.shape(),
Self::Int32(buf) => buf.inner()?.shape(),
Self::Int64(buf) => buf.inner()?.shape(),
Self::Float32(buf) => buf.inner()?.shape(),
Self::Float64(buf) => buf.inner()?.shape(),
};
Ok(out)
}

fn strides(&self) -> &[isize] {
match self {
Self::UInt8(buf) => buf.strides(),
Self::UInt16(buf) => buf.strides(),
Self::UInt32(buf) => buf.strides(),
Self::UInt64(buf) => buf.strides(),
Self::Int8(buf) => buf.strides(),
Self::Int16(buf) => buf.strides(),
Self::Int32(buf) => buf.strides(),
Self::Int64(buf) => buf.strides(),
Self::Float32(buf) => buf.strides(),
Self::Float64(buf) => buf.strides(),
}
fn strides(&self) -> PyResult<&[isize]> {
let out = match self {
Self::UInt8(buf) => buf.inner()?.strides(),
Self::UInt16(buf) => buf.inner()?.strides(),
Self::UInt32(buf) => buf.inner()?.strides(),
Self::UInt64(buf) => buf.inner()?.strides(),
Self::Int8(buf) => buf.inner()?.strides(),
Self::Int16(buf) => buf.inner()?.strides(),
Self::Int32(buf) => buf.inner()?.strides(),
Self::Int64(buf) => buf.inner()?.strides(),
Self::Float32(buf) => buf.inner()?.strides(),
Self::Float64(buf) => buf.inner()?.strides(),
};
Ok(out)
}

fn validate_buffer(&self) -> PyArrowResult<()> {
if !self.is_c_contiguous() {
if !self.is_c_contiguous()? {
return Err(PyValueError::new_err("Buffer is not C contiguous").into());
}

if self.shape().iter().any(|s| *s == 0) {
if self.shape()?.iter().any(|s| *s == 0) {
return Err(
PyValueError::new_err("0-length dimension not currently supported.").into(),
);
}

if self.strides().iter().any(|s| *s == 0) {
if self.strides()?.iter().any(|s| *s == 0) {
return Err(PyValueError::new_err("Non-zero strides not currently supported.").into());
}

Expand Down

0 comments on commit 2c6f95a

Please sign in to comment.