125
125
//!
126
126
//! # Limitations
127
127
//!
128
+ //! TODO: We only leave the case of aliasing, but only out of bounds. Can this actually happen for array views?
129
+ //!
128
130
//! Note that the current implementation of this is an over-approximation: It will consider overlapping borrows
129
131
//! potentially conflicting if the initial arrays have the same object at the end of their [base object chain][base].
130
132
//! For example, creating two views of the same underlying array by slicing can yield potentially conflicting borrows
@@ -143,6 +145,7 @@ use std::collections::hash_map::{Entry, HashMap};
143
145
use std:: ops:: { Deref , Range } ;
144
146
145
147
use ndarray:: { ArrayView , ArrayViewMut , Dimension , Ix1 , Ix2 , Ix3 , Ix4 , Ix5 , Ix6 , IxDyn } ;
148
+ use num_integer:: gcd;
146
149
use pyo3:: { FromPyObject , PyAny , PyResult } ;
147
150
148
151
use crate :: array:: PyArray ;
@@ -155,9 +158,28 @@ use crate::npyffi::{self, PyArrayObject, NPY_ARRAY_WRITEABLE};
155
158
#[ derive( PartialEq , Eq , Hash ) ]
156
159
struct BorrowKey {
157
160
range : Range < usize > ,
161
+ data_ptr : usize ,
162
+ gcd_strides : isize ,
158
163
}
159
164
160
165
impl BorrowKey {
166
+ fn from_array < T , D > ( array : & PyArray < T , D > ) -> Self
167
+ where
168
+ T : Element ,
169
+ D : Dimension ,
170
+ {
171
+ let range = data_range ( array) ;
172
+
173
+ let data_ptr = array. data ( ) as usize ;
174
+ let gcd_strides = array. strides ( ) . iter ( ) . copied ( ) . reduce ( gcd) . unwrap_or ( 1 ) ;
175
+
176
+ Self {
177
+ range,
178
+ data_ptr,
179
+ gcd_strides,
180
+ }
181
+ }
182
+
161
183
fn conflicts ( & self , other : & Self ) -> bool {
162
184
debug_assert ! ( self . range. start <= self . range. end) ;
163
185
debug_assert ! ( other. range. start <= other. range. end) ;
@@ -166,6 +188,20 @@ impl BorrowKey {
166
188
return false ;
167
189
}
168
190
191
+ // The Diophantine equation which describes whether any integers can combine the data pointers and strides of the two arrays
192
+ // so that they yield the same element has a solution if and only if the GCD of all strides divides the difference of the data pointers.
193
+ //
194
+ // That solution could be out of bounds which mean that this still an approximation,
195
+ // but it seems sufficient to handle typical cases like the color channels of an image.
196
+ //
197
+ // https://users.rust-lang.org/t/math-for-borrow-checking-numpy-arrays/73303
198
+ let data_diff = abs_diff ( self . data_ptr , other. data_ptr ) as isize ;
199
+ let gcd_strides = gcd ( self . gcd_strides , other. gcd_strides ) ;
200
+
201
+ if data_diff % gcd_strides != 0 {
202
+ return false ;
203
+ }
204
+
169
205
true
170
206
}
171
207
}
@@ -192,10 +228,7 @@ impl BorrowFlags {
192
228
D : Dimension ,
193
229
{
194
230
let address = base_address ( array) ;
195
-
196
- let key = BorrowKey {
197
- range : data_range ( array) ,
198
- } ;
231
+ let key = BorrowKey :: from_array ( array) ;
199
232
200
233
// SAFETY: Access to `&PyArray<T, D>` implies holding the GIL
201
234
// and we are not calling into user code which might re-enter this function.
@@ -242,10 +275,7 @@ impl BorrowFlags {
242
275
D : Dimension ,
243
276
{
244
277
let address = base_address ( array) ;
245
-
246
- let key = BorrowKey {
247
- range : data_range ( array) ,
248
- } ;
278
+ let key = BorrowKey :: from_array ( array) ;
249
279
250
280
// SAFETY: Access to `&PyArray<T, D>` implies holding the GIL
251
281
// and we are not calling into user code which might re-enter this function.
@@ -272,10 +302,7 @@ impl BorrowFlags {
272
302
D : Dimension ,
273
303
{
274
304
let address = base_address ( array) ;
275
-
276
- let key = BorrowKey {
277
- range : data_range ( array) ,
278
- } ;
305
+ let key = BorrowKey :: from_array ( array) ;
279
306
280
307
// SAFETY: Access to `&PyArray<T, D>` implies holding the GIL
281
308
// and we are not calling into user code which might re-enter this function.
@@ -320,10 +347,7 @@ impl BorrowFlags {
320
347
D : Dimension ,
321
348
{
322
349
let address = base_address ( array) ;
323
-
324
- let key = BorrowKey {
325
- range : data_range ( array) ,
326
- } ;
350
+ let key = BorrowKey :: from_array ( array) ;
327
351
328
352
// SAFETY: Access to `&PyArray<T, D>` implies holding the GIL
329
353
// and we are not calling into user code which might re-enter this function.
@@ -628,6 +652,14 @@ where
628
652
Range { start, end }
629
653
}
630
654
655
+ fn abs_diff ( lhs : usize , rhs : usize ) -> usize {
656
+ if lhs >= rhs {
657
+ lhs - rhs
658
+ } else {
659
+ rhs - lhs
660
+ }
661
+ }
662
+
631
663
#[ cfg( test) ]
632
664
mod tests {
633
665
use super :: * ;
@@ -650,7 +682,7 @@ mod tests {
650
682
assert_eq ! ( base_address, array as * const _ as usize ) ;
651
683
652
684
let data_range = data_range ( array) ;
653
- assert_eq ! ( data_range. start, unsafe { array. data( ) } as usize ) ;
685
+ assert_eq ! ( data_range. start, array. data( ) as usize ) ;
654
686
assert_eq ! ( data_range. end, unsafe { array. data( ) . add( 15 ) } as usize ) ;
655
687
} ) ;
656
688
}
@@ -668,7 +700,7 @@ mod tests {
668
700
assert_eq ! ( base_address, base as usize ) ;
669
701
670
702
let data_range = data_range ( array) ;
671
- assert_eq ! ( data_range. start, unsafe { array. data( ) } as usize ) ;
703
+ assert_eq ! ( data_range. start, array. data( ) as usize ) ;
672
704
assert_eq ! ( data_range. end, unsafe { array. data( ) . add( 15 ) } as usize ) ;
673
705
} ) ;
674
706
}
@@ -694,7 +726,7 @@ mod tests {
694
726
assert_eq ! ( base_address, base as usize ) ;
695
727
696
728
let data_range = data_range ( view) ;
697
- assert_eq ! ( data_range. start, unsafe { view. data( ) } as usize ) ;
729
+ assert_eq ! ( data_range. start, view. data( ) as usize ) ;
698
730
assert_eq ! ( data_range. end, unsafe { view. data( ) . add( 12 ) } as usize ) ;
699
731
} ) ;
700
732
}
@@ -724,7 +756,7 @@ mod tests {
724
756
assert_eq ! ( base_address, base as usize ) ;
725
757
726
758
let data_range = data_range ( view) ;
727
- assert_eq ! ( data_range. start, unsafe { view. data( ) } as usize ) ;
759
+ assert_eq ! ( data_range. start, view. data( ) as usize ) ;
728
760
assert_eq ! ( data_range. end, unsafe { view. data( ) . add( 12 ) } as usize ) ;
729
761
} ) ;
730
762
}
@@ -763,7 +795,7 @@ mod tests {
763
795
assert_eq ! ( base_address, base as usize ) ;
764
796
765
797
let data_range = data_range ( view2) ;
766
- assert_eq ! ( data_range. start, unsafe { view2. data( ) } as usize ) ;
798
+ assert_eq ! ( data_range. start, view2. data( ) as usize ) ;
767
799
assert_eq ! ( data_range. end, unsafe { view2. data( ) . add( 6 ) } as usize ) ;
768
800
} ) ;
769
801
}
@@ -806,7 +838,7 @@ mod tests {
806
838
assert_eq ! ( base_address, base as usize ) ;
807
839
808
840
let data_range = data_range ( view2) ;
809
- assert_eq ! ( data_range. start, unsafe { view2. data( ) } as usize ) ;
841
+ assert_eq ! ( data_range. start, view2. data( ) as usize ) ;
810
842
assert_eq ! ( data_range. end, unsafe { view2. data( ) . add( 6 ) } as usize ) ;
811
843
} ) ;
812
844
}
0 commit comments