Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom drop on PyBufferWrapper #231

Merged
merged 2 commits into from
Oct 14, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
305 changes: 172 additions & 133 deletions pyo3-arrow/src/buffer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use arrow_array::{
};
use arrow_buffer::{Buffer, ScalarBuffer};
use arrow_schema::Field;
use pyo3::buffer::{ElementType, PyBuffer};
use pyo3::buffer::{Element, ElementType, PyBuffer};
use pyo3::exceptions::PyValueError;
use pyo3::ffi;
use pyo3::prelude::*;
Expand Down Expand Up @@ -73,95 +73,128 @@ impl PyArrowBuffer {
}
}

/// A wrapper around a PyBuffer that applies a custom destructor that checks if the Python
/// interpreter is still initialized before freeing the buffer memory.
#[derive(Debug)]
pub struct PyBufferWrapper<T: Element>(Option<PyBuffer<T>>);

impl<T: Element> PyBufferWrapper<T> {
fn inner(&self) -> PyResult<&PyBuffer<T>> {
self.0
.as_ref()
.ok_or(PyValueError::new_err("Buffer already disposed"))
}
}

impl<T: Element> Drop for PyBufferWrapper<T> {
fn drop(&mut self) {
// Only call the underlying Drop of PyBuffer if the Python interpreter is still
// initialized. Sometimes the Drop can attempt to happen after the Python interpreter was
// already finalized.
// https://github.com/kylebarron/arro3/issues/230
let is_initialized = unsafe { ffi::Py_IsInitialized() };
if let Some(val) = self.0.take() {
if is_initialized == 0 {
std::mem::forget(val);
} else {
std::mem::drop(val);
}
}
}
}

/// An enum over buffer protocol input types.
#[allow(missing_docs)]
#[derive(Debug)]
pub enum AnyBufferProtocol {
UInt8(PyBuffer<u8>),
UInt16(PyBuffer<u16>),
UInt32(PyBuffer<u32>),
UInt64(PyBuffer<u64>),
Int8(PyBuffer<i8>),
Int16(PyBuffer<i16>),
Int32(PyBuffer<i32>),
Int64(PyBuffer<i64>),
Float32(PyBuffer<f32>),
Float64(PyBuffer<f64>),
UInt8(PyBufferWrapper<u8>),
UInt16(PyBufferWrapper<u16>),
UInt32(PyBufferWrapper<u32>),
UInt64(PyBufferWrapper<u64>),
Int8(PyBufferWrapper<i8>),
Int16(PyBufferWrapper<i16>),
Int32(PyBufferWrapper<i32>),
Int64(PyBufferWrapper<i64>),
Float32(PyBufferWrapper<f32>),
Float64(PyBufferWrapper<f64>),
}

impl<'py> FromPyObject<'py> for AnyBufferProtocol {
fn extract_bound(ob: &Bound<'py, PyAny>) -> PyResult<Self> {
if let Ok(buf) = ob.extract::<PyBuffer<u8>>() {
Ok(Self::UInt8(buf))
Ok(Self::UInt8(PyBufferWrapper(Some(buf))))
} else if let Ok(buf) = ob.extract::<PyBuffer<u16>>() {
Ok(Self::UInt16(buf))
Ok(Self::UInt16(PyBufferWrapper(Some(buf))))
} else if let Ok(buf) = ob.extract::<PyBuffer<u32>>() {
Ok(Self::UInt32(buf))
Ok(Self::UInt32(PyBufferWrapper(Some(buf))))
} else if let Ok(buf) = ob.extract::<PyBuffer<u64>>() {
Ok(Self::UInt64(buf))
Ok(Self::UInt64(PyBufferWrapper(Some(buf))))
} else if let Ok(buf) = ob.extract::<PyBuffer<i8>>() {
Ok(Self::Int8(buf))
Ok(Self::Int8(PyBufferWrapper(Some(buf))))
} else if let Ok(buf) = ob.extract::<PyBuffer<i16>>() {
Ok(Self::Int16(buf))
Ok(Self::Int16(PyBufferWrapper(Some(buf))))
} else if let Ok(buf) = ob.extract::<PyBuffer<i32>>() {
Ok(Self::Int32(buf))
Ok(Self::Int32(PyBufferWrapper(Some(buf))))
} else if let Ok(buf) = ob.extract::<PyBuffer<i64>>() {
Ok(Self::Int64(buf))
Ok(Self::Int64(PyBufferWrapper(Some(buf))))
} else if let Ok(buf) = ob.extract::<PyBuffer<f32>>() {
Ok(Self::Float32(buf))
Ok(Self::Float32(PyBufferWrapper(Some(buf))))
} else if let Ok(buf) = ob.extract::<PyBuffer<f64>>() {
Ok(Self::Float64(buf))
Ok(Self::Float64(PyBufferWrapper(Some(buf))))
} else {
Err(PyValueError::new_err("Not a buffer protocol object"))
}
}
}

impl AnyBufferProtocol {
fn buf_ptr(&self) -> *mut raw::c_void {
match self {
Self::UInt8(buf) => buf.buf_ptr(),
Self::UInt16(buf) => buf.buf_ptr(),
Self::UInt32(buf) => buf.buf_ptr(),
Self::UInt64(buf) => buf.buf_ptr(),
Self::Int8(buf) => buf.buf_ptr(),
Self::Int16(buf) => buf.buf_ptr(),
Self::Int32(buf) => buf.buf_ptr(),
Self::Int64(buf) => buf.buf_ptr(),
Self::Float32(buf) => buf.buf_ptr(),
Self::Float64(buf) => buf.buf_ptr(),
}
fn buf_ptr(&self) -> PyResult<*mut raw::c_void> {
let out = match self {
Self::UInt8(buf) => buf.inner()?.buf_ptr(),
Self::UInt16(buf) => buf.inner()?.buf_ptr(),
Self::UInt32(buf) => buf.inner()?.buf_ptr(),
Self::UInt64(buf) => buf.inner()?.buf_ptr(),
Self::Int8(buf) => buf.inner()?.buf_ptr(),
Self::Int16(buf) => buf.inner()?.buf_ptr(),
Self::Int32(buf) => buf.inner()?.buf_ptr(),
Self::Int64(buf) => buf.inner()?.buf_ptr(),
Self::Float32(buf) => buf.inner()?.buf_ptr(),
Self::Float64(buf) => buf.inner()?.buf_ptr(),
};
Ok(out)
}

#[allow(dead_code)]
fn dimensions(&self) -> usize {
match self {
Self::UInt8(buf) => buf.dimensions(),
Self::UInt16(buf) => buf.dimensions(),
Self::UInt32(buf) => buf.dimensions(),
Self::UInt64(buf) => buf.dimensions(),
Self::Int8(buf) => buf.dimensions(),
Self::Int16(buf) => buf.dimensions(),
Self::Int32(buf) => buf.dimensions(),
Self::Int64(buf) => buf.dimensions(),
Self::Float32(buf) => buf.dimensions(),
Self::Float64(buf) => buf.dimensions(),
}
fn dimensions(&self) -> PyResult<usize> {
let out = match self {
Self::UInt8(buf) => buf.inner()?.dimensions(),
Self::UInt16(buf) => buf.inner()?.dimensions(),
Self::UInt32(buf) => buf.inner()?.dimensions(),
Self::UInt64(buf) => buf.inner()?.dimensions(),
Self::Int8(buf) => buf.inner()?.dimensions(),
Self::Int16(buf) => buf.inner()?.dimensions(),
Self::Int32(buf) => buf.inner()?.dimensions(),
Self::Int64(buf) => buf.inner()?.dimensions(),
Self::Float32(buf) => buf.inner()?.dimensions(),
Self::Float64(buf) => buf.inner()?.dimensions(),
};
Ok(out)
}

fn format(&self) -> &CStr {
match self {
Self::UInt8(buf) => buf.format(),
Self::UInt16(buf) => buf.format(),
Self::UInt32(buf) => buf.format(),
Self::UInt64(buf) => buf.format(),
Self::Int8(buf) => buf.format(),
Self::Int16(buf) => buf.format(),
Self::Int32(buf) => buf.format(),
Self::Int64(buf) => buf.format(),
Self::Float32(buf) => buf.format(),
Self::Float64(buf) => buf.format(),
}
fn format(&self) -> PyResult<&CStr> {
let out = match self {
Self::UInt8(buf) => buf.inner()?.format(),
Self::UInt16(buf) => buf.inner()?.format(),
Self::UInt32(buf) => buf.inner()?.format(),
Self::UInt64(buf) => buf.inner()?.format(),
Self::Int8(buf) => buf.inner()?.format(),
Self::Int16(buf) => buf.inner()?.format(),
Self::Int32(buf) => buf.inner()?.format(),
Self::Int64(buf) => buf.inner()?.format(),
Self::Float32(buf) => buf.inner()?.format(),
Self::Float64(buf) => buf.inner()?.format(),
};
Ok(out)
}

/// Consume this and convert to an Arrow [`ArrayRef`].
Expand All @@ -187,7 +220,7 @@ impl AnyBufferProtocol {
pub fn into_arrow_array(self) -> PyArrowResult<ArrayRef> {
self.validate_buffer()?;

let shape = self.shape().to_vec();
let shape = self.shape()?.to_vec();

// Handle multi dimensional arrays by wrapping in FixedSizeLists
if shape.len() == 1 {
Expand All @@ -212,10 +245,11 @@ impl AnyBufferProtocol {
/// In `into_arrow_array` the values will be wrapped in FixedSizeLists if needed for multi
/// dimensional input.
fn into_arrow_values(self) -> PyArrowResult<ArrayRef> {
let len = self.item_count();
let len_bytes = self.len_bytes();
let ptr = NonNull::new(self.buf_ptr() as _).unwrap();
let element_type = ElementType::from_format(self.format());
let len = self.item_count()?;
let len_bytes = self.len_bytes()?;
let ptr = NonNull::new(self.buf_ptr()? as _)
.ok_or(PyValueError::new_err("Expected buffer ptr to be non null"))?;
let element_type = ElementType::from_format(self.format()?);

// TODO: couldn't get this macro to work with error
// cannot find value `buf` in this scope
Expand Down Expand Up @@ -340,93 +374,98 @@ impl AnyBufferProtocol {
}
}

fn item_count(&self) -> usize {
match self {
Self::UInt8(buf) => buf.item_count(),
Self::UInt16(buf) => buf.item_count(),
Self::UInt32(buf) => buf.item_count(),
Self::UInt64(buf) => buf.item_count(),
Self::Int8(buf) => buf.item_count(),
Self::Int16(buf) => buf.item_count(),
Self::Int32(buf) => buf.item_count(),
Self::Int64(buf) => buf.item_count(),
Self::Float32(buf) => buf.item_count(),
Self::Float64(buf) => buf.item_count(),
}
fn item_count(&self) -> PyResult<usize> {
let out = match self {
Self::UInt8(buf) => buf.inner()?.item_count(),
Self::UInt16(buf) => buf.inner()?.item_count(),
Self::UInt32(buf) => buf.inner()?.item_count(),
Self::UInt64(buf) => buf.inner()?.item_count(),
Self::Int8(buf) => buf.inner()?.item_count(),
Self::Int16(buf) => buf.inner()?.item_count(),
Self::Int32(buf) => buf.inner()?.item_count(),
Self::Int64(buf) => buf.inner()?.item_count(),
Self::Float32(buf) => buf.inner()?.item_count(),
Self::Float64(buf) => buf.inner()?.item_count(),
};
Ok(out)
}

fn is_c_contiguous(&self) -> bool {
match self {
Self::UInt8(buf) => buf.is_c_contiguous(),
Self::UInt16(buf) => buf.is_c_contiguous(),
Self::UInt32(buf) => buf.is_c_contiguous(),
Self::UInt64(buf) => buf.is_c_contiguous(),
Self::Int8(buf) => buf.is_c_contiguous(),
Self::Int16(buf) => buf.is_c_contiguous(),
Self::Int32(buf) => buf.is_c_contiguous(),
Self::Int64(buf) => buf.is_c_contiguous(),
Self::Float32(buf) => buf.is_c_contiguous(),
Self::Float64(buf) => buf.is_c_contiguous(),
}
fn is_c_contiguous(&self) -> PyResult<bool> {
let out = match self {
Self::UInt8(buf) => buf.inner()?.is_c_contiguous(),
Self::UInt16(buf) => buf.inner()?.is_c_contiguous(),
Self::UInt32(buf) => buf.inner()?.is_c_contiguous(),
Self::UInt64(buf) => buf.inner()?.is_c_contiguous(),
Self::Int8(buf) => buf.inner()?.is_c_contiguous(),
Self::Int16(buf) => buf.inner()?.is_c_contiguous(),
Self::Int32(buf) => buf.inner()?.is_c_contiguous(),
Self::Int64(buf) => buf.inner()?.is_c_contiguous(),
Self::Float32(buf) => buf.inner()?.is_c_contiguous(),
Self::Float64(buf) => buf.inner()?.is_c_contiguous(),
};
Ok(out)
}

fn len_bytes(&self) -> usize {
match self {
Self::UInt8(buf) => buf.len_bytes(),
Self::UInt16(buf) => buf.len_bytes(),
Self::UInt32(buf) => buf.len_bytes(),
Self::UInt64(buf) => buf.len_bytes(),
Self::Int8(buf) => buf.len_bytes(),
Self::Int16(buf) => buf.len_bytes(),
Self::Int32(buf) => buf.len_bytes(),
Self::Int64(buf) => buf.len_bytes(),
Self::Float32(buf) => buf.len_bytes(),
Self::Float64(buf) => buf.len_bytes(),
}
fn len_bytes(&self) -> PyResult<usize> {
let out = match self {
Self::UInt8(buf) => buf.inner()?.len_bytes(),
Self::UInt16(buf) => buf.inner()?.len_bytes(),
Self::UInt32(buf) => buf.inner()?.len_bytes(),
Self::UInt64(buf) => buf.inner()?.len_bytes(),
Self::Int8(buf) => buf.inner()?.len_bytes(),
Self::Int16(buf) => buf.inner()?.len_bytes(),
Self::Int32(buf) => buf.inner()?.len_bytes(),
Self::Int64(buf) => buf.inner()?.len_bytes(),
Self::Float32(buf) => buf.inner()?.len_bytes(),
Self::Float64(buf) => buf.inner()?.len_bytes(),
};
Ok(out)
}

fn shape(&self) -> &[usize] {
match self {
Self::UInt8(buf) => buf.shape(),
Self::UInt16(buf) => buf.shape(),
Self::UInt32(buf) => buf.shape(),
Self::UInt64(buf) => buf.shape(),
Self::Int8(buf) => buf.shape(),
Self::Int16(buf) => buf.shape(),
Self::Int32(buf) => buf.shape(),
Self::Int64(buf) => buf.shape(),
Self::Float32(buf) => buf.shape(),
Self::Float64(buf) => buf.shape(),
}
fn shape(&self) -> PyResult<&[usize]> {
let out = match self {
Self::UInt8(buf) => buf.inner()?.shape(),
Self::UInt16(buf) => buf.inner()?.shape(),
Self::UInt32(buf) => buf.inner()?.shape(),
Self::UInt64(buf) => buf.inner()?.shape(),
Self::Int8(buf) => buf.inner()?.shape(),
Self::Int16(buf) => buf.inner()?.shape(),
Self::Int32(buf) => buf.inner()?.shape(),
Self::Int64(buf) => buf.inner()?.shape(),
Self::Float32(buf) => buf.inner()?.shape(),
Self::Float64(buf) => buf.inner()?.shape(),
};
Ok(out)
}

fn strides(&self) -> &[isize] {
match self {
Self::UInt8(buf) => buf.strides(),
Self::UInt16(buf) => buf.strides(),
Self::UInt32(buf) => buf.strides(),
Self::UInt64(buf) => buf.strides(),
Self::Int8(buf) => buf.strides(),
Self::Int16(buf) => buf.strides(),
Self::Int32(buf) => buf.strides(),
Self::Int64(buf) => buf.strides(),
Self::Float32(buf) => buf.strides(),
Self::Float64(buf) => buf.strides(),
}
fn strides(&self) -> PyResult<&[isize]> {
let out = match self {
Self::UInt8(buf) => buf.inner()?.strides(),
Self::UInt16(buf) => buf.inner()?.strides(),
Self::UInt32(buf) => buf.inner()?.strides(),
Self::UInt64(buf) => buf.inner()?.strides(),
Self::Int8(buf) => buf.inner()?.strides(),
Self::Int16(buf) => buf.inner()?.strides(),
Self::Int32(buf) => buf.inner()?.strides(),
Self::Int64(buf) => buf.inner()?.strides(),
Self::Float32(buf) => buf.inner()?.strides(),
Self::Float64(buf) => buf.inner()?.strides(),
};
Ok(out)
}

fn validate_buffer(&self) -> PyArrowResult<()> {
if !self.is_c_contiguous() {
if !self.is_c_contiguous()? {
return Err(PyValueError::new_err("Buffer is not C contiguous").into());
}

if self.shape().iter().any(|s| *s == 0) {
if self.shape()?.iter().any(|s| *s == 0) {
return Err(
PyValueError::new_err("0-length dimension not currently supported.").into(),
);
}

if self.strides().iter().any(|s| *s == 0) {
if self.strides()?.iter().any(|s| *s == 0) {
return Err(PyValueError::new_err("Non-zero strides not currently supported.").into());
}

Expand Down