Skip to content

[mlir][python] add type wrappers #71218

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Nov 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions mlir/lib/Bindings/Python/IRCore.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2538,8 +2538,8 @@ void mlir::python::populateIRCore(py::module &m) {
[](py::object & /*class*/) {
auto *context = PyThreadContextEntry::getDefaultContext();
if (!context)
throw py::value_error("No current Context");
return context;
return py::none().cast<py::object>();
return py::cast(context);
},
"Gets the Context bound to the current thread or raises ValueError")
.def_property_readonly(
Expand Down
24 changes: 8 additions & 16 deletions mlir/lib/Bindings/Python/IRTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,7 @@ class PyVectorType : public PyConcreteType<PyVectorType, PyShapedType> {

static void bindDerived(ClassTy &c) {
c.def_static("get", &PyVectorType::get, py::arg("shape"),
py::arg("elementType"), py::kw_only(),
py::arg("element_type"), py::kw_only(),
py::arg("scalable") = py::none(),
py::arg("scalable_dims") = py::none(),
py::arg("loc") = py::none(), "Create a vector type")
Expand Down Expand Up @@ -689,13 +689,9 @@ class PyTupleType : public PyConcreteType<PyTupleType> {
static void bindDerived(ClassTy &c) {
c.def_static(
"get_tuple",
[](py::list elementList, DefaultingPyMlirContext context) {
intptr_t num = py::len(elementList);
// Mapping py::list to SmallVector.
SmallVector<MlirType, 4> elements;
for (auto element : elementList)
elements.push_back(element.cast<PyType>());
MlirType t = mlirTupleTypeGet(context->get(), num, elements.data());
[](std::vector<MlirType> elements, DefaultingPyMlirContext context) {
MlirType t = mlirTupleTypeGet(context->get(), elements.size(),
elements.data());
return PyTupleType(context->getRef(), t);
},
py::arg("elements"), py::arg("context") = py::none(),
Expand Down Expand Up @@ -727,13 +723,11 @@ class PyFunctionType : public PyConcreteType<PyFunctionType> {
static void bindDerived(ClassTy &c) {
c.def_static(
"get",
[](std::vector<PyType> inputs, std::vector<PyType> results,
[](std::vector<MlirType> inputs, std::vector<MlirType> results,
DefaultingPyMlirContext context) {
SmallVector<MlirType, 4> inputsRaw(inputs.begin(), inputs.end());
SmallVector<MlirType, 4> resultsRaw(results.begin(), results.end());
MlirType t = mlirFunctionTypeGet(context->get(), inputsRaw.size(),
inputsRaw.data(), resultsRaw.size(),
resultsRaw.data());
MlirType t =
mlirFunctionTypeGet(context->get(), inputs.size(), inputs.data(),
results.size(), results.data());
return PyFunctionType(context->getRef(), t);
},
py::arg("inputs"), py::arg("results"), py::arg("context") = py::none(),
Expand All @@ -742,7 +736,6 @@ class PyFunctionType : public PyConcreteType<PyFunctionType> {
"inputs",
[](PyFunctionType &self) {
MlirType t = self;
auto contextRef = self.getContext();
py::list types;
for (intptr_t i = 0, e = mlirFunctionTypeGetNumInputs(self); i < e;
++i) {
Expand All @@ -754,7 +747,6 @@ class PyFunctionType : public PyConcreteType<PyFunctionType> {
c.def_property_readonly(
"results",
[](PyFunctionType &self) {
auto contextRef = self.getContext();
py::list types;
for (intptr_t i = 0, e = mlirFunctionTypeGetNumResults(self); i < e;
++i) {
Expand Down
1 change: 1 addition & 0 deletions mlir/python/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ declare_mlir_python_sources(MLIRPythonSources.Core.Python
_mlir_libs/__init__.py
ir.py
passmanager.py
extras/types.py
dialects/_ods_common.py

# The main _mlir module has submodules: include stubs from each.
Expand Down
Empty file.
165 changes: 165 additions & 0 deletions mlir/python/mlir/extras/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from functools import partial
from typing import Optional, List

from ..ir import (
Attribute,
BF16Type,
ComplexType,
F16Type,
F32Type,
F64Type,
Float8E4M3B11FNUZType,
Float8E4M3FNType,
Float8E5M2Type,
FunctionType,
IndexType,
IntegerType,
MemRefType,
NoneType,
OpaqueType,
RankedTensorType,
StridedLayoutAttr,
StringAttr,
TupleType,
Type,
UnrankedMemRefType,
UnrankedTensorType,
VectorType,
)

index = lambda: IndexType.get()


def i(width):
return IntegerType.get_signless(width)


def si(width):
return IntegerType.get_signed(width)


def ui(width):
return IntegerType.get_unsigned(width)


bool = lambda: i(1)
i8 = lambda: i(8)
i16 = lambda: i(16)
i32 = lambda: i(32)
i64 = lambda: i(64)

si8 = lambda: si(8)
si16 = lambda: si(16)
si32 = lambda: si(32)
si64 = lambda: si(64)

ui8 = lambda: ui(8)
ui16 = lambda: ui(16)
ui32 = lambda: ui(32)
ui64 = lambda: ui(64)

f16 = lambda: F16Type.get()
f32 = lambda: F32Type.get()
f64 = lambda: F64Type.get()
bf16 = lambda: BF16Type.get()

f8E5M2 = lambda: Float8E5M2Type.get()
f8E4M3 = lambda: Float8E4M3FNType.get()
f8E4M3B11FNUZ = lambda: Float8E4M3B11FNUZType.get()

none = lambda: NoneType.get()


def complex(type):
return ComplexType.get(type)


def opaque(dialect_namespace, type_data):
return OpaqueType.get(dialect_namespace, type_data)


def _shaped(*shape, element_type: Type = None, type_constructor=None):
if type_constructor is None:
raise ValueError("shaped is an abstract base class - cannot be constructed.")
if (element_type is None and shape and not isinstance(shape[-1], Type)) or (
shape and isinstance(shape[-1], Type) and element_type is not None
):
raise ValueError(
f"Either element_type must be provided explicitly XOR last arg to tensor type constructor must be the element type."
)
if element_type is not None:
type = element_type
sizes = shape
else:
type = shape[-1]
sizes = shape[:-1]
if sizes:
return type_constructor(sizes, type)
else:
return type_constructor(type)


def vector(
*shape,
element_type: Type = None,
scalable: Optional[List[bool]] = None,
scalable_dims: Optional[List[int]] = None,
):
return _shaped(
*shape,
element_type=element_type,
type_constructor=partial(
VectorType.get, scalable=scalable, scalable_dims=scalable_dims
),
)


def tensor(*shape, element_type: Type = None, encoding: Optional[str] = None):
if encoding is not None:
encoding = StringAttr.get(encoding)
if not shape or (len(shape) == 1 and isinstance(shape[-1], Type)):
if encoding is not None:
raise ValueError("UnrankedTensorType does not support encoding.")
return _shaped(
*shape, element_type=element_type, type_constructor=UnrankedTensorType.get
)
return _shaped(
*shape,
element_type=element_type,
type_constructor=partial(RankedTensorType.get, encoding=encoding),
)


def memref(
*shape,
element_type: Type = None,
memory_space: Optional[int] = None,
layout: Optional[StridedLayoutAttr] = None,
):
if memory_space is not None:
memory_space = Attribute.parse(str(memory_space))
if not shape or (len(shape) == 1 and isinstance(shape[-1], Type)):
return _shaped(
*shape,
element_type=element_type,
type_constructor=partial(UnrankedMemRefType.get, memory_space=memory_space),
)
return _shaped(
*shape,
element_type=element_type,
type_constructor=partial(
MemRefType.get, memory_space=memory_space, layout=layout
),
)


def tuple(*elements):
return TupleType.get_tuple(elements)


def function(*, inputs, results):
return FunctionType.get(inputs, results)
99 changes: 99 additions & 0 deletions mlir/test/python/ir/builtin_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import gc
from mlir.ir import *
from mlir.dialects import arith, tensor, func, memref
import mlir.extras.types as T


def run(f):
Expand Down Expand Up @@ -772,3 +773,101 @@ def testCustomTypeTypeCaster():
print(t)
# CHECK: OperationType(!transform.op<"foo.bar">)
print(repr(t))


# CHECK-LABEL: TEST: testTypeWrappers
@run
def testTypeWrappers():
def stride(strides, offset=0):
return StridedLayoutAttr.get(offset, strides)

with Context(), Location.unknown():
ia = T.i(5)
sia = T.si(6)
uia = T.ui(7)
assert repr(ia) == "IntegerType(i5)"
assert repr(sia) == "IntegerType(si6)"
assert repr(uia) == "IntegerType(ui7)"

assert T.i(16) == T.i16()
assert T.si(16) == T.si16()
assert T.ui(16) == T.ui16()

c1 = T.complex(T.f16())
c2 = T.complex(T.i32())
assert repr(c1) == "ComplexType(complex<f16>)"
assert repr(c2) == "ComplexType(complex<i32>)"

vec_1 = T.vector(2, 3, T.f32())
vec_2 = T.vector(2, 3, 4, T.f32())
assert repr(vec_1) == "VectorType(vector<2x3xf32>)"
assert repr(vec_2) == "VectorType(vector<2x3x4xf32>)"

m1 = T.memref(2, 3, 4, T.f64())
assert repr(m1) == "MemRefType(memref<2x3x4xf64>)"

m2 = T.memref(2, 3, 4, T.f64(), memory_space=1)
assert repr(m2) == "MemRefType(memref<2x3x4xf64, 1>)"

m3 = T.memref(2, 3, 4, T.f64(), memory_space=1, layout=stride([5, 7, 13]))
assert repr(m3) == "MemRefType(memref<2x3x4xf64, strided<[5, 7, 13]>, 1>)"

m4 = T.memref(2, 3, 4, T.f64(), memory_space=1, layout=stride([5, 7, 13], 42))
assert (
repr(m4)
== "MemRefType(memref<2x3x4xf64, strided<[5, 7, 13], offset: 42>, 1>)"
)

S = ShapedType.get_dynamic_size()

t1 = T.tensor(S, 3, S, T.f64())
assert repr(t1) == "RankedTensorType(tensor<?x3x?xf64>)"
ut1 = T.tensor(T.f64())
assert repr(ut1) == "UnrankedTensorType(tensor<*xf64>)"
t2 = T.tensor(S, 3, S, element_type=T.f64())
assert repr(t2) == "RankedTensorType(tensor<?x3x?xf64>)"
ut2 = T.tensor(element_type=T.f64())
assert repr(ut2) == "UnrankedTensorType(tensor<*xf64>)"

t3 = T.tensor(S, 3, S, T.f64(), encoding="encoding")
assert repr(t3) == 'RankedTensorType(tensor<?x3x?xf64, "encoding">)'

v = T.vector(3, 3, 3, T.f64())
assert repr(v) == "VectorType(vector<3x3x3xf64>)"

m5 = T.memref(S, 3, S, T.f64())
assert repr(m5) == "MemRefType(memref<?x3x?xf64>)"
um1 = T.memref(T.f64())
assert repr(um1) == "UnrankedMemRefType(memref<*xf64>)"
m6 = T.memref(S, 3, S, element_type=T.f64())
assert repr(m6) == "MemRefType(memref<?x3x?xf64>)"
um2 = T.memref(element_type=T.f64())
assert repr(um2) == "UnrankedMemRefType(memref<*xf64>)"

m7 = T.memref(S, 3, S, T.f64())
assert repr(m7) == "MemRefType(memref<?x3x?xf64>)"
um3 = T.memref(T.f64())
assert repr(um3) == "UnrankedMemRefType(memref<*xf64>)"

scalable_1 = T.vector(2, 3, T.f32(), scalable=[False, True])
scalable_2 = T.vector(2, 3, 4, T.f32(), scalable=[True, False, True])
assert repr(scalable_1) == "VectorType(vector<2x[3]xf32>)"
assert repr(scalable_2) == "VectorType(vector<[2]x3x[4]xf32>)"

scalable_3 = T.vector(2, 3, T.f32(), scalable_dims=[1])
scalable_4 = T.vector(2, 3, 4, T.f32(), scalable_dims=[0, 2])
assert scalable_3 == scalable_1
assert scalable_4 == scalable_2

opaq = T.opaque("scf", "placeholder")
assert repr(opaq) == "OpaqueType(!scf.placeholder)"

tup1 = T.tuple(T.i16(), T.i32(), T.i64())
tup2 = T.tuple(T.f16(), T.f32(), T.f64())
assert repr(tup1) == "TupleType(tuple<i16, i32, i64>)"
assert repr(tup2) == "TupleType(tuple<f16, f32, f64>)"

func = T.function(
inputs=(T.i16(), T.i32(), T.i64()), results=(T.f16(), T.f32(), T.f64())
)
assert repr(func) == "FunctionType((i16, i32, i64) -> (f16, f32, f64))"
8 changes: 1 addition & 7 deletions mlir/test/python/ir/context_managers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,7 @@ def run(f):
def testContextEnterExit():
with Context() as ctx:
assert Context.current is ctx
try:
_ = Context.current
except ValueError as e:
# CHECK: No current Context
print(e)
else:
assert False, "Expected exception"
assert Context.current is None


run(testContextEnterExit)
Expand Down