@@ -25,6 +25,7 @@ limitations under the License.
25
25
#include " pybind11_protobuf/native_proto_caster.h" // from @pybind11_protobuf
26
26
#include " xla/literal.h"
27
27
#include " xla/python/logging.h"
28
+ #include " xla/python/nb_numpy.h"
28
29
#include " xla/python/types.h"
29
30
#include " xla/xla_data.pb.h"
30
31
// NOTE: The tsl-numpy header forbids importing the actual NumPy arrayobject.h
@@ -59,6 +60,19 @@ absl::StatusOr<py::object> MakeNdarray(const xla::LiteralProto& proto) {
59
60
// Convert `nb::object` into `py::object`.
60
61
return py::reinterpret_steal<py::object>(nbobj.release ().ptr ());
61
62
}
63
+
64
+ // Partial reversion of cl/617156835, until we can get the proto-casters
65
+ // (and hence the extension) switched over to nanobind.
66
+ // TODO(wrengr): Or can we mix `{py,nb}::module_::def` calls??
67
+ absl::StatusOr<xla::PrimitiveType> DtypeToEtype (const py::dtype& py_d) {
68
+ auto nb_d = nb::borrow<xla::nb_dtype>(py_d.ptr ());
69
+ return xla::DtypeToPrimitiveType (nb_d);
70
+ }
71
+
72
+ absl::StatusOr<py::dtype> EtypeToDtype (xla::PrimitiveType p) {
73
+ TF_ASSIGN_OR_RETURN (xla::nb_dtype nb_d, xla::PrimitiveTypeToNbDtype (p));
74
+ return py::reinterpret_steal<py::dtype>(nb_d.release ().ptr ());
75
+ }
62
76
} // namespace
63
77
64
78
// NOTE: It seems insurmountable to get "native_proto_caster.h" to work
@@ -98,7 +112,8 @@ PYBIND11_MODULE(_types, py_m) {
98
112
py::module_::import (" ml_dtypes" );
99
113
100
114
// Ensure that tsl-numpy initializes datastructures of the actual-NumPy
101
- // implementation, and does whatever else tsl-numpy needs.
115
+ // implementation, and does whatever else tsl-numpy needs. This is
116
+ // also necessary for using the `xla::nb_dtype` type.
102
117
tsl::ImportNumpy ();
103
118
104
119
// Declare that C++ can `nb::cast` from `std::shared_ptr<xla::Literal>`
@@ -124,5 +139,21 @@ PYBIND11_MODULE(_types, py_m) {
124
139
of tuples with leaves being `numpy.ndarray` views of array-shaped
125
140
sub-literals.
126
141
)pbdoc" );
142
+
143
+ // This method name is based on `xla_client.dtype_to_etype`.
144
+ // NOTE: `xla_client` uses a Python class wrapping the protobuf-enum,
145
+ // rather than using the protobuf-enum directly. See the module docstring
146
+ // in "types.py" for more explanation on why.
147
+ py_m.def (" dtype_to_etype" , &DtypeToEtype, py::arg (" dtype" ).none (false ),
148
+ py::pos_only (), R"pbdoc(
149
+ Converts `numpy.dtype` into
150
+ `tensorflow.compiler.xla.xla_data_pb2.PrimitiveType`.
151
+ )pbdoc" );
152
+
153
+ py_m.def (" etype_to_dtype" , &EtypeToDtype, py::arg (" ptype" ).none (false ),
154
+ py::pos_only (), R"pbdoc(
155
+ Converts `tensorflow.compiler.xla.xla_data_pb2.PrimitiveType` into
156
+ `numpy.dtype`.
157
+ )pbdoc" );
127
158
// LINT.ThenChange(_types.pyi)
128
159
}
0 commit comments