@@ -441,7 +441,12 @@ unsafe impl Element for PyObject {
441
441
442
442
#[ cfg( test) ]
443
443
mod tests {
444
- use super :: { dtype, Complex32 , Complex64 , Element } ;
444
+ use std:: os:: raw:: c_int;
445
+
446
+ use pyo3:: { py_run, types:: PyDict , PyObject } ;
447
+
448
+ use super :: { dtype, Complex32 , Complex64 , Element , PyArrayDescr } ;
449
+ use crate :: npyffi:: NPY_TYPES ;
445
450
446
451
#[ test]
447
452
fn test_dtype_names ( ) {
@@ -474,4 +479,105 @@ mod tests {
474
479
}
475
480
} ) ;
476
481
}
482
+
483
+ #[ test]
484
+ fn test_dtype_methods_scalar ( ) {
485
+ pyo3:: Python :: with_gil ( |py| {
486
+ let dt = dtype :: < f64 > ( py) ;
487
+
488
+ assert_eq ! ( dt. num( ) , NPY_TYPES :: NPY_DOUBLE as c_int) ;
489
+ assert_eq ! ( dt. typeobj( ) . name( ) . unwrap( ) , "float64" ) ;
490
+ assert_eq ! ( dt. char ( ) , b'd' ) ;
491
+ assert_eq ! ( dt. kind( ) , b'f' ) ;
492
+ assert_eq ! ( dt. byteorder( ) , b'=' ) ;
493
+ assert_eq ! ( dt. is_native_byteorder( ) , Some ( true ) ) ;
494
+ assert_eq ! ( dt. itemsize( ) , 8 ) ;
495
+ assert_eq ! ( dt. alignment( ) , 8 ) ;
496
+ assert_eq ! ( dt. has_object( ) , false ) ;
497
+ assert_eq ! ( dt. names( ) , None ) ;
498
+ assert_eq ! ( dt. has_fields( ) , false ) ;
499
+ assert_eq ! ( dt. is_aligned_struct( ) , false ) ;
500
+ assert_eq ! ( dt. has_subarray( ) , false ) ;
501
+ assert ! ( dt. base( ) . is_equiv_to( & dt) ) ;
502
+ assert_eq ! ( dt. ndim( ) , 0 ) ;
503
+ assert_eq ! ( dt. shape( ) , None ) ;
504
+ } ) ;
505
+ }
506
+
507
+ #[ test]
508
+ fn test_dtype_methods_subarray ( ) {
509
+ pyo3:: Python :: with_gil ( |py| {
510
+ let locals = PyDict :: new ( py) ;
511
+ py_run ! (
512
+ py,
513
+ * locals,
514
+ "dtype = __import__('numpy').dtype(('f8', (2, 3)))"
515
+ ) ;
516
+ let dt = locals
517
+ . get_item ( "dtype" )
518
+ . unwrap ( )
519
+ . downcast :: < PyArrayDescr > ( )
520
+ . unwrap ( ) ;
521
+
522
+ assert_eq ! ( dt. num( ) , NPY_TYPES :: NPY_VOID as c_int) ;
523
+ assert_eq ! ( dt. typeobj( ) . name( ) . unwrap( ) , "void" ) ;
524
+ assert_eq ! ( dt. char ( ) , b'V' ) ;
525
+ assert_eq ! ( dt. kind( ) , b'V' ) ;
526
+ assert_eq ! ( dt. byteorder( ) , b'|' ) ;
527
+ assert_eq ! ( dt. is_native_byteorder( ) , None ) ;
528
+ assert_eq ! ( dt. itemsize( ) , 48 ) ;
529
+ assert_eq ! ( dt. alignment( ) , 8 ) ;
530
+ assert_eq ! ( dt. has_object( ) , false ) ;
531
+ assert_eq ! ( dt. names( ) , None ) ;
532
+ assert_eq ! ( dt. has_fields( ) , false ) ;
533
+ assert_eq ! ( dt. is_aligned_struct( ) , false ) ;
534
+ assert_eq ! ( dt. has_subarray( ) , true ) ;
535
+ assert_eq ! ( dt. ndim( ) , 2 ) ;
536
+ assert_eq ! ( dt. shape( ) . unwrap( ) , vec![ 2 , 3 ] ) ;
537
+ assert ! ( dt. base( ) . is_equiv_to( dtype:: <f64 >( py) ) ) ;
538
+ } ) ;
539
+ }
540
+
541
+ #[ test]
542
+ fn test_dtype_methods_record ( ) {
543
+ pyo3:: Python :: with_gil ( |py| {
544
+ let locals = PyDict :: new ( py) ;
545
+ py_run ! (
546
+ py,
547
+ * locals,
548
+ "dtype = __import__('numpy').dtype([('x', 'u1'), ('y', 'f8'), ('z', 'O')], align=True)"
549
+ ) ;
550
+ let dt = locals
551
+ . get_item ( "dtype" )
552
+ . unwrap ( )
553
+ . downcast :: < PyArrayDescr > ( )
554
+ . unwrap ( ) ;
555
+
556
+ assert_eq ! ( dt. num( ) , NPY_TYPES :: NPY_VOID as c_int) ;
557
+ assert_eq ! ( dt. typeobj( ) . name( ) . unwrap( ) , "void" ) ;
558
+ assert_eq ! ( dt. char ( ) , b'V' ) ;
559
+ assert_eq ! ( dt. kind( ) , b'V' ) ;
560
+ assert_eq ! ( dt. byteorder( ) , b'|' ) ;
561
+ assert_eq ! ( dt. is_native_byteorder( ) , None ) ;
562
+ assert_eq ! ( dt. itemsize( ) , 24 ) ;
563
+ assert_eq ! ( dt. alignment( ) , 8 ) ;
564
+ assert_eq ! ( dt. has_object( ) , true ) ;
565
+ assert_eq ! ( dt. names( ) , Some ( vec![ "x" , "y" , "z" ] ) ) ;
566
+ assert_eq ! ( dt. has_fields( ) , true ) ;
567
+ assert_eq ! ( dt. is_aligned_struct( ) , true ) ;
568
+ assert_eq ! ( dt. has_subarray( ) , false ) ;
569
+ assert_eq ! ( dt. ndim( ) , 0 ) ;
570
+ assert_eq ! ( dt. shape( ) , None ) ;
571
+ assert ! ( dt. base( ) . is_equiv_to( & dt) ) ;
572
+ let x = dt. get_field ( "x" ) . unwrap ( ) ;
573
+ assert ! ( x. 0 . is_equiv_to( dtype:: <u8 >( py) ) ) ;
574
+ assert_eq ! ( x. 1 , 0 ) ;
575
+ let y = dt. get_field ( "y" ) . unwrap ( ) ;
576
+ assert ! ( y. 0 . is_equiv_to( dtype:: <f64 >( py) ) ) ;
577
+ assert_eq ! ( y. 1 , 8 ) ;
578
+ let z = dt. get_field ( "z" ) . unwrap ( ) ;
579
+ assert ! ( z. 0 . is_equiv_to( dtype:: <PyObject >( py) ) ) ;
580
+ assert_eq ! ( z. 1 , 16 ) ;
581
+ } ) ;
582
+ }
477
583
}
0 commit comments