From 5a927088355a275cad054c0505b8407d75fad2cc Mon Sep 17 00:00:00 2001 From: Caio Date: Tue, 1 Sep 2020 10:38:37 -0300 Subject: [PATCH] Support for arbitrary arrays --- Cargo.toml | 4 ++++ src/lib.rs | 2 ++ src/types/list.rs | 12 ++++++++++ src/types/mod.rs | 37 ++++++++++++++++++++++++++++++ src/types/sequence.rs | 52 +++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 107 insertions(+) diff --git a/Cargo.toml b/Cargo.toml index e9e14fa6270..6383bab158d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -35,6 +35,10 @@ rustversion = "1.0" [features] default = ["macros"] macros = ["ctor", "indoc", "inventory", "paste", "pyo3cls", "unindent"] + +# Supports arrays of arbitrary size +const-generics = [] + # Optimizes PyObject to Vec conversion and so on. nightly = [] diff --git a/src/lib.rs b/src/lib.rs index 10f3e768f8e..5bd29ae0169 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,3 +1,5 @@ +#![cfg_attr(feature = "const-generics", feature(min_const_generics))] +#![cfg_attr(feature = "nightly", allow(incomplete_features))] #![cfg_attr(feature = "nightly", feature(specialization))] #![allow(clippy::missing_safety_doc)] // FIXME (#698) diff --git a/src/types/list.rs b/src/types/list.rs index 4a0586e4fcd..a23940936b6 100644 --- a/src/types/list.rs +++ b/src/types/list.rs @@ -178,6 +178,17 @@ where } } +#[cfg(feature = "const-generics")] +impl IntoPy for [T; N] +where + T: ToPyObject, +{ + fn into_py(self, py: Python) -> PyObject { + self.as_ref().to_object(py) + } +} + +#[cfg(not(feature = "const-generics"))] macro_rules! array_impls { ($($N:expr),+) => { $( @@ -193,6 +204,7 @@ macro_rules! array_impls { } } +#[cfg(not(feature = "const-generics"))] array_impls!( 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32 diff --git a/src/types/mod.rs b/src/types/mod.rs index 32bbfe8e77e..72f2b782515 100644 --- a/src/types/mod.rs +++ b/src/types/mod.rs @@ -235,3 +235,40 @@ mod slice; mod string; mod tuple; mod typeobject; + +#[cfg(feature = "const-generics")] +struct ArrayGuard { + dst: *mut T, + initialized: usize, +} + +#[cfg(feature = "const-generics")] +impl Drop for ArrayGuard { + fn drop(&mut self) { + debug_assert!(self.initialized <= N); + let initialized_part = core::ptr::slice_from_raw_parts_mut(self.dst, self.initialized); + unsafe { + core::ptr::drop_in_place(initialized_part); + } + } +} + +#[cfg(feature = "const-generics")] +fn try_create_array(mut cb: F) -> Result<[T; N], E> +where + F: FnMut(usize) -> Result, +{ + let mut array: core::mem::MaybeUninit<[T; N]> = core::mem::MaybeUninit::uninit(); + let mut guard: ArrayGuard = ArrayGuard { + dst: array.as_mut_ptr() as _, + initialized: 0, + }; + unsafe { + for (idx, value_ptr) in (&mut *array.as_mut_ptr()).iter_mut().enumerate() { + core::ptr::write(value_ptr, cb(idx)?); + guard.initialized += 1; + } + core::mem::forget(guard); + Ok(array.assume_init()) + } +} diff --git a/src/types/sequence.rs b/src/types/sequence.rs index 1b84352a5b0..f4dc1d78062 100644 --- a/src/types/sequence.rs +++ b/src/types/sequence.rs @@ -257,6 +257,41 @@ impl PySequence { } } +#[cfg(feature = "const-generics")] +impl<'a, T, const N: usize> FromPyObject<'a> for [T; N] +where + T: FromPyObject<'a>, +{ + #[cfg(not(feature = "nightly"))] + fn extract(obj: &'a PyAny) -> PyResult { + create_array_from_obj(obj) + } + + #[cfg(feature = "nightly")] + default fn extract(obj: &'a PyAny) -> PyResult { + create_array_from_obj(obj) + } +} + +#[cfg(all(feature = "const-generics", feature = "nightly"))] +impl<'source, T, const N: usize> FromPyObject<'source> for [T; N] +where + for<'a> T: FromPyObject<'a> + crate::buffer::Element, +{ + fn extract(obj: &'source PyAny) -> PyResult { + let mut array = create_array_from_obj(obj)?; + if let Ok(buf) = crate::buffer::PyBuffer::get(obj) { + if buf.dimensions() == 1 && buf.copy_to_slice(obj.py(), &mut array).is_ok() { + buf.release(obj.py()); + return Ok(array); + } + buf.release(obj.py()); + } + Ok(array) + } +} + +#[cfg(not(feature = "const-generics"))] macro_rules! array_impls { ($($N:expr),+) => { $( @@ -303,6 +338,7 @@ macro_rules! array_impls { } } +#[cfg(not(feature = "const-generics"))] array_impls!( 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32 @@ -343,6 +379,21 @@ where } } +#[cfg(feature = "const-generics")] +fn create_array_from_obj<'s, T, const N: usize>(obj: &'s PyAny) -> PyResult<[T; N]> +where + T: FromPyObject<'s>, +{ + let seq = ::try_from(obj)?; + crate::types::try_create_array(|idx| { + seq.get_item(idx as isize) + .map_err(|_| { + exceptions::PyBufferError::py_err("Slice length does not match buffer length.") + })? + .extract::() + }) +} + fn extract_sequence<'s, T>(obj: &'s PyAny) -> PyResult> where T: FromPyObject<'s>, @@ -355,6 +406,7 @@ where Ok(v) } +#[cfg(not(feature = "const-generics"))] fn extract_sequence_into_slice<'s, T>(obj: &'s PyAny, slice: &mut [T]) -> PyResult<()> where T: FromPyObject<'s>,