Skip to content

Commit 6f413da

Browse files
committed
clean up array conversions
1 parent 459e84d commit 6f413da

File tree

1 file changed

+47
-89
lines changed

1 file changed

+47
-89
lines changed

xc3_model_py/src/map_py.rs

Lines changed: 47 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ map_py_into_impl!(Vec3, [f32; 3]);
7272
macro_rules! map_py_pyobject_ndarray_impl {
7373
($($t:ty),*) => {
7474
$(
75+
// 1D arrays
7576
impl MapPy<Py<PyArray1<$t>>> for Vec<$t> {
7677
fn map_py(&self, py: Python) -> PyResult<Py<PyArray1<$t>>> {
7778
Ok(self.to_pyarray(py).into())
@@ -85,6 +86,7 @@ macro_rules! map_py_pyobject_ndarray_impl {
8586
}
8687
}
8788

89+
// 1D untyped arrays
8890
impl MapPy<Py<PyUntypedArray>> for Vec<$t> {
8991
fn map_py(&self, py: Python) -> PyResult<Py<PyUntypedArray>> {
9092
let arr: Py<PyArray1<$t>> = self.map_py(py)?;
@@ -98,56 +100,57 @@ macro_rules! map_py_pyobject_ndarray_impl {
98100
arr.as_unbound().map_py(py)
99101
}
100102
}
101-
)*
102-
}
103-
}
104103

105-
map_py_pyobject_ndarray_impl!(u16, u32, f32);
104+
// 2D arrays
105+
impl<const N: usize> MapPy<Py<PyArray2<$t>>> for Vec<[$t; N]> {
106+
fn map_py(&self, py: Python) -> PyResult<Py<PyArray2<$t>>> {
107+
// This flatten will be optimized in Release mode.
108+
// This avoids needing unsafe code.
109+
let count = self.len();
110+
Ok(self
111+
.iter()
112+
.flatten()
113+
.copied()
114+
.collect::<Vec<$t>>()
115+
.into_pyarray(py)
116+
.reshape((count, N))
117+
.unwrap()
118+
.into())
119+
}
120+
}
106121

107-
// TODO: Share code with other primitive arrays?
108-
impl MapPy<Py<PyArray2<u8>>> for Vec<[u8; 4]> {
109-
fn map_py(&self, py: Python) -> PyResult<Py<PyArray2<u8>>> {
110-
// This flatten will be optimized in Release mode.
111-
// This avoids needing unsafe code.
112-
let count = self.len();
113-
Ok(self
114-
.iter()
115-
.flatten()
116-
.copied()
117-
.collect::<Vec<u8>>()
118-
.into_pyarray(py)
119-
.reshape((count, 4))
120-
.unwrap()
121-
.into())
122-
}
123-
}
122+
impl<const N: usize> MapPy<Vec<[$t; N]>> for Py<PyArray2<$t>> {
123+
fn map_py(&self, py: Python) -> PyResult<Vec<[$t; N]>> {
124+
let array = self.as_any().downcast_bound::<PyArray2<$t>>(py)?;
125+
Ok(array
126+
.readonly()
127+
.as_array()
128+
.rows()
129+
.into_iter()
130+
.map(|r| r.as_slice().unwrap().try_into().unwrap())
131+
.collect())
132+
}
133+
}
124134

125-
impl MapPy<Vec<[u8; 4]>> for Py<PyArray2<u8>> {
126-
fn map_py(&self, py: Python) -> PyResult<Vec<[u8; 4]>> {
127-
let array = self.as_any().downcast_bound::<PyArray2<u8>>(py)?;
128-
Ok(array
129-
.readonly()
130-
.as_array()
131-
.rows()
132-
.into_iter()
133-
.map(|r| r.as_slice().unwrap().try_into().unwrap())
134-
.collect())
135-
}
136-
}
135+
// 2D untyped arrrays
136+
impl<const N: usize> MapPy<Py<PyUntypedArray>> for Vec<[$t; N]> {
137+
fn map_py(&self, py: Python) -> PyResult<Py<PyUntypedArray>> {
138+
let arr: Py<PyArray2<$t>> = self.map_py(py)?;
139+
Ok(arr.bind(py).as_untyped().clone().unbind())
140+
}
141+
}
137142

138-
impl MapPy<Py<PyUntypedArray>> for Vec<[u8; 4]> {
139-
fn map_py(&self, py: Python) -> PyResult<Py<PyUntypedArray>> {
140-
let arr: Py<PyArray2<u8>> = self.map_py(py)?;
141-
Ok(arr.bind(py).as_untyped().clone().unbind())
143+
impl<const N: usize> MapPy<Vec<[$t; N]>> for Py<PyUntypedArray> {
144+
fn map_py(&self, py: Python) -> PyResult<Vec<[$t; N]>> {
145+
let arr = self.bind(py).downcast::<PyArray2<$t>>()?;
146+
arr.as_unbound().map_py(py)
147+
}
148+
}
149+
)*
142150
}
143151
}
144152

145-
impl MapPy<Vec<[u8; 4]>> for Py<PyUntypedArray> {
146-
fn map_py(&self, py: Python) -> PyResult<Vec<[u8; 4]>> {
147-
let arr = self.bind(py).downcast::<PyArray2<u8>>()?;
148-
arr.as_unbound().map_py(py)
149-
}
150-
}
153+
map_py_pyobject_ndarray_impl!(u8, u16, u32, f32);
151154

152155
impl<T, U> MapPy<Option<U>> for Option<T>
153156
where
@@ -160,7 +163,7 @@ where
160163

161164
// TODO: how to implement for Py<T>?
162165

163-
// TODO: Derive for each type to avoid overlapping definitions with numpy?
166+
// TODO: Blanket impl without overlap?
164167
impl MapPy<Vec<String>> for Py<PyList> {
165168
fn map_py(&self, py: Python) -> PyResult<Vec<String>> {
166169
self.extract(py)
@@ -225,50 +228,6 @@ map_py_vecn_ndarray_impl!(Vec2, 2);
225228
map_py_vecn_ndarray_impl!(Vec3, 3);
226229
map_py_vecn_ndarray_impl!(Vec4, 4);
227230

228-
impl MapPy<Py<PyArray2<u16>>> for Vec<[u16; 2]> {
229-
fn map_py(&self, py: Python) -> PyResult<Py<PyArray2<u16>>> {
230-
// This flatten will be optimized in Release mode.
231-
// This avoids needing unsafe code.
232-
let count = self.len();
233-
Ok(self
234-
.iter()
235-
.flatten()
236-
.copied()
237-
.collect::<Vec<u16>>()
238-
.into_pyarray(py)
239-
.reshape((count, 2))
240-
.unwrap()
241-
.into())
242-
}
243-
}
244-
245-
impl MapPy<Vec<[u16; 2]>> for Py<PyArray2<u16>> {
246-
fn map_py(&self, py: Python) -> PyResult<Vec<[u16; 2]>> {
247-
let array = self.as_any().downcast_bound::<PyArray2<u16>>(py)?;
248-
Ok(array
249-
.readonly()
250-
.as_array()
251-
.rows()
252-
.into_iter()
253-
.map(|r| r.as_slice().unwrap().try_into().unwrap())
254-
.collect())
255-
}
256-
}
257-
258-
impl MapPy<Py<PyUntypedArray>> for Vec<[u16; 2]> {
259-
fn map_py(&self, py: Python) -> PyResult<Py<PyUntypedArray>> {
260-
let arr: Py<PyArray2<u16>> = self.map_py(py)?;
261-
Ok(arr.bind(py).as_untyped().clone().unbind())
262-
}
263-
}
264-
265-
impl MapPy<Vec<[u16; 2]>> for Py<PyUntypedArray> {
266-
fn map_py(&self, py: Python) -> PyResult<Vec<[u16; 2]>> {
267-
let arr = self.bind(py).downcast::<PyArray2<u16>>()?;
268-
arr.as_unbound().map_py(py)
269-
}
270-
}
271-
272231
impl MapPy<Py<PyArray2<f32>>> for Mat4 {
273232
fn map_py(&self, py: Python) -> PyResult<Py<PyArray2<f32>>> {
274233
// TODO: Should this be transposed since numpy is row-major?
@@ -336,7 +295,6 @@ impl MapPy<String> for SmolStr {
336295
}
337296
}
338297

339-
// TODO: const generics?
340298
impl<T, U, const N: usize> MapPy<[U; N]> for [T; N]
341299
where
342300
T: MapPy<U>,

0 commit comments

Comments
 (0)