Skip to content

Commit a7d81b5

Browse files
wrengrtensorflower-gardener
authored andcommitted
[XLA:Python] Adding xla::PrimitiveType <-> numpy.dtype conversions to the library for internal debugging tools.
PiperOrigin-RevId: 620068167
1 parent 9926a40 commit a7d81b5

File tree

5 files changed

+44
-10
lines changed

5 files changed

+44
-10
lines changed

third_party/xla/xla/python/tools/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ tsl_pybind_extension(
6363
"//xla:literal",
6464
"//xla:xla_data_proto_cc",
6565
"//xla/python:logging",
66+
"//xla/python:nb_numpy",
6667
"//xla/python:types",
6768
"@com_google_absl//absl/status:statusor",
6869
"@com_google_absl//absl/strings",

third_party/xla/xla/python/tools/_types.cc

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ limitations under the License.
2525
#include "pybind11_protobuf/native_proto_caster.h" // from @pybind11_protobuf
2626
#include "xla/literal.h"
2727
#include "xla/python/logging.h"
28+
#include "xla/python/nb_numpy.h"
2829
#include "xla/python/types.h"
2930
#include "xla/xla_data.pb.h"
3031
// NOTE: The tsl-numpy header forbids importing the actual NumPy arrayobject.h
@@ -59,6 +60,19 @@ absl::StatusOr<py::object> MakeNdarray(const xla::LiteralProto& proto) {
5960
// Convert `nb::object` into `py::object`.
6061
return py::reinterpret_steal<py::object>(nbobj.release().ptr());
6162
}
63+
64+
// Partial reversion of cl/617156835, until we can get the proto-casters
65+
// (and hence the extension) switched over to nanobind.
66+
// TODO(wrengr): Or can we mix `{py,nb}::module_::def` calls??
67+
absl::StatusOr<xla::PrimitiveType> DtypeToEtype(const py::dtype& py_d) {
68+
auto nb_d = nb::borrow<xla::nb_dtype>(py_d.ptr());
69+
return xla::DtypeToPrimitiveType(nb_d);
70+
}
71+
72+
absl::StatusOr<py::dtype> EtypeToDtype(xla::PrimitiveType p) {
73+
TF_ASSIGN_OR_RETURN(xla::nb_dtype nb_d, xla::PrimitiveTypeToNbDtype(p));
74+
return py::reinterpret_steal<py::dtype>(nb_d.release().ptr());
75+
}
6276
} // namespace
6377

6478
// NOTE: It seems insurmountable to get "native_proto_caster.h" to work
@@ -98,7 +112,8 @@ PYBIND11_MODULE(_types, py_m) {
98112
py::module_::import("ml_dtypes");
99113

100114
// Ensure that tsl-numpy initializes datastructures of the actual-NumPy
101-
// implementation, and does whatever else tsl-numpy needs.
115+
// implementation, and does whatever else tsl-numpy needs. This is
116+
// also necessary for using the `xla::nb_dtype` type.
102117
tsl::ImportNumpy();
103118

104119
// Declare that C++ can `nb::cast` from `std::shared_ptr<xla::Literal>`
@@ -124,5 +139,21 @@ PYBIND11_MODULE(_types, py_m) {
124139
of tuples with leaves being `numpy.ndarray` views of array-shaped
125140
sub-literals.
126141
)pbdoc");
142+
143+
// This method name is based on `xla_client.dtype_to_etype`.
144+
// NOTE: `xla_client` uses a Python class wrapping the protobuf-enum,
145+
// rather than using the protobuf-enum directly. See the module docstring
146+
// in "types.py" for more explanation on why.
147+
py_m.def("dtype_to_etype", &DtypeToEtype, py::arg("dtype").none(false),
148+
py::pos_only(), R"pbdoc(
149+
Converts `numpy.dtype` into
150+
`tensorflow.compiler.xla.xla_data_pb2.PrimitiveType`.
151+
)pbdoc");
152+
153+
py_m.def("etype_to_dtype", &EtypeToDtype, py::arg("ptype").none(false),
154+
py::pos_only(), R"pbdoc(
155+
Converts `tensorflow.compiler.xla.xla_data_pb2.PrimitiveType` into
156+
`numpy.dtype`.
157+
)pbdoc");
127158
// LINT.ThenChange(_types.pyi)
128159
}

third_party/xla/xla/python/tools/_types.pyi

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,4 +20,6 @@ from xla import xla_data_pb2
2020
# LINT.IfChange
2121
NdarrayTree = Union[np.ndarray, tuple['NdarrayTree', ...]]
2222
def make_ndarray(proto: xla_data_pb2.LiteralProto, /) -> NdarrayTree: ...
23+
def dtype_to_etype(dtype: np.dtype, /) -> xla_data_pb2.PrimitiveType: ...
24+
def etype_to_dtype(ptype: xla_data_pb2.PrimitiveType, /) -> np.dtype: ...
2325
# LINT.ThenChange(types.py, _types.cc)

third_party/xla/xla/python/tools/types.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,12 @@
4040

4141
# NOTE: `import <name> as <name>` is required for names to be exported.
4242
# See PEP 484 & <https://github.com/google/jax/issues/7570>
43-
# pylint: disable=g-importing-member,useless-import-alias,unused-import
43+
# pylint: disable=g-importing-member,useless-import-alias,unused-import,g-multiple-import
4444
# LINT.IfChange
4545
from ._types import (
4646
make_ndarray as make_ndarray,
47+
dtype_to_etype as dtype_to_etype,
48+
etype_to_dtype as etype_to_dtype,
4749
)
4850
# TODO(wrengr): We can't import the `NdarrayTree` defined in the pyi file.
4951
# So re-defining it here for now.

third_party/xla/xla/python/tools/types_test.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -148,14 +148,12 @@ class MakeNdarrayValidTest(parameterized.TestCase):
148148

149149
def testHasCorrectDtype(self, proto, arr):
150150
"""Test that the result has the right dtype."""
151-
# Silence [unused-argument] warning.
152-
del proto
153-
# TODO(wrengr): Add pybind for `xla::PrimitiveTypeToDtype`,
154-
# so that we can avoid hard-coding the expected np.dtype.
155-
# Alternatively, we could use `xla_client.dtype_to_etype` (ideally
156-
# after refactoring that into a small library, so we need not pull in
157-
# all the rest of xla_client).
158-
self.assertEqual(np.float64, arr.dtype)
151+
e = proto.shape.element_type
152+
d = arr.dtype
153+
with self.subTest(msg='etype_to_dtype'):
154+
self.assertEqual(types.etype_to_dtype(e), d)
155+
with self.subTest(msg='dtype_to_etype'):
156+
self.assertEqual(e, types.dtype_to_etype(d))
159157

160158
def testHasCorrectRank(self, proto, arr):
161159
"""Test that the result has the right rank."""

0 commit comments

Comments
 (0)