@@ -49,6 +49,9 @@ PYBIND11_WARNING_DISABLE_MSVC(4127)
49
49
class dtype; // Forward declaration
50
50
class array ; // Forward declaration
51
51
52
+ template <typename >
53
+ struct numpy_scalar ; // Forward declaration
54
+
52
55
PYBIND11_NAMESPACE_BEGIN (detail)
53
56
54
57
template <>
@@ -245,6 +248,21 @@ struct npy_api {
245
248
NPY_UINT64_
246
249
= platform_lookup<std::uint64_t , unsigned long , unsigned long long , unsigned int >(
247
250
NPY_ULONG_, NPY_ULONGLONG_, NPY_UINT_),
251
+ NPY_FLOAT32_ = platform_lookup<float , double , float , long double >(
252
+ NPY_DOUBLE_, NPY_FLOAT_, NPY_LONGDOUBLE_),
253
+ NPY_FLOAT64_ = platform_lookup<double , double , float , long double >(
254
+ NPY_DOUBLE_, NPY_FLOAT_, NPY_LONGDOUBLE_),
255
+ NPY_COMPLEX64_
256
+ = platform_lookup<std::complex<float >,
257
+ std::complex<double >,
258
+ std::complex<float >,
259
+ std::complex<long double >>(NPY_DOUBLE_, NPY_FLOAT_, NPY_LONGDOUBLE_),
260
+ NPY_COMPLEX128_
261
+ = platform_lookup<std::complex<double >,
262
+ std::complex<double >,
263
+ std::complex<float >,
264
+ std::complex<long double >>(NPY_DOUBLE_, NPY_FLOAT_, NPY_LONGDOUBLE_),
265
+ NPY_CHAR_ = std::is_signed<char >::value ? NPY_BYTE_ : NPY_UBYTE_,
248
266
};
249
267
250
268
unsigned int PyArray_RUNTIME_VERSION_;
@@ -268,6 +286,7 @@ struct npy_api {
268
286
269
287
unsigned int (*PyArray_GetNDArrayCFeatureVersion_)();
270
288
PyObject *(*PyArray_DescrFromType_)(int );
289
+ PyObject *(*PyArray_TypeObjectFromType_)(int );
271
290
PyObject *(*PyArray_NewFromDescr_)(PyTypeObject *,
272
291
PyObject *,
273
292
int ,
@@ -284,6 +303,8 @@ struct npy_api {
284
303
PyTypeObject *PyVoidArrType_Type_;
285
304
PyTypeObject *PyArrayDescr_Type_;
286
305
PyObject *(*PyArray_DescrFromScalar_)(PyObject *);
306
+ PyObject *(*PyArray_Scalar_)(void *, PyObject *, PyObject *);
307
+ void (*PyArray_ScalarAsCtype_)(PyObject *, void *);
287
308
PyObject *(*PyArray_FromAny_)(PyObject *, PyObject *, int , int , int , PyObject *);
288
309
int (*PyArray_DescrConverter_)(PyObject *, PyObject **);
289
310
bool (*PyArray_EquivTypes_)(PyObject *, PyObject *);
@@ -301,7 +322,10 @@ struct npy_api {
301
322
API_PyArrayDescr_Type = 3 ,
302
323
API_PyVoidArrType_Type = 39 ,
303
324
API_PyArray_DescrFromType = 45 ,
325
+ API_PyArray_TypeObjectFromType = 46 ,
304
326
API_PyArray_DescrFromScalar = 57 ,
327
+ API_PyArray_Scalar = 60 ,
328
+ API_PyArray_ScalarAsCtype = 62 ,
305
329
API_PyArray_FromAny = 69 ,
306
330
API_PyArray_Resize = 80 ,
307
331
// CopyInto was slot 82 and 50 was effectively an alias. NumPy 2 removed 82.
@@ -336,7 +360,10 @@ struct npy_api {
336
360
DECL_NPY_API (PyVoidArrType_Type);
337
361
DECL_NPY_API (PyArrayDescr_Type);
338
362
DECL_NPY_API (PyArray_DescrFromType);
363
+ DECL_NPY_API (PyArray_TypeObjectFromType);
339
364
DECL_NPY_API (PyArray_DescrFromScalar);
365
+ DECL_NPY_API (PyArray_Scalar);
366
+ DECL_NPY_API (PyArray_ScalarAsCtype);
340
367
DECL_NPY_API (PyArray_FromAny);
341
368
DECL_NPY_API (PyArray_Resize);
342
369
DECL_NPY_API (PyArray_CopyInto);
@@ -355,6 +382,83 @@ struct npy_api {
355
382
}
356
383
};
357
384
385
+ template <typename T>
386
+ struct is_complex : std::false_type {};
387
+ template <typename T>
388
+ struct is_complex <std::complex<T>> : std::true_type {};
389
+
390
+ template <typename T, typename = void >
391
+ struct npy_format_descriptor_name ;
392
+
393
+ template <typename T>
394
+ struct npy_format_descriptor_name <T, enable_if_t <std::is_integral<T>::value>> {
395
+ static constexpr auto name = const_name<std::is_same<T, bool >::value>(
396
+ const_name (" numpy.bool" ),
397
+ const_name<std::is_signed<T>::value>(" numpy.int" , " numpy.uint" )
398
+ + const_name<sizeof (T) * 8 >());
399
+ };
400
+
401
+ template <typename T>
402
+ struct npy_format_descriptor_name <T, enable_if_t <std::is_floating_point<T>::value>> {
403
+ static constexpr auto name = const_name < std::is_same<T, float >::value
404
+ || std::is_same<T, const float >::value
405
+ || std::is_same<T, double >::value
406
+ || std::is_same<T, const double >::value
407
+ > (const_name(" numpy.float" ) + const_name<sizeof (T) * 8 >(),
408
+ const_name (" numpy.longdouble" ));
409
+ };
410
+
411
+ template <typename T>
412
+ struct npy_format_descriptor_name <T, enable_if_t <is_complex<T>::value>> {
413
+ static constexpr auto name = const_name < std::is_same<typename T::value_type, float >::value
414
+ || std::is_same<typename T::value_type, const float >::value
415
+ || std::is_same<typename T::value_type, double >::value
416
+ || std::is_same<typename T::value_type, const double >::value
417
+ > (const_name(" numpy.complex" )
418
+ + const_name<sizeof (typename T::value_type) * 16 >(),
419
+ const_name (" numpy.longcomplex" ));
420
+ };
421
+
422
+ template <typename T>
423
+ struct numpy_scalar_info {};
424
+
425
+ #define PYBIND11_NUMPY_SCALAR_IMPL (ctype_, typenum_ ) \
426
+ template <> \
427
+ struct numpy_scalar_info <ctype_> { \
428
+ static constexpr auto name = npy_format_descriptor_name<ctype_>::name; \
429
+ static constexpr int typenum = npy_api::typenum_##_; \
430
+ }
431
+
432
+ // boolean type
433
+ PYBIND11_NUMPY_SCALAR_IMPL (bool , NPY_BOOL);
434
+
435
+ // character types
436
+ PYBIND11_NUMPY_SCALAR_IMPL (char , NPY_CHAR);
437
+ PYBIND11_NUMPY_SCALAR_IMPL (signed char , NPY_BYTE);
438
+ PYBIND11_NUMPY_SCALAR_IMPL (unsigned char , NPY_UBYTE);
439
+
440
+ // signed integer types
441
+ PYBIND11_NUMPY_SCALAR_IMPL (std::int16_t , NPY_INT16);
442
+ PYBIND11_NUMPY_SCALAR_IMPL (std::int32_t , NPY_INT32);
443
+ PYBIND11_NUMPY_SCALAR_IMPL (std::int64_t , NPY_INT64);
444
+
445
+ // unsigned integer types
446
+ PYBIND11_NUMPY_SCALAR_IMPL (std::uint16_t , NPY_UINT16);
447
+ PYBIND11_NUMPY_SCALAR_IMPL (std::uint32_t , NPY_UINT32);
448
+ PYBIND11_NUMPY_SCALAR_IMPL (std::uint64_t , NPY_UINT64);
449
+
450
+ // floating point types
451
+ PYBIND11_NUMPY_SCALAR_IMPL (float , NPY_FLOAT);
452
+ PYBIND11_NUMPY_SCALAR_IMPL (double , NPY_DOUBLE);
453
+ PYBIND11_NUMPY_SCALAR_IMPL (long double , NPY_LONGDOUBLE);
454
+
455
+ // complex types
456
+ PYBIND11_NUMPY_SCALAR_IMPL (std::complex<float >, NPY_CFLOAT);
457
+ PYBIND11_NUMPY_SCALAR_IMPL (std::complex<double >, NPY_CDOUBLE);
458
+ PYBIND11_NUMPY_SCALAR_IMPL (std::complex<long double >, NPY_CLONGDOUBLE);
459
+
460
+ #undef PYBIND11_NUMPY_SCALAR_IMPL
461
+
358
462
// This table normalizes typenums by mapping NPY_INT_, NPY_LONG, ... to NPY_INT32_, NPY_INT64, ...
359
463
// This is needed to correctly handle situations where multiple typenums map to the same type,
360
464
// e.g. NPY_LONG_ may be equivalent to NPY_INT_ or NPY_LONGLONG_ despite having a different
@@ -453,10 +557,6 @@ template <typename T>
453
557
struct is_std_array : std::false_type {};
454
558
template <typename T, size_t N>
455
559
struct is_std_array <std::array<T, N>> : std::true_type {};
456
- template <typename T>
457
- struct is_complex : std::false_type {};
458
- template <typename T>
459
- struct is_complex <std::complex<T>> : std::true_type {};
460
560
461
561
template <typename T>
462
562
struct array_info_scalar {
@@ -670,8 +770,65 @@ template <typename T, ssize_t Dim>
670
770
struct type_caster <unchecked_mutable_reference<T, Dim>>
671
771
: type_caster<unchecked_reference<T, Dim>> {};
672
772
773
+ template <typename T>
774
+ struct type_caster <numpy_scalar<T>> {
775
+ using value_type = T;
776
+ using type_info = numpy_scalar_info<T>;
777
+
778
+ PYBIND11_TYPE_CASTER (numpy_scalar<T>, type_info::name);
779
+
780
+ static handle &target_type () {
781
+ static handle tp = npy_api::get ().PyArray_TypeObjectFromType_ (type_info::typenum);
782
+ return tp;
783
+ }
784
+
785
+ static handle &target_dtype () {
786
+ static handle tp = npy_api::get ().PyArray_DescrFromType_ (type_info::typenum);
787
+ return tp;
788
+ }
789
+
790
+ bool load (handle src, bool ) {
791
+ if (isinstance (src, target_type ())) {
792
+ npy_api::get ().PyArray_ScalarAsCtype_ (src.ptr (), &value.value );
793
+ return true ;
794
+ }
795
+ return false ;
796
+ }
797
+
798
+ static handle cast (numpy_scalar<T> src, return_value_policy, handle) {
799
+ return npy_api::get ().PyArray_Scalar_ (&src.value , target_dtype ().ptr (), nullptr );
800
+ }
801
+ };
802
+
673
803
PYBIND11_NAMESPACE_END (detail)
674
804
805
+ template <typename T>
806
+ struct numpy_scalar {
807
+ using value_type = T;
808
+
809
+ value_type value;
810
+
811
+ numpy_scalar () = default ;
812
+ explicit numpy_scalar (value_type value) : value (value) {}
813
+
814
+ explicit operator value_type () const { return value; }
815
+ numpy_scalar &operator =(value_type value) {
816
+ this ->value = value;
817
+ return *this ;
818
+ }
819
+
820
+ friend bool operator ==(const numpy_scalar &a, const numpy_scalar &b) {
821
+ return a.value == b.value ;
822
+ }
823
+
824
+ friend bool operator !=(const numpy_scalar &a, const numpy_scalar &b) { return !(a == b); }
825
+ };
826
+
827
+ template <typename T>
828
+ numpy_scalar<T> make_scalar (T value) {
829
+ return numpy_scalar<T>(value);
830
+ }
831
+
675
832
class dtype : public object {
676
833
public:
677
834
PYBIND11_OBJECT_DEFAULT (dtype, object, detail::npy_api::get().PyArrayDescr_Check_)
@@ -1409,38 +1566,6 @@ struct compare_buffer_info<T, detail::enable_if_t<detail::is_pod_struct<T>::valu
1409
1566
}
1410
1567
};
1411
1568
1412
- template <typename T, typename = void >
1413
- struct npy_format_descriptor_name ;
1414
-
1415
- template <typename T>
1416
- struct npy_format_descriptor_name <T, enable_if_t <std::is_integral<T>::value>> {
1417
- static constexpr auto name = const_name<std::is_same<T, bool >::value>(
1418
- const_name (" bool" ),
1419
- const_name<std::is_signed<T>::value>(" numpy.int" , " numpy.uint" )
1420
- + const_name<sizeof (T) * 8 >());
1421
- };
1422
-
1423
- template <typename T>
1424
- struct npy_format_descriptor_name <T, enable_if_t <std::is_floating_point<T>::value>> {
1425
- static constexpr auto name = const_name < std::is_same<T, float >::value
1426
- || std::is_same<T, const float >::value
1427
- || std::is_same<T, double >::value
1428
- || std::is_same<T, const double >::value
1429
- > (const_name(" numpy.float" ) + const_name<sizeof (T) * 8 >(),
1430
- const_name (" numpy.longdouble" ));
1431
- };
1432
-
1433
- template <typename T>
1434
- struct npy_format_descriptor_name <T, enable_if_t <is_complex<T>::value>> {
1435
- static constexpr auto name = const_name < std::is_same<typename T::value_type, float >::value
1436
- || std::is_same<typename T::value_type, const float >::value
1437
- || std::is_same<typename T::value_type, double >::value
1438
- || std::is_same<typename T::value_type, const double >::value
1439
- > (const_name(" numpy.complex" )
1440
- + const_name<sizeof (typename T::value_type) * 16 >(),
1441
- const_name (" numpy.longcomplex" ));
1442
- };
1443
-
1444
1569
template <typename T>
1445
1570
struct npy_format_descriptor <
1446
1571
T,
0 commit comments