Skip to content

Commit cf3d1a7

Browse files
authored
feat: numpy scalars (#5726)
1 parent c60c149 commit cf3d1a7

File tree

5 files changed

+319
-36
lines changed

5 files changed

+319
-36
lines changed

docs/advanced/pycpp/numpy.rst

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,46 @@ prevent many types of unsupported structures, it is still the user's
232232
responsibility to use only "plain" structures that can be safely manipulated as
233233
raw memory without violating invariants.
234234

235+
Scalar types
236+
============
237+
238+
In some cases we may want to accept or return NumPy scalar values such as
239+
``np.float32`` or ``np.float64``. We hope to be able to handle single-precision
240+
and double-precision on the C-side. However, both are bound to Python's
241+
double-precision builtin float by default, so they cannot be processed separately.
242+
We used the ``py::buffer`` trick to implement the previous approach, which
243+
will cause the readability of the code to drop significantly.
244+
245+
Luckily, there's a helper type for this occasion - ``py::numpy_scalar``:
246+
247+
.. code-block:: cpp
248+
249+
m.def("add", [](py::numpy_scalar<float> a, py::numpy_scalar<float> b) {
250+
return py::make_scalar(a + b);
251+
});
252+
m.def("add", [](py::numpy_scalar<double> a, py::numpy_scalar<double> b) {
253+
return py::make_scalar(a + b);
254+
});
255+
256+
This type is trivially convertible to and from the type it wraps; currently
257+
supported scalar types are NumPy arithmetic types: ``bool_``, ``int8``,
258+
``int16``, ``int32``, ``int64``, ``uint8``, ``uint16``, ``uint32``,
259+
``uint64``, ``float32``, ``float64``, ``complex64``, ``complex128``, all of
260+
them mapping to respective C++ counterparts.
261+
262+
.. note::
263+
264+
``py::numpy_scalar<T>`` strictly matches NumPy scalar types. For example,
265+
``py::numpy_scalar<int64_t>`` will accept ``np.int64(123)``,
266+
but **not** a regular Python ``int`` like ``123``.
267+
268+
.. note::
269+
270+
Native C types are mapped to NumPy types in a platform specific way: for
271+
instance, ``char`` may be mapped to either ``np.int8`` or ``np.uint8``
272+
and ``long`` may use 4 or 8 bytes depending on the platform. Unless you
273+
clearly understand the difference and your needs, please use ``<cstdint>``.
274+
235275
Vectorizing functions
236276
=====================
237277

include/pybind11/numpy.h

Lines changed: 161 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,9 @@ PYBIND11_WARNING_DISABLE_MSVC(4127)
4949
class dtype; // Forward declaration
5050
class array; // Forward declaration
5151

52+
template <typename>
53+
struct numpy_scalar; // Forward declaration
54+
5255
PYBIND11_NAMESPACE_BEGIN(detail)
5356

5457
template <>
@@ -245,6 +248,21 @@ struct npy_api {
245248
NPY_UINT64_
246249
= platform_lookup<std::uint64_t, unsigned long, unsigned long long, unsigned int>(
247250
NPY_ULONG_, NPY_ULONGLONG_, NPY_UINT_),
251+
NPY_FLOAT32_ = platform_lookup<float, double, float, long double>(
252+
NPY_DOUBLE_, NPY_FLOAT_, NPY_LONGDOUBLE_),
253+
NPY_FLOAT64_ = platform_lookup<double, double, float, long double>(
254+
NPY_DOUBLE_, NPY_FLOAT_, NPY_LONGDOUBLE_),
255+
NPY_COMPLEX64_
256+
= platform_lookup<std::complex<float>,
257+
std::complex<double>,
258+
std::complex<float>,
259+
std::complex<long double>>(NPY_DOUBLE_, NPY_FLOAT_, NPY_LONGDOUBLE_),
260+
NPY_COMPLEX128_
261+
= platform_lookup<std::complex<double>,
262+
std::complex<double>,
263+
std::complex<float>,
264+
std::complex<long double>>(NPY_DOUBLE_, NPY_FLOAT_, NPY_LONGDOUBLE_),
265+
NPY_CHAR_ = std::is_signed<char>::value ? NPY_BYTE_ : NPY_UBYTE_,
248266
};
249267

250268
unsigned int PyArray_RUNTIME_VERSION_;
@@ -268,6 +286,7 @@ struct npy_api {
268286

269287
unsigned int (*PyArray_GetNDArrayCFeatureVersion_)();
270288
PyObject *(*PyArray_DescrFromType_)(int);
289+
PyObject *(*PyArray_TypeObjectFromType_)(int);
271290
PyObject *(*PyArray_NewFromDescr_)(PyTypeObject *,
272291
PyObject *,
273292
int,
@@ -284,6 +303,8 @@ struct npy_api {
284303
PyTypeObject *PyVoidArrType_Type_;
285304
PyTypeObject *PyArrayDescr_Type_;
286305
PyObject *(*PyArray_DescrFromScalar_)(PyObject *);
306+
PyObject *(*PyArray_Scalar_)(void *, PyObject *, PyObject *);
307+
void (*PyArray_ScalarAsCtype_)(PyObject *, void *);
287308
PyObject *(*PyArray_FromAny_)(PyObject *, PyObject *, int, int, int, PyObject *);
288309
int (*PyArray_DescrConverter_)(PyObject *, PyObject **);
289310
bool (*PyArray_EquivTypes_)(PyObject *, PyObject *);
@@ -301,7 +322,10 @@ struct npy_api {
301322
API_PyArrayDescr_Type = 3,
302323
API_PyVoidArrType_Type = 39,
303324
API_PyArray_DescrFromType = 45,
325+
API_PyArray_TypeObjectFromType = 46,
304326
API_PyArray_DescrFromScalar = 57,
327+
API_PyArray_Scalar = 60,
328+
API_PyArray_ScalarAsCtype = 62,
305329
API_PyArray_FromAny = 69,
306330
API_PyArray_Resize = 80,
307331
// CopyInto was slot 82 and 50 was effectively an alias. NumPy 2 removed 82.
@@ -336,7 +360,10 @@ struct npy_api {
336360
DECL_NPY_API(PyVoidArrType_Type);
337361
DECL_NPY_API(PyArrayDescr_Type);
338362
DECL_NPY_API(PyArray_DescrFromType);
363+
DECL_NPY_API(PyArray_TypeObjectFromType);
339364
DECL_NPY_API(PyArray_DescrFromScalar);
365+
DECL_NPY_API(PyArray_Scalar);
366+
DECL_NPY_API(PyArray_ScalarAsCtype);
340367
DECL_NPY_API(PyArray_FromAny);
341368
DECL_NPY_API(PyArray_Resize);
342369
DECL_NPY_API(PyArray_CopyInto);
@@ -355,6 +382,83 @@ struct npy_api {
355382
}
356383
};
357384

385+
template <typename T>
386+
struct is_complex : std::false_type {};
387+
template <typename T>
388+
struct is_complex<std::complex<T>> : std::true_type {};
389+
390+
template <typename T, typename = void>
391+
struct npy_format_descriptor_name;
392+
393+
template <typename T>
394+
struct npy_format_descriptor_name<T, enable_if_t<std::is_integral<T>::value>> {
395+
static constexpr auto name = const_name<std::is_same<T, bool>::value>(
396+
const_name("numpy.bool"),
397+
const_name<std::is_signed<T>::value>("numpy.int", "numpy.uint")
398+
+ const_name<sizeof(T) * 8>());
399+
};
400+
401+
template <typename T>
402+
struct npy_format_descriptor_name<T, enable_if_t<std::is_floating_point<T>::value>> {
403+
static constexpr auto name = const_name < std::is_same<T, float>::value
404+
|| std::is_same<T, const float>::value
405+
|| std::is_same<T, double>::value
406+
|| std::is_same<T, const double>::value
407+
> (const_name("numpy.float") + const_name<sizeof(T) * 8>(),
408+
const_name("numpy.longdouble"));
409+
};
410+
411+
template <typename T>
412+
struct npy_format_descriptor_name<T, enable_if_t<is_complex<T>::value>> {
413+
static constexpr auto name = const_name < std::is_same<typename T::value_type, float>::value
414+
|| std::is_same<typename T::value_type, const float>::value
415+
|| std::is_same<typename T::value_type, double>::value
416+
|| std::is_same<typename T::value_type, const double>::value
417+
> (const_name("numpy.complex")
418+
+ const_name<sizeof(typename T::value_type) * 16>(),
419+
const_name("numpy.longcomplex"));
420+
};
421+
422+
template <typename T>
423+
struct numpy_scalar_info {};
424+
425+
#define PYBIND11_NUMPY_SCALAR_IMPL(ctype_, typenum_) \
426+
template <> \
427+
struct numpy_scalar_info<ctype_> { \
428+
static constexpr auto name = npy_format_descriptor_name<ctype_>::name; \
429+
static constexpr int typenum = npy_api::typenum_##_; \
430+
}
431+
432+
// boolean type
433+
PYBIND11_NUMPY_SCALAR_IMPL(bool, NPY_BOOL);
434+
435+
// character types
436+
PYBIND11_NUMPY_SCALAR_IMPL(char, NPY_CHAR);
437+
PYBIND11_NUMPY_SCALAR_IMPL(signed char, NPY_BYTE);
438+
PYBIND11_NUMPY_SCALAR_IMPL(unsigned char, NPY_UBYTE);
439+
440+
// signed integer types
441+
PYBIND11_NUMPY_SCALAR_IMPL(std::int16_t, NPY_INT16);
442+
PYBIND11_NUMPY_SCALAR_IMPL(std::int32_t, NPY_INT32);
443+
PYBIND11_NUMPY_SCALAR_IMPL(std::int64_t, NPY_INT64);
444+
445+
// unsigned integer types
446+
PYBIND11_NUMPY_SCALAR_IMPL(std::uint16_t, NPY_UINT16);
447+
PYBIND11_NUMPY_SCALAR_IMPL(std::uint32_t, NPY_UINT32);
448+
PYBIND11_NUMPY_SCALAR_IMPL(std::uint64_t, NPY_UINT64);
449+
450+
// floating point types
451+
PYBIND11_NUMPY_SCALAR_IMPL(float, NPY_FLOAT);
452+
PYBIND11_NUMPY_SCALAR_IMPL(double, NPY_DOUBLE);
453+
PYBIND11_NUMPY_SCALAR_IMPL(long double, NPY_LONGDOUBLE);
454+
455+
// complex types
456+
PYBIND11_NUMPY_SCALAR_IMPL(std::complex<float>, NPY_CFLOAT);
457+
PYBIND11_NUMPY_SCALAR_IMPL(std::complex<double>, NPY_CDOUBLE);
458+
PYBIND11_NUMPY_SCALAR_IMPL(std::complex<long double>, NPY_CLONGDOUBLE);
459+
460+
#undef PYBIND11_NUMPY_SCALAR_IMPL
461+
358462
// This table normalizes typenums by mapping NPY_INT_, NPY_LONG, ... to NPY_INT32_, NPY_INT64, ...
359463
// This is needed to correctly handle situations where multiple typenums map to the same type,
360464
// e.g. NPY_LONG_ may be equivalent to NPY_INT_ or NPY_LONGLONG_ despite having a different
@@ -453,10 +557,6 @@ template <typename T>
453557
struct is_std_array : std::false_type {};
454558
template <typename T, size_t N>
455559
struct is_std_array<std::array<T, N>> : std::true_type {};
456-
template <typename T>
457-
struct is_complex : std::false_type {};
458-
template <typename T>
459-
struct is_complex<std::complex<T>> : std::true_type {};
460560

461561
template <typename T>
462562
struct array_info_scalar {
@@ -670,8 +770,65 @@ template <typename T, ssize_t Dim>
670770
struct type_caster<unchecked_mutable_reference<T, Dim>>
671771
: type_caster<unchecked_reference<T, Dim>> {};
672772

773+
template <typename T>
774+
struct type_caster<numpy_scalar<T>> {
775+
using value_type = T;
776+
using type_info = numpy_scalar_info<T>;
777+
778+
PYBIND11_TYPE_CASTER(numpy_scalar<T>, type_info::name);
779+
780+
static handle &target_type() {
781+
static handle tp = npy_api::get().PyArray_TypeObjectFromType_(type_info::typenum);
782+
return tp;
783+
}
784+
785+
static handle &target_dtype() {
786+
static handle tp = npy_api::get().PyArray_DescrFromType_(type_info::typenum);
787+
return tp;
788+
}
789+
790+
bool load(handle src, bool) {
791+
if (isinstance(src, target_type())) {
792+
npy_api::get().PyArray_ScalarAsCtype_(src.ptr(), &value.value);
793+
return true;
794+
}
795+
return false;
796+
}
797+
798+
static handle cast(numpy_scalar<T> src, return_value_policy, handle) {
799+
return npy_api::get().PyArray_Scalar_(&src.value, target_dtype().ptr(), nullptr);
800+
}
801+
};
802+
673803
PYBIND11_NAMESPACE_END(detail)
674804

805+
template <typename T>
806+
struct numpy_scalar {
807+
using value_type = T;
808+
809+
value_type value;
810+
811+
numpy_scalar() = default;
812+
explicit numpy_scalar(value_type value) : value(value) {}
813+
814+
explicit operator value_type() const { return value; }
815+
numpy_scalar &operator=(value_type value) {
816+
this->value = value;
817+
return *this;
818+
}
819+
820+
friend bool operator==(const numpy_scalar &a, const numpy_scalar &b) {
821+
return a.value == b.value;
822+
}
823+
824+
friend bool operator!=(const numpy_scalar &a, const numpy_scalar &b) { return !(a == b); }
825+
};
826+
827+
template <typename T>
828+
numpy_scalar<T> make_scalar(T value) {
829+
return numpy_scalar<T>(value);
830+
}
831+
675832
class dtype : public object {
676833
public:
677834
PYBIND11_OBJECT_DEFAULT(dtype, object, detail::npy_api::get().PyArrayDescr_Check_)
@@ -1409,38 +1566,6 @@ struct compare_buffer_info<T, detail::enable_if_t<detail::is_pod_struct<T>::valu
14091566
}
14101567
};
14111568

1412-
template <typename T, typename = void>
1413-
struct npy_format_descriptor_name;
1414-
1415-
template <typename T>
1416-
struct npy_format_descriptor_name<T, enable_if_t<std::is_integral<T>::value>> {
1417-
static constexpr auto name = const_name<std::is_same<T, bool>::value>(
1418-
const_name("bool"),
1419-
const_name<std::is_signed<T>::value>("numpy.int", "numpy.uint")
1420-
+ const_name<sizeof(T) * 8>());
1421-
};
1422-
1423-
template <typename T>
1424-
struct npy_format_descriptor_name<T, enable_if_t<std::is_floating_point<T>::value>> {
1425-
static constexpr auto name = const_name < std::is_same<T, float>::value
1426-
|| std::is_same<T, const float>::value
1427-
|| std::is_same<T, double>::value
1428-
|| std::is_same<T, const double>::value
1429-
> (const_name("numpy.float") + const_name<sizeof(T) * 8>(),
1430-
const_name("numpy.longdouble"));
1431-
};
1432-
1433-
template <typename T>
1434-
struct npy_format_descriptor_name<T, enable_if_t<is_complex<T>::value>> {
1435-
static constexpr auto name = const_name < std::is_same<typename T::value_type, float>::value
1436-
|| std::is_same<typename T::value_type, const float>::value
1437-
|| std::is_same<typename T::value_type, double>::value
1438-
|| std::is_same<typename T::value_type, const double>::value
1439-
> (const_name("numpy.complex")
1440-
+ const_name<sizeof(typename T::value_type) * 16>(),
1441-
const_name("numpy.longcomplex"));
1442-
};
1443-
14441569
template <typename T>
14451570
struct npy_format_descriptor<
14461571
T,

tests/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,7 @@ set(PYBIND11_TEST_FILES
159159
test_native_enum
160160
test_numpy_array
161161
test_numpy_dtypes
162+
test_numpy_scalars
162163
test_numpy_vectorize
163164
test_opaque_types
164165
test_operator_overloading

tests/test_numpy_scalars.cpp

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
/*
2+
tests/test_numpy_scalars.cpp -- strict NumPy scalars
3+
4+
Copyright (c) 2021 Steve R. Sun
5+
6+
All rights reserved. Use of this source code is governed by a
7+
BSD-style license that can be found in the LICENSE file.
8+
*/
9+
10+
#include <pybind11/numpy.h>
11+
12+
#include "pybind11_tests.h"
13+
14+
#include <complex>
15+
#include <cstdint>
16+
17+
namespace py = pybind11;
18+
19+
namespace pybind11_test_numpy_scalars {
20+
21+
template <typename T>
22+
struct add {
23+
T x;
24+
explicit add(T x) : x(x) {}
25+
T operator()(T y) const { return static_cast<T>(x + y); }
26+
};
27+
28+
template <typename T, typename F>
29+
void register_test(py::module &m, const char *name, F &&func) {
30+
m.def((std::string("test_") + name).c_str(),
31+
[=](py::numpy_scalar<T> v) {
32+
return std::make_tuple(name, py::make_scalar(static_cast<T>(func(v.value))));
33+
},
34+
py::arg("x"));
35+
}
36+
37+
} // namespace pybind11_test_numpy_scalars
38+
39+
using namespace pybind11_test_numpy_scalars;
40+
41+
TEST_SUBMODULE(numpy_scalars, m) {
42+
using cfloat = std::complex<float>;
43+
using cdouble = std::complex<double>;
44+
45+
register_test<bool>(m, "bool", [](bool x) { return !x; });
46+
register_test<int8_t>(m, "int8", add<int8_t>(-8));
47+
register_test<int16_t>(m, "int16", add<int16_t>(-16));
48+
register_test<int32_t>(m, "int32", add<int32_t>(-32));
49+
register_test<int64_t>(m, "int64", add<int64_t>(-64));
50+
register_test<uint8_t>(m, "uint8", add<uint8_t>(8));
51+
register_test<uint16_t>(m, "uint16", add<uint16_t>(16));
52+
register_test<uint32_t>(m, "uint32", add<uint32_t>(32));
53+
register_test<uint64_t>(m, "uint64", add<uint64_t>(64));
54+
register_test<float>(m, "float32", add<float>(0.125f));
55+
register_test<double>(m, "float64", add<double>(0.25f));
56+
register_test<cfloat>(m, "complex64", add<cfloat>({0, -0.125f}));
57+
register_test<cdouble>(m, "complex128", add<cdouble>({0, -0.25f}));
58+
59+
m.def("test_eq",
60+
[](py::numpy_scalar<int32_t> a, py::numpy_scalar<int32_t> b) { return a == b; });
61+
m.def("test_ne",
62+
[](py::numpy_scalar<int32_t> a, py::numpy_scalar<int32_t> b) { return a != b; });
63+
}

0 commit comments

Comments
 (0)