@@ -12,6 +12,7 @@ use numpy::PyArrayMethods;
1212use numpy:: PyUntypedArrayMethods ;
1313use numpy:: ToPyArray ;
1414use numpy:: { PyArray2 , PyReadonlyArray2 } ;
15+ use numpy:: IntoPyArray ;
1516
1617#[ pyfunction]
1718fn first_true_1d_a ( array : PyReadonlyArray1 < bool > ) -> isize {
@@ -346,49 +347,127 @@ fn first_true_1d(py: Python, array: PyReadonlyArray1<bool>, forward: bool) -> is
346347// axis == 0: transpose, copy to C
347348// axis == 1: copy to C
348349
349- fn prepare_array_for_axis < ' py > (
350+ // fn prepare_array_for_axis<'py>(
351+ // py: Python<'py>,
352+ // array: PyReadonlyArray2<'py, bool>,
353+ // axis: isize,
354+ // ) -> PyResult<Bound<'py, PyArray2<bool>>> {
355+ // if axis != 0 && axis != 1 {
356+ // return Err(PyValueError::new_err("axis must be 0 or 1"));
357+ // }
358+
359+ // let is_c = array.is_c_contiguous();
360+ // let is_f = array.is_fortran_contiguous();
361+ // let array_view = array.as_array();
362+
363+ // match (is_c, is_f, axis) {
364+ // (true, _, 1) => {
365+ // // Already C-contiguous, no copy needed
366+ // Ok(array_view.to_pyarray(py).to_owned())
367+ // }
368+ // (_, true, 0) => {
369+ // // F-contiguous original -> transposed will be C-contiguous, no copy needed
370+ // Ok(array_view.reversed_axes().to_pyarray(py).to_owned())
371+ // }
372+ // (_, true, 1) => {
373+ // // F-contiguous, need to copy to C-contiguous
374+ // let contiguous = array_view.as_standard_layout();
375+ // Ok(contiguous.to_pyarray(py).to_owned())
376+ // }
377+ // (_, _, 1) => {
378+ // // Neither C nor F contiguous, need to copy
379+ // let contiguous = array_view.as_standard_layout();
380+ // Ok(contiguous.to_pyarray(py).to_owned())
381+ // }
382+
383+ // (true, _, 0) | (_, _, 0) => {
384+ // // C-contiguous or neither -> transposed won't be C-contiguous, need copy
385+ // let transposed = array_view.reversed_axes();
386+ // let contiguous = transposed.as_standard_layout();
387+ // Ok(contiguous.to_pyarray(py).to_owned())
388+ // }
389+ // _ => unreachable!(),
390+ // }
391+ // }
392+
393+
394+ // use numpy::{PyReadonlyArray2, IntoPyArray, PyArray2};
395+ // use pyo3::prelude::*;
396+
397+ pub struct PreparedBool2D < ' py > {
398+ pub data : & ' py [ u8 ] , // flat contiguous buffer
399+ pub nrows : usize , // number of logical rows
400+ pub ncols : usize ,
401+ _keepalive : Option < Bound < ' py , PyAny > > , // holds any copied/transposed buffer
402+ }
403+
404+ pub fn prepare_array_for_axis < ' py > (
350405 py : Python < ' py > ,
351406 array : PyReadonlyArray2 < ' py , bool > ,
352407 axis : isize ,
353- ) -> PyResult < Bound < ' py , PyArray2 < bool > > > {
408+ ) -> PyResult < PreparedBool2D < ' py > > {
354409 if axis != 0 && axis != 1 {
355410 return Err ( PyValueError :: new_err ( "axis must be 0 or 1" ) ) ;
356411 }
357412
413+ let shape = array. shape ( ) ;
414+ let ( nrows, ncols) = if axis == 0 {
415+ ( shape[ 1 ] , shape[ 0 ] ) // transposed
416+ } else {
417+ ( shape[ 0 ] , shape[ 1 ] ) // as-is
418+ } ;
419+
358420 let is_c = array. is_c_contiguous ( ) ;
359421 let is_f = array. is_fortran_contiguous ( ) ;
360422 let array_view = array. as_array ( ) ;
361423
362- match ( is_c, is_f, axis) {
363- ( true , _, 1 ) => {
364- // Already C-contiguous, no copy needed
365- Ok ( array_view. to_pyarray ( py) . to_owned ( ) )
366- }
367- ( _, true , 0 ) => {
368- // F-contiguous original -> transposed will be C-contiguous, no copy needed
369- Ok ( array_view. reversed_axes ( ) . to_pyarray ( py) . to_owned ( ) )
370- }
371- ( _, true , 1 ) => {
372- // F-contiguous, need to copy to C-contiguous
373- let contiguous = array_view. as_standard_layout ( ) ;
374- Ok ( contiguous. to_pyarray ( py) . to_owned ( ) )
375- }
376- ( _, _, 1 ) => {
377- // Neither C nor F contiguous, need to copy
378- let contiguous = array_view. as_standard_layout ( ) ;
379- Ok ( contiguous. to_pyarray ( py) . to_owned ( ) )
424+ // Case 1: C-contiguous + axis=1 → zero-copy slice
425+ if is_c && axis == 1 {
426+ if let Ok ( slice) = array. as_slice ( ) {
427+ return Ok ( PreparedBool2D {
428+ data : unsafe { std:: mem:: transmute ( slice) } , // &[bool] → &[u8]
429+ nrows,
430+ ncols,
431+ _keepalive : None ,
432+ } ) ;
380433 }
434+ }
381435
382- ( true , _, 0 ) | ( _, _, 0 ) => {
383- // C-contiguous or neither -> transposed won't be C-contiguous, need copy
384- let transposed = array_view. reversed_axes ( ) ;
385- let contiguous = transposed. as_standard_layout ( ) ;
386- Ok ( contiguous. to_pyarray ( py) . to_owned ( ) )
436+ // Case 2: F-contiguous + axis=0 → transpose, check if sliceable
437+ if is_f && axis == 0 {
438+ let transposed = array_view. reversed_axes ( ) ;
439+ if let Some ( slice) = transposed. as_standard_layout ( ) . as_slice_memory_order ( ) {
440+ return Ok ( PreparedBool2D {
441+ data : unsafe { std:: mem:: transmute ( slice) } ,
442+ nrows,
443+ ncols,
444+ _keepalive : None ,
445+ } ) ;
387446 }
388- _ => unreachable ! ( ) ,
389447 }
448+
449+ // Case 3: fallback — create a new C-contiguous owned array
450+ let prepared_array: Bound < ' py , PyArray2 < bool > > = if axis == 0 {
451+ array_view. reversed_axes ( ) . as_standard_layout ( ) . to_owned ( ) . to_pyarray ( py)
452+ } else {
453+ array_view. as_standard_layout ( ) . to_owned ( ) . to_pyarray ( py)
454+ } ;
455+
456+ let array_view = unsafe { prepared_array. as_array ( ) } ;
457+ let prepared_slice = array_view
458+ . as_slice_memory_order ( )
459+ . expect ( "Newly allocated array must be contiguous" ) ;
460+
461+ Ok ( PreparedBool2D {
462+ data : unsafe { std:: mem:: transmute ( prepared_slice) } ,
463+ nrows,
464+ ncols,
465+ _keepalive : Some ( prepared_array. into_any ( ) ) ,
466+ } )
390467}
391468
469+
470+
392471#[ pyfunction]
393472#[ pyo3( signature = ( array, * , forward=true , axis) ) ]
394473pub fn first_true_2d < ' py > (
@@ -397,47 +476,58 @@ pub fn first_true_2d<'py>(
397476 forward : bool ,
398477 axis : isize ,
399478) -> PyResult < Bound < ' py , PyArray1 < isize > > > {
400- let prepped = prepare_array_for_axis ( py, array, axis) ?;
401- let view = unsafe { prepped. as_array ( ) } ;
402479
403- // let view = array.as_array();
404- // NOTE: these are rows in the view, not always the same as rows
405- let rows = view. nrows ( ) ;
406- let mut result = Vec :: with_capacity ( rows) ;
480+ // let prepped = prepare_array_for_axis(py, array, axis)?;
481+ // let view = unsafe { prepped.as_array() };
482+ // // NOTE: these are rows in the view, not always the same as rows
483+ // let rows = view.nrows();
484+
485+ let prepared = prepare_array_for_axis ( py, array, axis) ?;
486+ let data = prepared. data ;
487+ let rows = prepared. nrows ;
488+ let row_len = prepared. ncols ;
489+
490+ let mut result = vec ! [ -1isize ; rows] ;
407491
408492 py. allow_threads ( || {
409493 const LANES : usize = 32 ;
410494 let ones = u8x32:: splat ( 1 ) ;
411495
496+ let base_ptr = data. as_ptr ( ) ;
497+
412498 for row in 0 ..rows {
413- let mut found = -1 ;
414- let row_slice = & view. row ( row) ;
415- let ptr = row_slice. as_ptr ( ) as * const u8 ;
416- let len = row_slice. len ( ) ;
499+
500+ let ptr = unsafe { base_ptr. add ( row * row_len) } ;
501+
502+ // let mut found = -1;
503+ // let row_slice = &view.row(row);
504+ // let ptr = row_slice.as_ptr() as *const u8;
505+ // let len = row_slice.len();
417506
418507 if forward {
419508 // Forward search
420509 let mut i = 0 ;
421510 unsafe {
422- while i + LANES <= len {
511+ while i + LANES <= row_len {
423512 let chunk = & * ( ptr. add ( i) as * const [ u8 ; LANES ] ) ;
424513 let vec = u8x32:: from ( * chunk) ;
425514 if vec. cmp_eq ( ones) . any ( ) {
426515 break ;
427516 }
428517 i += LANES ;
429518 }
430- while i < len {
519+ while i < row_len {
431520 if * ptr. add ( i) != 0 {
432- found = i as isize ;
521+ // found = i as isize;
522+ result[ row] = i as isize ;
433523 break ;
434524 }
435525 i += 1 ;
436526 }
437527 }
438528 } else {
439529 // Backward search
440- let mut i = len ;
530+ let mut i = row_len ;
441531 unsafe {
442532 // Process LANES bytes at a time with SIMD (backwards)
443533 while i >= LANES {
@@ -448,25 +538,26 @@ pub fn first_true_2d<'py>(
448538 // Found a true in this chunk, search backwards within it
449539 for j in ( i..i + LANES ) . rev ( ) {
450540 if * ptr. add ( j) != 0 {
451- found = j as isize ;
541+ // found = j as isize;
542+ result[ row] = j as isize ;
452543 break ;
453544 }
454545 }
455546 break ;
456547 }
457548 }
458549 // Handle remaining bytes at the beginning
459- if found == - 1 && i > 0 {
550+ if i > 0 && i < LANES {
460551 for j in ( 0 ..i) . rev ( ) {
461552 if * ptr. add ( j) != 0 {
462- found = j as isize ;
553+ // found = j as isize;
554+ result[ row] = j as isize ;
463555 break ;
464556 }
465557 }
466558 }
467559 }
468560 }
469- result. push ( found) ;
470561 }
471562 } ) ;
472563
0 commit comments