-
Notifications
You must be signed in to change notification settings - Fork 13.3k
[mlir][IR] Support op interfaces in HasParent
trait
#91471
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[mlir][IR] Support op interfaces in HasParent
trait
#91471
Conversation
This commit adds support for op interfaces to `HasParent`: an op interface can now be specified as a parent. To produce useful error messages, a new helper function `getInterfaceName` is generated for every op interface. This is similar to `getOperationName`, which is generated for operations. This commit addresses a TODO in `TensorOps.td`.
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-core Author: Matthias Springer (matthias-springer) ChangesThis commit adds support for op interfaces to To produce useful error messages, a new helper function This commit addresses a TODO in Full diff: https://github.com/llvm/llvm-project/pull/91471.diff 6 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index a403e89a39f98..2d9f4c29f7aad 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -1463,8 +1463,7 @@ def Tensor_PadOp : Tensor_Op<"pad", [
def Tensor_ParallelInsertSliceOp : Tensor_Op<"parallel_insert_slice", [
AttrSizedOperandSegments,
OffsetSizeAndStrideOpInterface,
- // TODO: Cannot use an interface here atm, verify this manually for now.
- // HasParent<"ParallelCombiningOpInterface">
+ HasParent<"ParallelCombiningOpInterface">
]> {
let summary = [{
Specify the tensor slice update of a single thread of a parent
diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index 7866ac24c1ccb..b089e72fe8928 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -133,7 +133,7 @@ class SingleBlockImplicitTerminator<string op>
// Op's regions don't have terminator.
def NoTerminator : NativeOpTrait<"NoTerminator">, StructuralOpTrait;
-// Op's parent operation is the provided one.
+// Op's parent operation or op interface is the provided one.
class HasParent<string op>
: ParamNativeOpTrait<"HasParent", op>, StructuralOpTrait;
diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h
index 59f094d669099..550f04d9a373b 100644
--- a/mlir/include/mlir/IR/OpDefinition.h
+++ b/mlir/include/mlir/IR/OpDefinition.h
@@ -1298,7 +1298,9 @@ struct HasParent {
return op->emitOpError()
<< "expects parent op "
<< (sizeof...(ParentOpTypes) != 1 ? "to be one of '" : "'")
- << llvm::ArrayRef({ParentOpTypes::getOperationName()...}) << "'";
+ << llvm::ArrayRef(
+ {getOperationOrInterfaceName<ParentOpTypes>()...})
+ << "'";
}
template <typename ParentOpType =
@@ -1309,6 +1311,25 @@ struct HasParent {
return llvm::cast<ParentOpType>(parent);
}
};
+
+private:
+ /// A class is an op interface if it has a `getInterfaceName` function.
+ template <typename T, typename = int>
+ struct IsInterface : std::false_type {};
+ template <typename T>
+ struct IsInterface<T, decltype((void)T::getInterfaceName(), 0)>
+ : std::true_type {};
+
+ /// Helper function that returns the name of the given operation or interface
+ /// as a string literal.
+ template <typename T>
+ static constexpr StringLiteral getOperationOrInterfaceName() {
+ if constexpr (IsInterface<T>::value) {
+ return T::getInterfaceName();
+ } else {
+ return T::getOperationName();
+ }
+ }
};
/// A trait for operations that have an attribute specifying operand segments.
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 7a13f7a7d1355..f45c2e4efdf58 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -3455,10 +3455,6 @@ void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
}
LogicalResult ParallelInsertSliceOp::verify() {
- if (!isa<ParallelCombiningOpInterface>(getOperation()->getParentOp()))
- return this->emitError("expected ParallelCombiningOpInterface parent, got:")
- << *(getOperation()->getParentOp());
-
RankedTensorType expectedType;
SliceVerificationResult result =
verifyInsertSliceOp(getSourceType(), getDestType(), getStaticOffsets(),
diff --git a/mlir/test/Dialect/Tensor/invalid.mlir b/mlir/test/Dialect/Tensor/invalid.mlir
index 41b6529f64afa..4205d9c3dcd31 100644
--- a/mlir/test/Dialect/Tensor/invalid.mlir
+++ b/mlir/test/Dialect/Tensor/invalid.mlir
@@ -698,3 +698,12 @@ func.func @unpack_mismatch_inner_tile_size_and_output_shape(
%0 = tensor.unpack %input inner_dims_pos = [0, 1] inner_tiles = [8, 4] into %output : tensor<?x?x8x8xf32> -> tensor<?x?xf32>
return %0 : tensor<?x?xf32>
}
+
+// -----
+
+func.func @parallel_insert_slice_out_of_context(%a: tensor<5xf32>, %b: tensor<100xf32>) {
+ // expected-error@+1 {{expects parent op 'ParallelCombiningOpInterface'}}
+ tensor.parallel_insert_slice %a into %b[0][5][1]
+ : tensor<5xf32> into tensor<100xf32>
+ return
+}
diff --git a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp
index 2a7406f42f34b..17babee913f04 100644
--- a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp
@@ -537,7 +537,7 @@ void InterfaceGenerator::emitInterfaceDecl(const Interface &interface) {
// Emit the derived trait for the interface.
os << "template <typename " << valueTemplate << ">\n";
- os << "struct " << interface.getName() << "Trait;\n";
+ os << "struct " << interfaceName << "Trait;\n";
os << "\n} // namespace detail\n";
@@ -548,6 +548,11 @@ void InterfaceGenerator::emitInterfaceDecl(const Interface &interface) {
interfaceName, interfaceName, interfaceTraitsName,
interfaceBaseType);
+ // Insert function that returns the name of the interface as a string.
+ os << " static constexpr ::llvm::StringLiteral getInterfaceName() {\n"
+ << " return \"" << interfaceName << "\";\n"
+ << " }\n\n";
+
// Emit a utility wrapper trait class.
os << llvm::formatv(" template <typename {1}>\n"
" struct Trait : public detail::{0}Trait<{1}> {{};\n",
@@ -588,7 +593,8 @@ void InterfaceGenerator::emitInterfaceDecl(const Interface &interface) {
<< " auto* interface = getInterfaceFor(base);\n"
<< " if (!interface)\n"
" return false;\n"
- " " << interfaceName << " odsInterfaceInstance(base, interface);\n"
+ " "
+ << interfaceName << " odsInterfaceInstance(base, interface);\n"
<< " " << tblgen::tgfmt(extraClassOf->trim(), &extraClassOfFmt)
<< "\n }\n";
}
|
@@ -1309,6 +1311,25 @@ struct HasParent { | |||
return llvm::cast<ParentOpType>(parent); | |||
} | |||
}; | |||
|
|||
private: | |||
/// A class is an op interface if it has a `getInterfaceName` function. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ops and op interfaces inherit from OpState
. If you know a better way of checking if a class is an interface, let me know. I also thought about calling the function getOperationName
instead of getInterfaceName
, but that could be confusing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We actually have logic for this already at
void isInterfaceImpl( |
DenseMap
work with interfaces: llvm-project/mlir/include/mlir/IR/OpDefinition.h
Line 2130 in 50b45b2
std::enable_if_t<std::is_base_of<mlir::OpState, T>::value && |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For reference, this was previously attempted in #66196
The review comments there seem to align with the implementation shown here
@@ -1309,6 +1311,25 @@ struct HasParent { | |||
return llvm::cast<ParentOpType>(parent); | |||
} | |||
}; | |||
|
|||
private: | |||
/// A class is an op interface if it has a `getInterfaceName` function. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We actually have logic for this already at
void isInterfaceImpl( |
DenseMap
work with interfaces: llvm-project/mlir/include/mlir/IR/OpDefinition.h
Line 2130 in 50b45b2
std::enable_if_t<std::is_base_of<mlir::OpState, T>::value && |
This commit adds support for op interfaces to
HasParent
: an op interface can now be specified as a parent.To produce useful error messages, a new helper function
getInterfaceName
is generated for every op interface. This is similar togetOperationName
, which is generated for operations.This commit addresses a TODO in
TensorOps.td
.