-
Notifications
You must be signed in to change notification settings - Fork 14.1k
[mlir][python] add type wrappers #71218
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
b5d4d99
to
5552089
Compare
34c7cf4
to
d809aa5
Compare
✅ With the latest revision this PR passed the Python code formatter. |
83be73b
to
3237d3a
Compare
@llvm/pr-subscribers-mlir Author: Maksim Levental (makslevental) ChangesInspired by this comment #71050 (comment), let's give people a nicer way to instantiate types. The goal of the design here is to provide a module of Pythonic type builders but have simple types (such as The solution uses a little-known python feature of attribute resolution on modules, namely that you can "override" There's some room for bikeshedding/discussion here, e.g., whether these type builders such be spelled with a Full diff: https://github.com/llvm/llvm-project/pull/71218.diff 4 Files Affected:
diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp
index 483db673f989e6b..56e895d3053796e 100644
--- a/mlir/lib/Bindings/Python/IRTypes.cpp
+++ b/mlir/lib/Bindings/Python/IRTypes.cpp
@@ -463,7 +463,7 @@ class PyVectorType : public PyConcreteType<PyVectorType, PyShapedType> {
static void bindDerived(ClassTy &c) {
c.def_static("get", &PyVectorType::get, py::arg("shape"),
- py::arg("elementType"), py::kw_only(),
+ py::arg("element_type"), py::kw_only(),
py::arg("scalable") = py::none(),
py::arg("scalable_dims") = py::none(),
py::arg("loc") = py::none(), "Create a vector type")
@@ -689,13 +689,9 @@ class PyTupleType : public PyConcreteType<PyTupleType> {
static void bindDerived(ClassTy &c) {
c.def_static(
"get_tuple",
- [](py::list elementList, DefaultingPyMlirContext context) {
- intptr_t num = py::len(elementList);
- // Mapping py::list to SmallVector.
- SmallVector<MlirType, 4> elements;
- for (auto element : elementList)
- elements.push_back(element.cast<PyType>());
- MlirType t = mlirTupleTypeGet(context->get(), num, elements.data());
+ [](std::vector<MlirType> elements, DefaultingPyMlirContext context) {
+ MlirType t = mlirTupleTypeGet(context->get(), elements.size(),
+ elements.data());
return PyTupleType(context->getRef(), t);
},
py::arg("elements"), py::arg("context") = py::none(),
@@ -727,13 +723,11 @@ class PyFunctionType : public PyConcreteType<PyFunctionType> {
static void bindDerived(ClassTy &c) {
c.def_static(
"get",
- [](std::vector<PyType> inputs, std::vector<PyType> results,
+ [](std::vector<MlirType> inputs, std::vector<MlirType> results,
DefaultingPyMlirContext context) {
- SmallVector<MlirType, 4> inputsRaw(inputs.begin(), inputs.end());
- SmallVector<MlirType, 4> resultsRaw(results.begin(), results.end());
- MlirType t = mlirFunctionTypeGet(context->get(), inputsRaw.size(),
- inputsRaw.data(), resultsRaw.size(),
- resultsRaw.data());
+ MlirType t =
+ mlirFunctionTypeGet(context->get(), inputs.size(), inputs.data(),
+ results.size(), results.data());
return PyFunctionType(context->getRef(), t);
},
py::arg("inputs"), py::arg("results"), py::arg("context") = py::none(),
@@ -742,7 +736,6 @@ class PyFunctionType : public PyConcreteType<PyFunctionType> {
"inputs",
[](PyFunctionType &self) {
MlirType t = self;
- auto contextRef = self.getContext();
py::list types;
for (intptr_t i = 0, e = mlirFunctionTypeGetNumInputs(self); i < e;
++i) {
@@ -754,7 +747,6 @@ class PyFunctionType : public PyConcreteType<PyFunctionType> {
c.def_property_readonly(
"results",
[](PyFunctionType &self) {
- auto contextRef = self.getContext();
py::list types;
for (intptr_t i = 0, e = mlirFunctionTypeGetNumResults(self); i < e;
++i) {
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index 971ad2dd214a15f..12e2dab60f3011b 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -21,6 +21,7 @@ declare_mlir_python_sources(MLIRPythonSources.Core.Python
_mlir_libs/__init__.py
ir.py
passmanager.py
+ types.py
dialects/_ods_common.py
# The main _mlir module has submodules: include stubs from each.
diff --git a/mlir/python/mlir/types.py b/mlir/python/mlir/types.py
new file mode 100644
index 000000000000000..ce8d826b40a6b1d
--- /dev/null
+++ b/mlir/python/mlir/types.py
@@ -0,0 +1,207 @@
+from functools import partial
+from typing import Optional
+
+from .ir import (
+ Attribute,
+ BF16Type,
+ ComplexType,
+ F16Type,
+ F32Type,
+ F64Type,
+ Float8E4M3B11FNUZType,
+ Float8E4M3FNType,
+ Float8E5M2Type,
+ FunctionType,
+ IndexType,
+ IntegerType,
+ MemRefType,
+ NoneType,
+ OpaqueType,
+ RankedTensorType,
+ StridedLayoutAttr,
+ StringAttr,
+ TupleType,
+ Type,
+ UnrankedMemRefType,
+ UnrankedTensorType,
+ VectorType,
+)
+
+from .dialects import transform
+from .dialects import pdl
+
+
+_index = lambda: IndexType.get()
+_bool = lambda: IntegerType.get_signless(1)
+
+_i8 = lambda: IntegerType.get_signless(8)
+_i16 = lambda: IntegerType.get_signless(16)
+_i32 = lambda: IntegerType.get_signless(32)
+_i64 = lambda: IntegerType.get_signless(64)
+
+_si8 = lambda: IntegerType.get_signed(8)
+_si16 = lambda: IntegerType.get_signed(16)
+_si32 = lambda: IntegerType.get_signed(32)
+_si64 = lambda: IntegerType.get_signed(64)
+
+_ui8 = lambda: IntegerType.get_unsigned(8)
+_ui16 = lambda: IntegerType.get_unsigned(16)
+_ui32 = lambda: IntegerType.get_unsigned(32)
+_ui64 = lambda: IntegerType.get_unsigned(64)
+
+_f16 = lambda: F16Type.get()
+_f32 = lambda: F32Type.get()
+_f64 = lambda: F64Type.get()
+_bf16 = lambda: BF16Type.get()
+
+_f8e5m2 = lambda: Float8E5M2Type.get()
+_f8e4m3 = lambda: Float8E4M3FNType.get()
+_f8e4m3b11fnuz = lambda: Float8E4M3B11FNUZType.get()
+
+_cmp16 = lambda: ComplexType.get(_f16())
+_cmp32 = lambda: ComplexType.get(_f32())
+_cmp64 = lambda: ComplexType.get(_f64())
+
+_none = lambda: NoneType.get()
+
+_pdl_operation = lambda: pdl.OperationType.get()
+
+
+def _transform_any_op():
+ return transform.AnyOpType.get()
+
+
+_name_to_type = {
+ "index": _index,
+ "bool": _bool,
+ "i8": _i8,
+ "i16": _i16,
+ "i32": _i32,
+ "i64": _i64,
+ "si8": _si8,
+ "si16": _si16,
+ "si32": _si32,
+ "si64": _si64,
+ "ui8": _ui8,
+ "ui16": _ui16,
+ "ui32": _ui32,
+ "ui64": _ui64,
+ "f16": _f16,
+ "f32": _f32,
+ "f64": _f64,
+ "bf16": _bf16,
+ "f8e5m2": _f8e5m2,
+ "f8e4m3": _f8e4m3,
+ "f8e4m3b11fnuz": _f8e4m3b11fnuz,
+ "cmp16": _cmp16,
+ "cmp32": _cmp32,
+ "cmp64": _cmp64,
+ "none": _none,
+ "pdl_operation": _pdl_operation,
+ "transform_any_op": _transform_any_op,
+}
+
+
+def __getattr__(name):
+ if name in _name_to_type:
+ return _name_to_type[name]()
+ # This delegates the lookup to default module attribute lookup
+ # (i.e., functions defined below and such).
+ return None
+
+
+def transform_op(name):
+ return transform.OperationType.get(name)
+
+
+def opaque(dialect_namespace, type_data):
+ return OpaqueType.get(dialect_namespace, type_data)
+
+
+def _shaped(*args, element_type: Type = None, type_constructor=None):
+ if type_constructor is None:
+ raise ValueError("shaped is an abstract base class - cannot be constructed.")
+ if (element_type is None and args and not isinstance(args[-1], Type)) or (
+ args and isinstance(args[-1], Type) and element_type is not None
+ ):
+ raise ValueError(
+ f"Either element_type must be provided explicitly XOR last arg to tensor type constructor must be the element type."
+ )
+ if element_type is not None:
+ type = element_type
+ sizes = args
+ else:
+ type = args[-1]
+ sizes = args[:-1]
+ if sizes:
+ return type_constructor(sizes, type)
+ else:
+ return type_constructor(type)
+
+
+def vector(
+ *args,
+ element_type: Type = None,
+ scalable: Optional[list[bool]] = None,
+ scalable_dims: Optional[list[int]] = None,
+):
+ return _shaped(
+ *args,
+ element_type=element_type,
+ type_constructor=partial(
+ VectorType.get, scalable=scalable, scalable_dims=scalable_dims
+ ),
+ )
+
+
+def tensor(*args, element_type: Type = None, encoding: Optional[str] = None):
+ if encoding is not None:
+ encoding = StringAttr.get(encoding)
+ if not len(args) or len(args) == 1 and isinstance(args[-1], Type):
+ if encoding is not None:
+ raise ValueError("UnrankedTensorType does not support encoding.")
+ return _shaped(
+ *args, element_type=element_type, type_constructor=UnrankedTensorType.get
+ )
+ else:
+ return _shaped(
+ *args,
+ element_type=element_type,
+ type_constructor=partial(RankedTensorType.get, encoding=encoding),
+ )
+
+
+def stride(strides, offset: Optional[int] = 0):
+ return StridedLayoutAttr.get(offset, strides)
+
+
+def memref(
+ *args,
+ element_type: Type = None,
+ memory_space: Optional[int] = None,
+ layout: Optional[tuple[tuple[int, ...], int]] = None,
+):
+ if memory_space is not None:
+ memory_space = Attribute.parse(str(memory_space))
+ if not len(args) or len(args) == 1 and isinstance(args[-1], Type):
+ return _shaped(
+ *args,
+ element_type=element_type,
+ type_constructor=partial(UnrankedMemRefType.get, memory_space=memory_space),
+ )
+ else:
+ return _shaped(
+ *args,
+ element_type=element_type,
+ type_constructor=partial(
+ MemRefType.get, memory_space=memory_space, layout=layout
+ ),
+ )
+
+
+def tuple(*elements):
+ return TupleType.get_tuple(elements)
+
+
+def function(inputs, results):
+ return FunctionType.get(inputs, results)
diff --git a/mlir/test/python/ir/builtin_types.py b/mlir/test/python/ir/builtin_types.py
index d4fed86b4f135ee..203dcee263b958b 100644
--- a/mlir/test/python/ir/builtin_types.py
+++ b/mlir/test/python/ir/builtin_types.py
@@ -772,3 +772,95 @@ def testCustomTypeTypeCaster():
print(t)
# CHECK: OperationType(!transform.op<"foo.bar">)
print(repr(t))
+
+
+# CHECK-LABEL: TEST: testTypeWrappers
+@run
+def testTypeWrappers():
+ try:
+ from mlir.types import i32
+ except RuntimeError as e:
+ assert e.args[0].startswith(
+ "An MLIR function requires a Context but none was provided"
+ )
+
+ import mlir.types as T
+ from mlir.types import vector, tensor
+
+ with Context(), Location.unknown():
+ c1 = T.cmp16
+ c2 = T.cmp32
+ assert repr(c1) == "ComplexType(complex<f16>)"
+ assert repr(c2) == "ComplexType(complex<f32>)"
+
+ vec_1 = vector(2, 3, T.f32)
+ vec_2 = vector(2, 3, 4, T.f32)
+ assert repr(vec_1) == "VectorType(vector<2x3xf32>)"
+ assert repr(vec_2) == "VectorType(vector<2x3x4xf32>)"
+
+ m1 = T.memref(2, 3, 4, T.f64)
+ assert repr(m1) == "MemRefType(memref<2x3x4xf64>)"
+
+ m2 = T.memref(2, 3, 4, T.f64, memory_space=1)
+ assert repr(m2) == "MemRefType(memref<2x3x4xf64, 1>)"
+
+ m3 = T.memref(2, 3, 4, T.f64, memory_space=1, layout=T.stride([5, 7, 13]))
+ assert repr(m3) == "MemRefType(memref<2x3x4xf64, strided<[5, 7, 13]>, 1>)"
+
+ m4 = T.memref(2, 3, 4, T.f64, memory_space=1, layout=T.stride([5, 7, 13], 42))
+ assert (
+ repr(m4)
+ == "MemRefType(memref<2x3x4xf64, strided<[5, 7, 13], offset: 42>, 1>)"
+ )
+
+ S = ShapedType.get_dynamic_size()
+
+ t = T.tensor(S, 3, S, T.f64)
+ assert repr(t) == "RankedTensorType(tensor<?x3x?xf64>)"
+ ut = tensor(T.f64)
+ assert repr(ut) == "UnrankedTensorType(tensor<*xf64>)"
+ t = tensor(S, 3, S, element_type=T.f64)
+ assert repr(t) == "RankedTensorType(tensor<?x3x?xf64>)"
+ ut = tensor(element_type=T.f64)
+ assert repr(ut) == "UnrankedTensorType(tensor<*xf64>)"
+
+ v = vector(3, 3, 3, T.f64)
+ assert repr(v) == "VectorType(vector<3x3x3xf64>)"
+
+ m = T.memref(S, 3, S, T.f64)
+ assert repr(m) == "MemRefType(memref<?x3x?xf64>)"
+ um = T.memref(T.f64)
+ assert repr(um) == "UnrankedMemRefType(memref<*xf64>)"
+ m = T.memref(S, 3, S, element_type=T.f64)
+ assert repr(m) == "MemRefType(memref<?x3x?xf64>)"
+ um = T.memref(element_type=T.f64)
+ assert repr(um) == "UnrankedMemRefType(memref<*xf64>)"
+
+ m = T.memref(S, 3, S, T.f64)
+ assert repr(m) == "MemRefType(memref<?x3x?xf64>)"
+ um = T.memref(T.f64)
+ assert repr(um) == "UnrankedMemRefType(memref<*xf64>)"
+
+ scalable_1 = vector(2, 3, T.f32, scalable=[False, True])
+ scalable_2 = vector(2, 3, 4, T.f32, scalable=[True, False, True])
+ assert repr(scalable_1) == "VectorType(vector<2x[3]xf32>)"
+ assert repr(scalable_2) == "VectorType(vector<[2]x3x[4]xf32>)"
+
+ scalable_3 = vector(2, 3, T.f32, scalable_dims=[1])
+ scalable_4 = vector(2, 3, 4, T.f32, scalable_dims=[0, 2])
+ assert scalable_3 == scalable_1
+ assert scalable_4 == scalable_2
+
+ opaq = T.opaque("scf", "placeholder")
+ assert repr(opaq) == "OpaqueType(!scf.placeholder)"
+
+ transfor_op = T.transform_op("foo.bar")
+ assert repr(transfor_op) == 'OperationType(!transform.op<"foo.bar">)'
+
+ tup1 = T.tuple(T.i16, T.i32, T.i64)
+ tup2 = T.tuple(T.f16, T.f32, T.f64)
+ assert repr(tup1) == "TupleType(tuple<i16, i32, i64>)"
+ assert repr(tup2) == "TupleType(tuple<f16, f32, f64>)"
+
+ func = T.function((T.i16, T.i32, T.i64), (T.f16, T.f32, T.f64))
+ assert repr(func) == "FunctionType((i16, i32, i64) -> (f16, f32, f64))"
|
3237d3a
to
85dc525
Compare
f817adf
to
c5aefdc
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks nice from code readability perspective, especially if generalized to be used more pervasively.
My overall thinking here is that we want to delimit optional "sugar" from fundamental APIs, and make sure unsugared APIs are still usable. Any thoughts on how to make this clear?
I'm not sure how much of this should be treated as sugar, maybe we can just have a "types" object in dialect that contains such constructor functions?
Overall, this also needs more documentation. Both docstrings and overall design in https://mlir.llvm.org/docs/Bindings/Python/. (btw, magic value downcasting should also be described there).
Well, as the comment1 that was the impetus for this PR suggests, the fundamental APIs prioritize "authenticity" over ergonomics and that's fine. It's useful that they more closely reflect the C APIs underneath because then the bindings partially serve as a gentle on-ramp to the rest of the codebase. Regarding making things more clearly delineated - we could move these to a
Generating a whole bunch of this from the ODS is feasible, it's just tedious, corner-case chasing, implementation.
Good point. Will add in this PR. Footnotes
|
b0e4de5
to
7d1d4cc
Compare
mlir/python/mlir/types.py
Outdated
# This module is NOT a package and so this must be None (rather than throw the RuntimeError below). | ||
return None | ||
try: | ||
Context.current |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I would prefer that this just returned None
instead of throwing a ValueError
- that way I could check is None
here instead of try -> except -> raise
. And arguably properties shouldn't raise since they're supposed to emulate fields. What say you @ftynse?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd second using None
. The consensus seems to be that raising exceptions from property getters is undesirable: https://stackoverflow.com/questions/1488472/best-practices-throwing-exceptions-from-properties, https://stackoverflow.com/questions/48778945/by-design-should-a-property-getter-ever-throw-an-exception-in-python with the latter referencing PEP-8.
That being said, having a property that returns something or None
based on it being queried within a certain context manager or not also doesn't quite agree with the idea of fields either. Maybe you'll give the idea of consistently using functions a second thought based on that (getter functions are recommended if their body can raise errors). I understand the aesthetic appeal of fields for type annotations, but I would rather not base all decisions based exclusively on that. After all, we could have several mechanisms, one based on each other that serve different purposes.
For example, we could have type constructors as functions available to all clients, and an additional sugaring with properties somewhere in mlir.sugar.typing
that is more specific to from_py_func
. If somebody doesn't use from_py_func
(I don't in my downstreams as the amount of complexity it comes with isn't compensated by the sweetness of sugaring for my taste), they don't have to pay the complexity cost for something they don't need.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe you'll give the idea of consistently using functions a second thought based on that (getter functions are recommended if their body can raise errors)
I'm not opinionated - I just thought it was aesthetically pleasing - but "least surprising" is the practical approach so I'm happy to change it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So the most surprising thing for me in the original version was that from mlir.(sugar.)types import i32
could fail with something that is not an ImportError
. :)
837c147
to
e383d52
Compare
e383d52
to
8282a78
Compare
If this can be placed in a subdirectory, even temporarily, in a subdirectory with a clear indication that the design may evolve, I have no objection to landing this. |
How about into runtime since there are various other type-y things in there already? |
Those are specific to interaction with numpy. The things added here are more for IR construction, not for using objects of said type at runtime. It also doesn't convey the "this API is more likely to change than other parts" message. I'm fine with any new directory: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice, thanks! Let's land this, have some mileage with it and iterate if needed.
151b846
to
d1a9c4c
Compare
Inspired by this comment #71050 (comment), let's give people a nicer way to instantiate types.
The goal of the design here is to provide a module of Pythonic type builders but have simple types (such as
i32
,f16
, etc.) appear is if they've already been built (so that users don't need to call a seemingly superfluous.get()
). The problem of course is that types can't be "pre-instantiated" in an arbitrary module since they need a context (and requiring that the module itself only be imported under awith Context()...
is weird IMHO).The solution uses a little-known python feature of attribute resolution on modules, namely that you can "override"
__getattr__
for a module. And sotypes.i32
,types.f64
, etc. appear if they're using already instantiated types but in actuality are just callingIntegerType.get_signless(32)
,F64Type.get()
, etc., at point-of-use.There's some room for bikeshedding/discussion here, e.g., whether these type builders such be spelled with a
_t
suffix to distinguish e.g.,vector
dialect fromtypes.vector
type. I didn't do this because like I describe above, these are meant to most often be accessed through thetypes
module. Also, with the goal of matching as closely the appearance of MLIR IR source as possible, all the complex types builder (vector
,memref
,tensor
) take their shape args as a*shape
. That means no keyword-only args are possible (because you can't have two*
s in the params list of a Python function). Keyword args still work (I just had to add a little boilerplate checking in the*shape
tuple) but keyword-only args can't be supported. The alternative is to require the shape to be a list/tuple.