Skip to content

[mlir][IR] Support op interfaces in HasParent trait #91471

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

matthias-springer
Copy link
Member

This commit adds support for op interfaces to HasParent: an op interface can now be specified as a parent.

To produce useful error messages, a new helper function getInterfaceName is generated for every op interface. This is similar to getOperationName, which is generated for operations.

This commit addresses a TODO in TensorOps.td.

This commit adds support for op interfaces to `HasParent`: an op interface can now be specified as a parent.

To produce useful error messages, a new helper function `getInterfaceName` is generated for every op interface. This is similar to `getOperationName`, which is generated for operations.

This commit addresses a TODO in `TensorOps.td`.
@llvmbot
Copy link
Member

llvmbot commented May 8, 2024

@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-tensor
@llvm/pr-subscribers-mlir-ods

@llvm/pr-subscribers-mlir-core

Author: Matthias Springer (matthias-springer)

Changes

This commit adds support for op interfaces to HasParent: an op interface can now be specified as a parent.

To produce useful error messages, a new helper function getInterfaceName is generated for every op interface. This is similar to getOperationName, which is generated for operations.

This commit addresses a TODO in TensorOps.td.


Full diff: https://github.com/llvm/llvm-project/pull/91471.diff

6 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td (+1-2)
  • (modified) mlir/include/mlir/IR/OpBase.td (+1-1)
  • (modified) mlir/include/mlir/IR/OpDefinition.h (+22-1)
  • (modified) mlir/lib/Dialect/Tensor/IR/TensorOps.cpp (-4)
  • (modified) mlir/test/Dialect/Tensor/invalid.mlir (+9)
  • (modified) mlir/tools/mlir-tblgen/OpInterfacesGen.cpp (+8-2)
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index a403e89a39f98..2d9f4c29f7aad 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -1463,8 +1463,7 @@ def Tensor_PadOp : Tensor_Op<"pad", [
 def Tensor_ParallelInsertSliceOp : Tensor_Op<"parallel_insert_slice", [
        AttrSizedOperandSegments,
        OffsetSizeAndStrideOpInterface,
-       // TODO: Cannot use an interface here atm, verify this manually for now.
-       // HasParent<"ParallelCombiningOpInterface">
+       HasParent<"ParallelCombiningOpInterface">
   ]> {
   let summary = [{
     Specify the tensor slice update of a single thread of a parent
diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index 7866ac24c1ccb..b089e72fe8928 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -133,7 +133,7 @@ class SingleBlockImplicitTerminator<string op>
 // Op's regions don't have terminator.
 def NoTerminator : NativeOpTrait<"NoTerminator">, StructuralOpTrait;
 
-// Op's parent operation is the provided one.
+// Op's parent operation or op interface is the provided one.
 class HasParent<string op>
     : ParamNativeOpTrait<"HasParent", op>, StructuralOpTrait;
 
diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h
index 59f094d669099..550f04d9a373b 100644
--- a/mlir/include/mlir/IR/OpDefinition.h
+++ b/mlir/include/mlir/IR/OpDefinition.h
@@ -1298,7 +1298,9 @@ struct HasParent {
       return op->emitOpError()
              << "expects parent op "
              << (sizeof...(ParentOpTypes) != 1 ? "to be one of '" : "'")
-             << llvm::ArrayRef({ParentOpTypes::getOperationName()...}) << "'";
+             << llvm::ArrayRef(
+                    {getOperationOrInterfaceName<ParentOpTypes>()...})
+             << "'";
     }
 
     template <typename ParentOpType =
@@ -1309,6 +1311,25 @@ struct HasParent {
       return llvm::cast<ParentOpType>(parent);
     }
   };
+
+private:
+  /// A class is an op interface if it has a `getInterfaceName` function.
+  template <typename T, typename = int>
+  struct IsInterface : std::false_type {};
+  template <typename T>
+  struct IsInterface<T, decltype((void)T::getInterfaceName(), 0)>
+      : std::true_type {};
+
+  /// Helper function that returns the name of the given operation or interface
+  /// as a string literal.
+  template <typename T>
+  static constexpr StringLiteral getOperationOrInterfaceName() {
+    if constexpr (IsInterface<T>::value) {
+      return T::getInterfaceName();
+    } else {
+      return T::getOperationName();
+    }
+  }
 };
 
 /// A trait for operations that have an attribute specifying operand segments.
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 7a13f7a7d1355..f45c2e4efdf58 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -3455,10 +3455,6 @@ void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
 }
 
 LogicalResult ParallelInsertSliceOp::verify() {
-  if (!isa<ParallelCombiningOpInterface>(getOperation()->getParentOp()))
-    return this->emitError("expected ParallelCombiningOpInterface parent, got:")
-           << *(getOperation()->getParentOp());
-
   RankedTensorType expectedType;
   SliceVerificationResult result =
       verifyInsertSliceOp(getSourceType(), getDestType(), getStaticOffsets(),
diff --git a/mlir/test/Dialect/Tensor/invalid.mlir b/mlir/test/Dialect/Tensor/invalid.mlir
index 41b6529f64afa..4205d9c3dcd31 100644
--- a/mlir/test/Dialect/Tensor/invalid.mlir
+++ b/mlir/test/Dialect/Tensor/invalid.mlir
@@ -698,3 +698,12 @@ func.func @unpack_mismatch_inner_tile_size_and_output_shape(
   %0 = tensor.unpack %input inner_dims_pos = [0, 1] inner_tiles = [8, 4] into %output : tensor<?x?x8x8xf32> -> tensor<?x?xf32>
   return %0 : tensor<?x?xf32>
 }
+
+// -----
+
+func.func @parallel_insert_slice_out_of_context(%a: tensor<5xf32>, %b: tensor<100xf32>) {
+  // expected-error@+1 {{expects parent op 'ParallelCombiningOpInterface'}}
+  tensor.parallel_insert_slice %a into %b[0][5][1]
+      : tensor<5xf32> into tensor<100xf32>
+  return
+}
diff --git a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp
index 2a7406f42f34b..17babee913f04 100644
--- a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp
@@ -537,7 +537,7 @@ void InterfaceGenerator::emitInterfaceDecl(const Interface &interface) {
 
   // Emit the derived trait for the interface.
   os << "template <typename " << valueTemplate << ">\n";
-  os << "struct " << interface.getName() << "Trait;\n";
+  os << "struct " << interfaceName << "Trait;\n";
 
   os << "\n} // namespace detail\n";
 
@@ -548,6 +548,11 @@ void InterfaceGenerator::emitInterfaceDecl(const Interface &interface) {
                       interfaceName, interfaceName, interfaceTraitsName,
                       interfaceBaseType);
 
+  // Insert function that returns the name of the interface as a string.
+  os << "  static constexpr ::llvm::StringLiteral getInterfaceName() {\n"
+     << "    return \"" << interfaceName << "\";\n"
+     << "  }\n\n";
+
   // Emit a utility wrapper trait class.
   os << llvm::formatv("  template <typename {1}>\n"
                       "  struct Trait : public detail::{0}Trait<{1}> {{};\n",
@@ -588,7 +593,8 @@ void InterfaceGenerator::emitInterfaceDecl(const Interface &interface) {
        << "    auto* interface = getInterfaceFor(base);\n"
        << "    if (!interface)\n"
           "      return false;\n"
-          "    " << interfaceName << " odsInterfaceInstance(base, interface);\n"
+          "    "
+       << interfaceName << " odsInterfaceInstance(base, interface);\n"
        << "    " << tblgen::tgfmt(extraClassOf->trim(), &extraClassOfFmt)
        << "\n  }\n";
   }

@@ -1309,6 +1311,25 @@ struct HasParent {
return llvm::cast<ParentOpType>(parent);
}
};

private:
/// A class is an op interface if it has a `getInterfaceName` function.
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ops and op interfaces inherit from OpState. If you know a better way of checking if a class is an interface, let me know. I also thought about calling the function getOperationName instead of getInterfaceName, but that could be confusing.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We actually have logic for this already at

which is used to make DenseMap work with interfaces:
std::enable_if_t<std::is_base_of<mlir::OpState, T>::value &&

Copy link
Member

@zero9178 zero9178 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For reference, this was previously attempted in #66196
The review comments there seem to align with the implementation shown here

@@ -1309,6 +1311,25 @@ struct HasParent {
return llvm::cast<ParentOpType>(parent);
}
};

private:
/// A class is an op interface if it has a `getInterfaceName` function.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We actually have logic for this already at

which is used to make DenseMap work with interfaces:
std::enable_if_t<std::is_base_of<mlir::OpState, T>::value &&

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants