Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[OM] Pass Python values back and forth, not Attributes. #7417

Merged
merged 1 commit into from
Jul 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 119 additions & 20 deletions lib/Bindings/Python/OMModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//

#include "DialectModules.h"
#include "circt-c/Dialect/HW.h"
#include "circt-c/Dialect/OM.h"
#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/BuiltinTypes.h"
Expand All @@ -29,18 +30,27 @@ struct Map;
struct BasePath;
struct Path;

/// None is used to by pybind when default initializing a PythonValue. The order
/// of types in the variant matters here, and we want pybind to try casting to
/// the Python classes defined in this file first, before MlirAttribute and the
/// upstream MLIR type casters. If the MlirAttribute is tried first, then we
/// can hit an assert inside the MLIR codebase.
/// These are the Python types that are represented by the different primitive
/// OMEvaluatorValues as Attributes.
using PythonPrimitive = std::variant<py::int_, py::float_, py::str, py::bool_,
py::tuple, py::list, py::dict>;

/// None is used to by pybind when default initializing a PythonValue. The
/// order of types in the variant matters here, and we want pybind to try
/// casting to the Python classes defined in this file first, before
/// MlirAttribute and the upstream MLIR type casters. If the MlirAttribute
/// is tried first, then we can hit an assert inside the MLIR codebase.
struct None {};
using PythonValue =
std::variant<None, Object, List, Tuple, Map, BasePath, Path, MlirAttribute>;
using PythonValue = std::variant<None, Object, List, Tuple, Map, BasePath, Path,
PythonPrimitive>;

/// Map an opaque OMEvaluatorValue into a python value.
PythonValue omEvaluatorValueToPythonValue(OMEvaluatorValue result);
OMEvaluatorValue pythonValueToOMEvaluatorValue(PythonValue result);
OMEvaluatorValue pythonValueToOMEvaluatorValue(PythonValue result,
MlirContext ctx);
static PythonPrimitive omPrimitiveToPythonValue(MlirAttribute attr);
static MlirAttribute omPythonValueToPrimitive(PythonPrimitive value,
MlirContext ctx);

/// Provides a List class by simply wrapping the OMObject CAPI.
struct List {
Expand Down Expand Up @@ -79,13 +89,15 @@ struct Map {
Map(OMEvaluatorValue value) : value(value) {}

/// Return the keys.
std::vector<MlirAttribute> getKeys() {
std::vector<py::str> getKeys() {
auto attr = omEvaluatorMapGetKeys(value);
intptr_t numFieldNames = mlirArrayAttrGetNumElements(attr);

std::vector<MlirAttribute> pyFieldNames;
for (intptr_t i = 0; i < numFieldNames; ++i)
pyFieldNames.emplace_back(mlirArrayAttrGetElement(attr, i));
std::vector<py::str> pyFieldNames;
for (intptr_t i = 0; i < numFieldNames; ++i) {
auto name = mlirStringAttrGetValue(mlirArrayAttrGetElement(attr, i));
pyFieldNames.emplace_back(py::str(name.data, name.length));
}

return pyFieldNames;
}
Expand Down Expand Up @@ -224,7 +236,8 @@ struct Evaluator {
std::vector<PythonValue> actualParams) {
std::vector<OMEvaluatorValue> values;
for (auto &param : actualParams)
values.push_back(pythonValueToOMEvaluatorValue(param));
values.push_back(pythonValueToOMEvaluatorValue(
param, mlirModuleGetContext(getModule())));

// Instantiate the Object via the CAPI.
OMEvaluatorValue result = omEvaluatorInstantiate(
Expand Down Expand Up @@ -288,7 +301,8 @@ class PyMapAttrIterator {
throw py::stop_iteration();

MlirIdentifier key = omMapAttrGetElementKey(attr, nextIndex);
MlirAttribute value = omMapAttrGetElementValue(attr, nextIndex);
PythonValue value =
omPrimitiveToPythonValue(omMapAttrGetElementValue(attr, nextIndex));
nextIndex++;

auto keyName = mlirIdentifierStr(key);
Expand Down Expand Up @@ -349,6 +363,88 @@ Map::dunderGetItem(std::variant<intptr_t, std::string, MlirAttribute> key) {
return dunderGetItemAttr(std::get<MlirAttribute>(key));
}

// Convert a generic MLIR Attribute to a PythonValue. This is basically a C++
// fast path of the parts of attribute_to_var that we use in the OM dialect.
static PythonPrimitive omPrimitiveToPythonValue(MlirAttribute attr) {
if (omAttrIsAIntegerAttr(attr)) {
auto strRef = omIntegerAttrToString(attr);
return py::int_(py::str(strRef.data, strRef.length));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a reason, we need to go from int to str to py::int_ ?
(Not super familiar with the pybind API)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's to handle ints larger than 64 bits. This is the pybind way of saying this same Python code:

return int(str(om.OMIntegerAttr(attr)))

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There was some discussion previously, but this is our workaround for now: llvm/llvm-project#84190

}

if (mlirAttributeIsAFloat(attr)) {
return py::float_(mlirFloatAttrGetValueDouble(attr));
}

if (mlirAttributeIsAString(attr)) {
auto strRef = mlirStringAttrGetValue(attr);
return py::str(strRef.data, strRef.length);
}

if (mlirAttributeIsABool(attr)) {
return py::bool_(mlirBoolAttrGetValue(attr));
}

if (omAttrIsAReferenceAttr(attr)) {
auto innerRef = omReferenceAttrGetInnerRef(attr);
auto moduleStrRef =
mlirStringAttrGetValue(hwInnerRefAttrGetModule(innerRef));
auto nameStrRef = mlirStringAttrGetValue(hwInnerRefAttrGetName(innerRef));
auto moduleStr = py::str(moduleStrRef.data, moduleStrRef.length);
auto nameStr = py::str(nameStrRef.data, nameStrRef.length);
return py::make_tuple(moduleStr, nameStr);
}

if (omAttrIsAListAttr(attr)) {
py::list results;
for (intptr_t i = 0, e = omListAttrGetNumElements(attr); i < e; ++i)
results.append(omPrimitiveToPythonValue(omListAttrGetElement(attr, i)));
return results;
}

if (omAttrIsAMapAttr(attr)) {
py::dict results;
for (intptr_t i = 0, e = omMapAttrGetNumElements(attr); i < e; ++i) {
auto keyStrRef = mlirIdentifierStr(omMapAttrGetElementKey(attr, i));
auto key = py::str(keyStrRef.data, keyStrRef.length);
auto value = omPrimitiveToPythonValue(omMapAttrGetElementValue(attr, i));
results[key] = value;
}
return results;
}

mlirAttributeDump(attr);
throw py::type_error("Unexpected OM primitive attribute");
}

// Convert a primitive PythonValue to a generic MLIR Attribute. This is
// basically a C++ fast path of the parts of var_to_attribute that we use in the
// OM dialect.
static MlirAttribute omPythonValueToPrimitive(PythonPrimitive value,
MlirContext ctx) {
if (auto *intValue = std::get_if<py::int_>(&value)) {
auto intType = mlirIntegerTypeGet(ctx, 64);
auto intAttr = mlirIntegerAttrGet(intType, intValue->cast<int64_t>());
return omIntegerAttrGet(intAttr);
}

if (auto *attr = std::get_if<py::float_>(&value)) {
auto floatType = mlirF64TypeGet(ctx);
return mlirFloatAttrDoubleGet(ctx, floatType, attr->cast<double>());
}

if (auto *attr = std::get_if<py::str>(&value)) {
auto str = attr->cast<std::string>();
auto strRef = mlirStringRefCreate(str.data(), str.length());
return mlirStringAttrGet(ctx, strRef);
}

if (auto *attr = std::get_if<py::bool_>(&value)) {
return mlirBoolAttrGet(ctx, attr->cast<bool>());
}

throw py::type_error("Unexpected OM primitive value");
}

PythonValue omEvaluatorValueToPythonValue(OMEvaluatorValue result) {
// If the result is null, something failed. Diagnostic handling is
// implemented in pure Python, so nothing to do here besides throwing an
Expand Down Expand Up @@ -386,13 +482,11 @@ PythonValue omEvaluatorValueToPythonValue(OMEvaluatorValue result) {

// If the field was a primitive, return the Attribute.
assert(omEvaluatorValueIsAPrimitive(result));
return omEvaluatorValueGetPrimitive(result);
return omPrimitiveToPythonValue(omEvaluatorValueGetPrimitive(result));
}

OMEvaluatorValue pythonValueToOMEvaluatorValue(PythonValue result) {
if (auto *attr = std::get_if<MlirAttribute>(&result))
return omEvaluatorValueFromPrimitive(*attr);

OMEvaluatorValue pythonValueToOMEvaluatorValue(PythonValue result,
MlirContext ctx) {
if (auto *list = std::get_if<List>(&result))
return list->getValue();

Expand All @@ -408,7 +502,12 @@ OMEvaluatorValue pythonValueToOMEvaluatorValue(PythonValue result) {
if (auto *path = std::get_if<Path>(&result))
return path->getValue();

return std::get<Object>(result).getValue();
if (auto *object = std::get_if<Object>(&result))
return object->getValue();

auto primitive = std::get<PythonPrimitive>(result);
return omEvaluatorValueFromPrimitive(
omPythonValueToPrimitive(primitive, ctx));
}

} // namespace
Expand Down
21 changes: 9 additions & 12 deletions lib/Bindings/Python/dialects/om.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@

# Wrap a base mlir object with high-level object.
def wrap_mlir_object(value):
# For primitives, return a Python value.
if isinstance(value, Attribute):
return attribute_to_var(value)
# For primitives, return the Python value directly.
if isinstance(value, (int, float, str, bool, tuple, list, dict)):
return value

if isinstance(value, BaseList):
return List(value)
Expand Down Expand Up @@ -52,12 +52,7 @@ def om_var_to_attribute(obj, none_on_fail: bool = False) -> ir.Attrbute:


def unwrap_python_object(value):
# Check if the value is a Primitive.
try:
return om_var_to_attribute(value)
except:
pass

# Check if the value is any of our container or custom types.
if isinstance(value, List):
return BaseList(value)

Expand All @@ -73,9 +68,11 @@ def unwrap_python_object(value):
if isinstance(value, Path):
return value

# Otherwise, it must be an Object. Cast to the mlir object.
assert isinstance(value, Object)
return BaseObject(value)
if isinstance(value, Object):
return BaseObject(value)

# Otherwise, it must be a primitive, so just return it.
return value


class List(BaseList):
Expand Down
Loading