Skip to content

[mlir][python] add type wrappers #71218

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Nov 27, 2023
Merged

Conversation

makslevental
Copy link
Contributor

@makslevental makslevental commented Nov 3, 2023

Inspired by this comment #71050 (comment), let's give people a nicer way to instantiate types.

The goal of the design here is to provide a module of Pythonic type builders but have simple types (such as i32, f16, etc.) appear is if they've already been built (so that users don't need to call a seemingly superfluous .get()). The problem of course is that types can't be "pre-instantiated" in an arbitrary module since they need a context (and requiring that the module itself only be imported under a with Context()... is weird IMHO).

The solution uses a little-known python feature of attribute resolution on modules, namely that you can "override" __getattr__ for a module. And so types.i32, types.f64, etc. appear if they're using already instantiated types but in actuality are just calling IntegerType.get_signless(32), F64Type.get(), etc., at point-of-use.

There's some room for bikeshedding/discussion here, e.g., whether these type builders such be spelled with a _t suffix to distinguish e.g., vector dialect from types.vector type. I didn't do this because like I describe above, these are meant to most often be accessed through the types module. Also, with the goal of matching as closely the appearance of MLIR IR source as possible, all the complex types builder (vector, memref, tensor) take their shape args as a *shape. That means no keyword-only args are possible (because you can't have two *s in the params list of a Python function). Keyword args still work (I just had to add a little boilerplate checking in the *shape tuple) but keyword-only args can't be supported. The alternative is to require the shape to be a list/tuple.

Copy link

github-actions bot commented Nov 13, 2023

✅ With the latest revision this PR passed the Python code formatter.

@makslevental makslevental force-pushed the type_wrappers branch 3 times, most recently from 83be73b to 3237d3a Compare November 13, 2023 18:33
@makslevental makslevental marked this pull request as ready for review November 13, 2023 18:34
@llvmbot llvmbot added mlir:python MLIR Python bindings mlir labels Nov 13, 2023
@llvmbot
Copy link
Member

llvmbot commented Nov 13, 2023

@llvm/pr-subscribers-mlir

Author: Maksim Levental (makslevental)

Changes

Inspired by this comment #71050 (comment), let's give people a nicer way to instantiate types.

The goal of the design here is to provide a module of Pythonic type builders but have simple types (such as i32, f16, etc.) appear is if they've already been built (so that users don't need to call a seemingly superfluous .get()). The problem of course is that types can't be "pre-instantiated" in an arbitrary module since they need a context (and requiring that the module itself only be imported under a with Context()... is weird IMHO).

The solution uses a little-known python feature of attribute resolution on modules, namely that you can "override" __getattr__ for a module. And so types.i32, types.f64, etc. appear if they're using already instantiated types but in actuality are just calling IntegerType.get_signless(32), F64Type.get(), etc., at point-of-use.

There's some room for bikeshedding/discussion here, e.g., whether these type builders such be spelled with a _t suffix to distinguish e.g., vector dialect from types.vector type. I didn't do this because like I describe above, these are meant to most often be accessed through the types module.


Full diff: https://github.com/llvm/llvm-project/pull/71218.diff

4 Files Affected:

  • (modified) mlir/lib/Bindings/Python/IRTypes.cpp (+8-16)
  • (modified) mlir/python/CMakeLists.txt (+1)
  • (added) mlir/python/mlir/types.py (+207)
  • (modified) mlir/test/python/ir/builtin_types.py (+92)
diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp
index 483db673f989e6b..56e895d3053796e 100644
--- a/mlir/lib/Bindings/Python/IRTypes.cpp
+++ b/mlir/lib/Bindings/Python/IRTypes.cpp
@@ -463,7 +463,7 @@ class PyVectorType : public PyConcreteType<PyVectorType, PyShapedType> {
 
   static void bindDerived(ClassTy &c) {
     c.def_static("get", &PyVectorType::get, py::arg("shape"),
-                 py::arg("elementType"), py::kw_only(),
+                 py::arg("element_type"), py::kw_only(),
                  py::arg("scalable") = py::none(),
                  py::arg("scalable_dims") = py::none(),
                  py::arg("loc") = py::none(), "Create a vector type")
@@ -689,13 +689,9 @@ class PyTupleType : public PyConcreteType<PyTupleType> {
   static void bindDerived(ClassTy &c) {
     c.def_static(
         "get_tuple",
-        [](py::list elementList, DefaultingPyMlirContext context) {
-          intptr_t num = py::len(elementList);
-          // Mapping py::list to SmallVector.
-          SmallVector<MlirType, 4> elements;
-          for (auto element : elementList)
-            elements.push_back(element.cast<PyType>());
-          MlirType t = mlirTupleTypeGet(context->get(), num, elements.data());
+        [](std::vector<MlirType> elements, DefaultingPyMlirContext context) {
+          MlirType t = mlirTupleTypeGet(context->get(), elements.size(),
+                                        elements.data());
           return PyTupleType(context->getRef(), t);
         },
         py::arg("elements"), py::arg("context") = py::none(),
@@ -727,13 +723,11 @@ class PyFunctionType : public PyConcreteType<PyFunctionType> {
   static void bindDerived(ClassTy &c) {
     c.def_static(
         "get",
-        [](std::vector<PyType> inputs, std::vector<PyType> results,
+        [](std::vector<MlirType> inputs, std::vector<MlirType> results,
            DefaultingPyMlirContext context) {
-          SmallVector<MlirType, 4> inputsRaw(inputs.begin(), inputs.end());
-          SmallVector<MlirType, 4> resultsRaw(results.begin(), results.end());
-          MlirType t = mlirFunctionTypeGet(context->get(), inputsRaw.size(),
-                                           inputsRaw.data(), resultsRaw.size(),
-                                           resultsRaw.data());
+          MlirType t =
+              mlirFunctionTypeGet(context->get(), inputs.size(), inputs.data(),
+                                  results.size(), results.data());
           return PyFunctionType(context->getRef(), t);
         },
         py::arg("inputs"), py::arg("results"), py::arg("context") = py::none(),
@@ -742,7 +736,6 @@ class PyFunctionType : public PyConcreteType<PyFunctionType> {
         "inputs",
         [](PyFunctionType &self) {
           MlirType t = self;
-          auto contextRef = self.getContext();
           py::list types;
           for (intptr_t i = 0, e = mlirFunctionTypeGetNumInputs(self); i < e;
                ++i) {
@@ -754,7 +747,6 @@ class PyFunctionType : public PyConcreteType<PyFunctionType> {
     c.def_property_readonly(
         "results",
         [](PyFunctionType &self) {
-          auto contextRef = self.getContext();
           py::list types;
           for (intptr_t i = 0, e = mlirFunctionTypeGetNumResults(self); i < e;
                ++i) {
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index 971ad2dd214a15f..12e2dab60f3011b 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -21,6 +21,7 @@ declare_mlir_python_sources(MLIRPythonSources.Core.Python
     _mlir_libs/__init__.py
     ir.py
     passmanager.py
+    types.py
     dialects/_ods_common.py
 
     # The main _mlir module has submodules: include stubs from each.
diff --git a/mlir/python/mlir/types.py b/mlir/python/mlir/types.py
new file mode 100644
index 000000000000000..ce8d826b40a6b1d
--- /dev/null
+++ b/mlir/python/mlir/types.py
@@ -0,0 +1,207 @@
+from functools import partial
+from typing import Optional
+
+from .ir import (
+    Attribute,
+    BF16Type,
+    ComplexType,
+    F16Type,
+    F32Type,
+    F64Type,
+    Float8E4M3B11FNUZType,
+    Float8E4M3FNType,
+    Float8E5M2Type,
+    FunctionType,
+    IndexType,
+    IntegerType,
+    MemRefType,
+    NoneType,
+    OpaqueType,
+    RankedTensorType,
+    StridedLayoutAttr,
+    StringAttr,
+    TupleType,
+    Type,
+    UnrankedMemRefType,
+    UnrankedTensorType,
+    VectorType,
+)
+
+from .dialects import transform
+from .dialects import pdl
+
+
+_index = lambda: IndexType.get()
+_bool = lambda: IntegerType.get_signless(1)
+
+_i8 = lambda: IntegerType.get_signless(8)
+_i16 = lambda: IntegerType.get_signless(16)
+_i32 = lambda: IntegerType.get_signless(32)
+_i64 = lambda: IntegerType.get_signless(64)
+
+_si8 = lambda: IntegerType.get_signed(8)
+_si16 = lambda: IntegerType.get_signed(16)
+_si32 = lambda: IntegerType.get_signed(32)
+_si64 = lambda: IntegerType.get_signed(64)
+
+_ui8 = lambda: IntegerType.get_unsigned(8)
+_ui16 = lambda: IntegerType.get_unsigned(16)
+_ui32 = lambda: IntegerType.get_unsigned(32)
+_ui64 = lambda: IntegerType.get_unsigned(64)
+
+_f16 = lambda: F16Type.get()
+_f32 = lambda: F32Type.get()
+_f64 = lambda: F64Type.get()
+_bf16 = lambda: BF16Type.get()
+
+_f8e5m2 = lambda: Float8E5M2Type.get()
+_f8e4m3 = lambda: Float8E4M3FNType.get()
+_f8e4m3b11fnuz = lambda: Float8E4M3B11FNUZType.get()
+
+_cmp16 = lambda: ComplexType.get(_f16())
+_cmp32 = lambda: ComplexType.get(_f32())
+_cmp64 = lambda: ComplexType.get(_f64())
+
+_none = lambda: NoneType.get()
+
+_pdl_operation = lambda: pdl.OperationType.get()
+
+
+def _transform_any_op():
+    return transform.AnyOpType.get()
+
+
+_name_to_type = {
+    "index": _index,
+    "bool": _bool,
+    "i8": _i8,
+    "i16": _i16,
+    "i32": _i32,
+    "i64": _i64,
+    "si8": _si8,
+    "si16": _si16,
+    "si32": _si32,
+    "si64": _si64,
+    "ui8": _ui8,
+    "ui16": _ui16,
+    "ui32": _ui32,
+    "ui64": _ui64,
+    "f16": _f16,
+    "f32": _f32,
+    "f64": _f64,
+    "bf16": _bf16,
+    "f8e5m2": _f8e5m2,
+    "f8e4m3": _f8e4m3,
+    "f8e4m3b11fnuz": _f8e4m3b11fnuz,
+    "cmp16": _cmp16,
+    "cmp32": _cmp32,
+    "cmp64": _cmp64,
+    "none": _none,
+    "pdl_operation": _pdl_operation,
+    "transform_any_op": _transform_any_op,
+}
+
+
+def __getattr__(name):
+    if name in _name_to_type:
+        return _name_to_type[name]()
+    # This delegates the lookup to default module attribute lookup
+    # (i.e., functions defined below and such).
+    return None
+
+
+def transform_op(name):
+    return transform.OperationType.get(name)
+
+
+def opaque(dialect_namespace, type_data):
+    return OpaqueType.get(dialect_namespace, type_data)
+
+
+def _shaped(*args, element_type: Type = None, type_constructor=None):
+    if type_constructor is None:
+        raise ValueError("shaped is an abstract base class - cannot be constructed.")
+    if (element_type is None and args and not isinstance(args[-1], Type)) or (
+        args and isinstance(args[-1], Type) and element_type is not None
+    ):
+        raise ValueError(
+            f"Either element_type must be provided explicitly XOR last arg to tensor type constructor must be the element type."
+        )
+    if element_type is not None:
+        type = element_type
+        sizes = args
+    else:
+        type = args[-1]
+        sizes = args[:-1]
+    if sizes:
+        return type_constructor(sizes, type)
+    else:
+        return type_constructor(type)
+
+
+def vector(
+    *args,
+    element_type: Type = None,
+    scalable: Optional[list[bool]] = None,
+    scalable_dims: Optional[list[int]] = None,
+):
+    return _shaped(
+        *args,
+        element_type=element_type,
+        type_constructor=partial(
+            VectorType.get, scalable=scalable, scalable_dims=scalable_dims
+        ),
+    )
+
+
+def tensor(*args, element_type: Type = None, encoding: Optional[str] = None):
+    if encoding is not None:
+        encoding = StringAttr.get(encoding)
+    if not len(args) or len(args) == 1 and isinstance(args[-1], Type):
+        if encoding is not None:
+            raise ValueError("UnrankedTensorType does not support encoding.")
+        return _shaped(
+            *args, element_type=element_type, type_constructor=UnrankedTensorType.get
+        )
+    else:
+        return _shaped(
+            *args,
+            element_type=element_type,
+            type_constructor=partial(RankedTensorType.get, encoding=encoding),
+        )
+
+
+def stride(strides, offset: Optional[int] = 0):
+    return StridedLayoutAttr.get(offset, strides)
+
+
+def memref(
+    *args,
+    element_type: Type = None,
+    memory_space: Optional[int] = None,
+    layout: Optional[tuple[tuple[int, ...], int]] = None,
+):
+    if memory_space is not None:
+        memory_space = Attribute.parse(str(memory_space))
+    if not len(args) or len(args) == 1 and isinstance(args[-1], Type):
+        return _shaped(
+            *args,
+            element_type=element_type,
+            type_constructor=partial(UnrankedMemRefType.get, memory_space=memory_space),
+        )
+    else:
+        return _shaped(
+            *args,
+            element_type=element_type,
+            type_constructor=partial(
+                MemRefType.get, memory_space=memory_space, layout=layout
+            ),
+        )
+
+
+def tuple(*elements):
+    return TupleType.get_tuple(elements)
+
+
+def function(inputs, results):
+    return FunctionType.get(inputs, results)
diff --git a/mlir/test/python/ir/builtin_types.py b/mlir/test/python/ir/builtin_types.py
index d4fed86b4f135ee..203dcee263b958b 100644
--- a/mlir/test/python/ir/builtin_types.py
+++ b/mlir/test/python/ir/builtin_types.py
@@ -772,3 +772,95 @@ def testCustomTypeTypeCaster():
         print(t)
         # CHECK: OperationType(!transform.op<"foo.bar">)
         print(repr(t))
+
+
+# CHECK-LABEL: TEST: testTypeWrappers
+@run
+def testTypeWrappers():
+    try:
+        from mlir.types import i32
+    except RuntimeError as e:
+        assert e.args[0].startswith(
+            "An MLIR function requires a Context but none was provided"
+        )
+
+    import mlir.types as T
+    from mlir.types import vector, tensor
+
+    with Context(), Location.unknown():
+        c1 = T.cmp16
+        c2 = T.cmp32
+        assert repr(c1) == "ComplexType(complex<f16>)"
+        assert repr(c2) == "ComplexType(complex<f32>)"
+
+        vec_1 = vector(2, 3, T.f32)
+        vec_2 = vector(2, 3, 4, T.f32)
+        assert repr(vec_1) == "VectorType(vector<2x3xf32>)"
+        assert repr(vec_2) == "VectorType(vector<2x3x4xf32>)"
+
+        m1 = T.memref(2, 3, 4, T.f64)
+        assert repr(m1) == "MemRefType(memref<2x3x4xf64>)"
+
+        m2 = T.memref(2, 3, 4, T.f64, memory_space=1)
+        assert repr(m2) == "MemRefType(memref<2x3x4xf64, 1>)"
+
+        m3 = T.memref(2, 3, 4, T.f64, memory_space=1, layout=T.stride([5, 7, 13]))
+        assert repr(m3) == "MemRefType(memref<2x3x4xf64, strided<[5, 7, 13]>, 1>)"
+
+        m4 = T.memref(2, 3, 4, T.f64, memory_space=1, layout=T.stride([5, 7, 13], 42))
+        assert (
+            repr(m4)
+            == "MemRefType(memref<2x3x4xf64, strided<[5, 7, 13], offset: 42>, 1>)"
+        )
+
+        S = ShapedType.get_dynamic_size()
+
+        t = T.tensor(S, 3, S, T.f64)
+        assert repr(t) == "RankedTensorType(tensor<?x3x?xf64>)"
+        ut = tensor(T.f64)
+        assert repr(ut) == "UnrankedTensorType(tensor<*xf64>)"
+        t = tensor(S, 3, S, element_type=T.f64)
+        assert repr(t) == "RankedTensorType(tensor<?x3x?xf64>)"
+        ut = tensor(element_type=T.f64)
+        assert repr(ut) == "UnrankedTensorType(tensor<*xf64>)"
+
+        v = vector(3, 3, 3, T.f64)
+        assert repr(v) == "VectorType(vector<3x3x3xf64>)"
+
+        m = T.memref(S, 3, S, T.f64)
+        assert repr(m) == "MemRefType(memref<?x3x?xf64>)"
+        um = T.memref(T.f64)
+        assert repr(um) == "UnrankedMemRefType(memref<*xf64>)"
+        m = T.memref(S, 3, S, element_type=T.f64)
+        assert repr(m) == "MemRefType(memref<?x3x?xf64>)"
+        um = T.memref(element_type=T.f64)
+        assert repr(um) == "UnrankedMemRefType(memref<*xf64>)"
+
+        m = T.memref(S, 3, S, T.f64)
+        assert repr(m) == "MemRefType(memref<?x3x?xf64>)"
+        um = T.memref(T.f64)
+        assert repr(um) == "UnrankedMemRefType(memref<*xf64>)"
+
+        scalable_1 = vector(2, 3, T.f32, scalable=[False, True])
+        scalable_2 = vector(2, 3, 4, T.f32, scalable=[True, False, True])
+        assert repr(scalable_1) == "VectorType(vector<2x[3]xf32>)"
+        assert repr(scalable_2) == "VectorType(vector<[2]x3x[4]xf32>)"
+
+        scalable_3 = vector(2, 3, T.f32, scalable_dims=[1])
+        scalable_4 = vector(2, 3, 4, T.f32, scalable_dims=[0, 2])
+        assert scalable_3 == scalable_1
+        assert scalable_4 == scalable_2
+
+        opaq = T.opaque("scf", "placeholder")
+        assert repr(opaq) == "OpaqueType(!scf.placeholder)"
+
+        transfor_op = T.transform_op("foo.bar")
+        assert repr(transfor_op) == 'OperationType(!transform.op<"foo.bar">)'
+
+        tup1 = T.tuple(T.i16, T.i32, T.i64)
+        tup2 = T.tuple(T.f16, T.f32, T.f64)
+        assert repr(tup1) == "TupleType(tuple<i16, i32, i64>)"
+        assert repr(tup2) == "TupleType(tuple<f16, f32, f64>)"
+
+        func = T.function((T.i16, T.i32, T.i64), (T.f16, T.f32, T.f64))
+        assert repr(func) == "FunctionType((i16, i32, i64) -> (f16, f32, f64))"

Copy link
Member

@ftynse ftynse left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks nice from code readability perspective, especially if generalized to be used more pervasively.

My overall thinking here is that we want to delimit optional "sugar" from fundamental APIs, and make sure unsugared APIs are still usable. Any thoughts on how to make this clear?

I'm not sure how much of this should be treated as sugar, maybe we can just have a "types" object in dialect that contains such constructor functions?

Overall, this also needs more documentation. Both docstrings and overall design in https://mlir.llvm.org/docs/Bindings/Python/. (btw, magic value downcasting should also be described there).

@makslevental
Copy link
Contributor Author

My overall thinking here is that we want to delimit optional "sugar" from fundamental APIs, and make sure unsugared APIs are still usable. Any thoughts on how to make this clear?

Well, as the comment1 that was the impetus for this PR suggests, the fundamental APIs prioritize "authenticity" over ergonomics and that's fine. It's useful that they more closely reflect the C APIs underneath because then the bindings partially serve as a gentle on-ramp to the rest of the codebase.

Regarding making things more clearly delineated - we could move these to a sugar. namespace or some other more professional sounding word (helpers. ?).

I'm not sure how much of this should be treated as sugar, maybe we can just have a "types" object in dialect that contains such constructor functions?

Generating a whole bunch of this from the ODS is feasible, it's just tedious, corner-case chasing, implementation.

Overall, this also needs more documentation. Both docstrings and overall design in https://mlir.llvm.org/docs/Bindings/Python/. (btw, magic value downcasting should also be described there).

Good point. Will add in this PR.

Footnotes

  1. though I did notice you ultimately did add scalable_dims.

@makslevental makslevental force-pushed the type_wrappers branch 4 times, most recently from b0e4de5 to 7d1d4cc Compare November 14, 2023 16:55
# This module is NOT a package and so this must be None (rather than throw the RuntimeError below).
return None
try:
Context.current
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I would prefer that this just returned None instead of throwing a ValueError - that way I could check is None here instead of try -> except -> raise. And arguably properties shouldn't raise since they're supposed to emulate fields. What say you @ftynse?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd second using None. The consensus seems to be that raising exceptions from property getters is undesirable: https://stackoverflow.com/questions/1488472/best-practices-throwing-exceptions-from-properties, https://stackoverflow.com/questions/48778945/by-design-should-a-property-getter-ever-throw-an-exception-in-python with the latter referencing PEP-8.

That being said, having a property that returns something or None based on it being queried within a certain context manager or not also doesn't quite agree with the idea of fields either. Maybe you'll give the idea of consistently using functions a second thought based on that (getter functions are recommended if their body can raise errors). I understand the aesthetic appeal of fields for type annotations, but I would rather not base all decisions based exclusively on that. After all, we could have several mechanisms, one based on each other that serve different purposes.

For example, we could have type constructors as functions available to all clients, and an additional sugaring with properties somewhere in mlir.sugar.typing that is more specific to from_py_func. If somebody doesn't use from_py_func (I don't in my downstreams as the amount of complexity it comes with isn't compensated by the sweetness of sugaring for my taste), they don't have to pay the complexity cost for something they don't need.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe you'll give the idea of consistently using functions a second thought based on that (getter functions are recommended if their body can raise errors)

I'm not opinionated - I just thought it was aesthetically pleasing - but "least surprising" is the practical approach so I'm happy to change it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the most surprising thing for me in the original version was that from mlir.(sugar.)types import i32 could fail with something that is not an ImportError. :)

@makslevental makslevental force-pushed the type_wrappers branch 3 times, most recently from 837c147 to e383d52 Compare November 14, 2023 18:55
@ftynse
Copy link
Member

ftynse commented Nov 21, 2023

If this can be placed in a subdirectory, even temporarily, in a subdirectory with a clear indication that the design may evolve, I have no objection to landing this.

@makslevental
Copy link
Contributor Author

makslevental commented Nov 22, 2023

in a subdirectory with a clear indication that the design may evolve, I have no objection to landing this.

How about into runtime since there are various other type-y things in there already?

@ftynse
Copy link
Member

ftynse commented Nov 22, 2023

How about into runtime since there are various other type-y things in there already?

Those are specific to interaction with numpy. The things added here are more for IR construction, not for using objects of said type at runtime. It also doesn't convey the "this API is more likely to change than other parts" message. I'm fine with any new directory: extras/helpers/sugar/syntax/lang/use-your-imagination. Let's put it somewhere so we are not bikeshedding the name and can make progress, all I want is to ensure we have the freedom to revise the decision we take now with insufficient information.

Copy link
Member

@ftynse ftynse left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, thanks! Let's land this, have some mileage with it and iterate if needed.

@makslevental makslevental merged commit 225648e into llvm:main Nov 27, 2023
@makslevental makslevental deleted the type_wrappers branch November 27, 2023 21:58
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:python MLIR Python bindings mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants