Skip to content

Commit 8e9c17d

Browse files
committed
WIP: Use the GCD of the strides to check if two array views could alias.
1 parent 49b5c09 commit 8e9c17d

File tree

4 files changed

+77
-33
lines changed

4 files changed

+77
-33
lines changed

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ license = "BSD-2-Clause"
1717
[dependencies]
1818
libc = "0.2"
1919
num-complex = ">= 0.2, < 0.5"
20+
num-integer = "0.1"
2021
num-traits = "0.2"
2122
ndarray = ">= 0.13, < 0.16"
2223
pyo3 = { version = "0.16", default-features = false, features = ["macros"] }

src/array.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -340,9 +340,9 @@ impl<T, D> PyArray<T, D> {
340340
}
341341

342342
/// Returns the pointer to the first element of the inner array.
343-
pub(crate) unsafe fn data(&self) -> *mut T {
343+
pub(crate) fn data(&self) -> *mut T {
344344
let ptr = self.as_array_ptr();
345-
(*ptr).data as *mut _
345+
unsafe { (*ptr).data as *mut _ }
346346
}
347347
}
348348

@@ -381,7 +381,7 @@ impl<T: Element, D: Dimension> PyArray<T, D> {
381381
let strides = self.strides();
382382

383383
let mut new_strides = D::zeros(strides.len());
384-
let mut data_ptr = unsafe { self.data() };
384+
let mut data_ptr = self.data();
385385
let mut inverted_axes = InvertedAxes::new(strides.len());
386386

387387
for i in 0..strides.len() {

src/borrow.rs

Lines changed: 54 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,8 @@
125125
//!
126126
//! # Limitations
127127
//!
128+
//! TODO: We only leave the case of aliasing, but only out of bounds. Can this actually happen for array views?
129+
//!
128130
//! Note that the current implementation of this is an over-approximation: It will consider overlapping borrows
129131
//! potentially conflicting if the initial arrays have the same object at the end of their [base object chain][base].
130132
//! For example, creating two views of the same underlying array by slicing can yield potentially conflicting borrows
@@ -143,6 +145,7 @@ use std::collections::hash_map::{Entry, HashMap};
143145
use std::ops::{Deref, Range};
144146

145147
use ndarray::{ArrayView, ArrayViewMut, Dimension, Ix1, Ix2, Ix3, Ix4, Ix5, Ix6, IxDyn};
148+
use num_integer::gcd;
146149
use pyo3::{FromPyObject, PyAny, PyResult};
147150

148151
use crate::array::PyArray;
@@ -155,9 +158,28 @@ use crate::npyffi::{self, PyArrayObject, NPY_ARRAY_WRITEABLE};
155158
#[derive(PartialEq, Eq, Hash)]
156159
struct BorrowKey {
157160
range: Range<usize>,
161+
data_ptr: usize,
162+
gcd_strides: isize,
158163
}
159164

160165
impl BorrowKey {
166+
fn from_array<T, D>(array: &PyArray<T, D>) -> Self
167+
where
168+
T: Element,
169+
D: Dimension,
170+
{
171+
let range = data_range(array);
172+
173+
let data_ptr = array.data() as usize;
174+
let gcd_strides = array.strides().iter().copied().reduce(gcd).unwrap_or(1);
175+
176+
Self {
177+
range,
178+
data_ptr,
179+
gcd_strides,
180+
}
181+
}
182+
161183
fn conflicts(&self, other: &Self) -> bool {
162184
debug_assert!(self.range.start <= self.range.end);
163185
debug_assert!(other.range.start <= other.range.end);
@@ -166,6 +188,20 @@ impl BorrowKey {
166188
return false;
167189
}
168190

191+
// The Diophantine equation which describes whether any integers can combine the data pointers and strides of the two arrays
192+
// so that they yield the same element has a solution if and only if the GCD of all strides divides the difference of the data pointers.
193+
//
194+
// That solution could be out of bounds which mean that this still an approximation,
195+
// but it seems sufficient to handle typical cases like the color channels of an image.
196+
//
197+
// https://users.rust-lang.org/t/math-for-borrow-checking-numpy-arrays/73303
198+
let data_diff = abs_diff(self.data_ptr, other.data_ptr) as isize;
199+
let gcd_strides = gcd(self.gcd_strides, other.gcd_strides);
200+
201+
if data_diff % gcd_strides != 0 {
202+
return false;
203+
}
204+
169205
true
170206
}
171207
}
@@ -192,10 +228,7 @@ impl BorrowFlags {
192228
D: Dimension,
193229
{
194230
let address = base_address(array);
195-
196-
let key = BorrowKey {
197-
range: data_range(array),
198-
};
231+
let key = BorrowKey::from_array(array);
199232

200233
// SAFETY: Access to `&PyArray<T, D>` implies holding the GIL
201234
// and we are not calling into user code which might re-enter this function.
@@ -242,10 +275,7 @@ impl BorrowFlags {
242275
D: Dimension,
243276
{
244277
let address = base_address(array);
245-
246-
let key = BorrowKey {
247-
range: data_range(array),
248-
};
278+
let key = BorrowKey::from_array(array);
249279

250280
// SAFETY: Access to `&PyArray<T, D>` implies holding the GIL
251281
// and we are not calling into user code which might re-enter this function.
@@ -272,10 +302,7 @@ impl BorrowFlags {
272302
D: Dimension,
273303
{
274304
let address = base_address(array);
275-
276-
let key = BorrowKey {
277-
range: data_range(array),
278-
};
305+
let key = BorrowKey::from_array(array);
279306

280307
// SAFETY: Access to `&PyArray<T, D>` implies holding the GIL
281308
// and we are not calling into user code which might re-enter this function.
@@ -320,10 +347,7 @@ impl BorrowFlags {
320347
D: Dimension,
321348
{
322349
let address = base_address(array);
323-
324-
let key = BorrowKey {
325-
range: data_range(array),
326-
};
350+
let key = BorrowKey::from_array(array);
327351

328352
// SAFETY: Access to `&PyArray<T, D>` implies holding the GIL
329353
// and we are not calling into user code which might re-enter this function.
@@ -628,6 +652,14 @@ where
628652
Range { start, end }
629653
}
630654

655+
fn abs_diff(lhs: usize, rhs: usize) -> usize {
656+
if lhs >= rhs {
657+
lhs - rhs
658+
} else {
659+
rhs - lhs
660+
}
661+
}
662+
631663
#[cfg(test)]
632664
mod tests {
633665
use super::*;
@@ -650,7 +682,7 @@ mod tests {
650682
assert_eq!(base_address, array as *const _ as usize);
651683

652684
let data_range = data_range(array);
653-
assert_eq!(data_range.start, unsafe { array.data() } as usize);
685+
assert_eq!(data_range.start, array.data() as usize);
654686
assert_eq!(data_range.end, unsafe { array.data().add(15) } as usize);
655687
});
656688
}
@@ -668,7 +700,7 @@ mod tests {
668700
assert_eq!(base_address, base as usize);
669701

670702
let data_range = data_range(array);
671-
assert_eq!(data_range.start, unsafe { array.data() } as usize);
703+
assert_eq!(data_range.start, array.data() as usize);
672704
assert_eq!(data_range.end, unsafe { array.data().add(15) } as usize);
673705
});
674706
}
@@ -694,7 +726,7 @@ mod tests {
694726
assert_eq!(base_address, base as usize);
695727

696728
let data_range = data_range(view);
697-
assert_eq!(data_range.start, unsafe { view.data() } as usize);
729+
assert_eq!(data_range.start, view.data() as usize);
698730
assert_eq!(data_range.end, unsafe { view.data().add(12) } as usize);
699731
});
700732
}
@@ -724,7 +756,7 @@ mod tests {
724756
assert_eq!(base_address, base as usize);
725757

726758
let data_range = data_range(view);
727-
assert_eq!(data_range.start, unsafe { view.data() } as usize);
759+
assert_eq!(data_range.start, view.data() as usize);
728760
assert_eq!(data_range.end, unsafe { view.data().add(12) } as usize);
729761
});
730762
}
@@ -763,7 +795,7 @@ mod tests {
763795
assert_eq!(base_address, base as usize);
764796

765797
let data_range = data_range(view2);
766-
assert_eq!(data_range.start, unsafe { view2.data() } as usize);
798+
assert_eq!(data_range.start, view2.data() as usize);
767799
assert_eq!(data_range.end, unsafe { view2.data().add(6) } as usize);
768800
});
769801
}
@@ -806,7 +838,7 @@ mod tests {
806838
assert_eq!(base_address, base as usize);
807839

808840
let data_range = data_range(view2);
809-
assert_eq!(data_range.start, unsafe { view2.data() } as usize);
841+
assert_eq!(data_range.start, view2.data() as usize);
810842
assert_eq!(data_range.end, unsafe { view2.data().add(6) } as usize);
811843
});
812844
}

tests/borrow.rs

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -212,28 +212,39 @@ fn conflict_due_reborrow_of_overlapping_views() {
212212
}
213213

214214
#[test]
215-
#[should_panic(expected = "AlreadyBorrowed")]
216-
fn interleaved_views_conflict() {
215+
fn interleaved_views_do_not_conflict() {
217216
Python::with_gil(|py| {
218-
let array = PyArray::<f64, _>::zeros(py, (1, 2, 3), false);
217+
let array = PyArray::<f64, _>::zeros(py, (23, 42, 3), false);
219218
let locals = [("array", array)].into_py_dict(py);
220219

221220
let view1 = py
222-
.eval("array[:,:,1]", None, Some(locals))
221+
.eval("array[:,:,0]", None, Some(locals))
223222
.unwrap()
224223
.downcast::<PyArray2<f64>>()
225224
.unwrap();
226-
assert_eq!(view1.shape(), [1, 2]);
225+
assert_eq!(view1.shape(), [23, 42]);
227226

228227
let view2 = py
228+
.eval("array[:,:,1]", None, Some(locals))
229+
.unwrap()
230+
.downcast::<PyArray2<f64>>()
231+
.unwrap();
232+
assert_eq!(view2.shape(), [23, 42]);
233+
234+
let view3 = py
229235
.eval("array[:,:,2]", None, Some(locals))
230236
.unwrap()
231237
.downcast::<PyArray2<f64>>()
232238
.unwrap();
233-
assert_eq!(view2.shape(), [1, 2]);
239+
assert_eq!(view2.shape(), [23, 42]);
234240

235-
let _exclusive1 = view1.readwrite();
236-
let _exclusive2 = view2.readwrite();
241+
let exclusive1 = view1.readwrite();
242+
let exclusive2 = view2.readwrite();
243+
let exclusive3 = view3.readwrite();
244+
245+
assert_eq!(exclusive3.len(), 23 * 42);
246+
assert_eq!(exclusive2.len(), 23 * 42);
247+
assert_eq!(exclusive1.len(), 23 * 42);
237248
});
238249
}
239250

0 commit comments

Comments
 (0)