Skip to content

Commit

Permalink
[Python][OM] Handle BoolAttr's before IntegerAttr's. (#7438)
Browse files Browse the repository at this point in the history
BoolAttr's are IntegerAttr's, check them first.

IntegerAttr's that happen to have the characteristics of
BoolAttr will accordingly become Python boolean values.

Unclear where these come from but we do lower booleans
to MLIR bool constants so make sure to handle that.

Add test for object model IR with bool constants.
  • Loading branch information
dtzSiFive authored Aug 5, 2024
1 parent cbdee94 commit bec0dea
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 10 deletions.
12 changes: 11 additions & 1 deletion integration_test/Bindings/Python/dialects/om.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,11 @@
%map = om.map_create %entry1, %entry2: !om.string, !om.integer
om.class.field @map_create, %map : !om.map<!om.string, !om.integer>
%true = om.constant true
om.class.field @true, %true : i1
%false = om.constant false
om.class.field @false, %false : i1
}
om.class @Child(%0: !om.integer) {
Expand Down Expand Up @@ -157,7 +162,7 @@

# CHECK: 14
print(obj.child.foo)
# CHECK: loc("-":60:7)
# CHECK: loc("-":65:7)
print(obj.child.get_field_loc("foo"))
# CHECK: ('Root', 'x')
print(obj.reference)
Expand Down Expand Up @@ -224,6 +229,11 @@
# CHECK-NEXT: Y 15
print(k, v)

# CHECK: True
print(obj.true)
# CHECK: False
print(obj.false)

obj = evaluator.instantiate("Client")
object_dict: Dict[om.Object, str] = {}
for field_name, data in obj:
Expand Down
19 changes: 10 additions & 9 deletions lib/Bindings/Python/OMModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -366,15 +366,6 @@ Map::dunderGetItem(std::variant<intptr_t, std::string, MlirAttribute> key) {
// Convert a generic MLIR Attribute to a PythonValue. This is basically a C++
// fast path of the parts of attribute_to_var that we use in the OM dialect.
static PythonPrimitive omPrimitiveToPythonValue(MlirAttribute attr) {
if (mlirAttributeIsAInteger(attr)) {
MlirType type = mlirAttributeGetType(attr);
if (mlirTypeIsAIndex(type) || mlirIntegerTypeIsSignless(type))
return py::int_(mlirIntegerAttrGetValueInt(attr));
if (mlirIntegerTypeIsSigned(type))
return py::int_(mlirIntegerAttrGetValueSInt(attr));
return py::int_(mlirIntegerAttrGetValueUInt(attr));
}

if (omAttrIsAIntegerAttr(attr)) {
auto strRef = omIntegerAttrToString(attr);
return py::int_(py::str(strRef.data, strRef.length));
Expand All @@ -389,10 +380,20 @@ static PythonPrimitive omPrimitiveToPythonValue(MlirAttribute attr) {
return py::str(strRef.data, strRef.length);
}

// BoolAttr's are IntegerAttr's, check this first.
if (mlirAttributeIsABool(attr)) {
return py::bool_(mlirBoolAttrGetValue(attr));
}

if (mlirAttributeIsAInteger(attr)) {
MlirType type = mlirAttributeGetType(attr);
if (mlirTypeIsAIndex(type) || mlirIntegerTypeIsSignless(type))
return py::int_(mlirIntegerAttrGetValueInt(attr));
if (mlirIntegerTypeIsSigned(type))
return py::int_(mlirIntegerAttrGetValueSInt(attr));
return py::int_(mlirIntegerAttrGetValueUInt(attr));
}

if (omAttrIsAReferenceAttr(attr)) {
auto innerRef = omReferenceAttrGetInnerRef(attr);
auto moduleStrRef =
Expand Down

0 comments on commit bec0dea

Please sign in to comment.