Skip to content

[mlir] [tblgen-to-irdl] Add types to tblgen-to-irdl script #108558

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 17, 2024

Conversation

alexarice
Copy link
Contributor

Adds dialect types to the tblgen-to-irdl script and also allows operations to refer to types by symbol, when possible, and updates tests to do this.

The name of the type is exported with an exclamation mark to avoid name clashes.

@math-fehr

Also allows operations to refer to types by symbol, when possible, and
updates tests to do this.
@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir labels Sep 13, 2024
@llvmbot
Copy link
Member

llvmbot commented Sep 13, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-core

Author: Alex Rice (alexarice)

Changes

Adds dialect types to the tblgen-to-irdl script and also allows operations to refer to types by symbol, when possible, and updates tests to do this.

The name of the type is exported with an exclamation mark to avoid name clashes.

@math-fehr


Full diff: https://github.com/llvm/llvm-project/pull/108558.diff

3 Files Affected:

  • (modified) mlir/test/tblgen-to-irdl/CMathDialect.td (+7-5)
  • (modified) mlir/test/tblgen-to-irdl/TestDialect.td (+10-7)
  • (modified) mlir/tools/tblgen-to-irdl/OpDefinitionsGen.cpp (+36)
diff --git a/mlir/test/tblgen-to-irdl/CMathDialect.td b/mlir/test/tblgen-to-irdl/CMathDialect.td
index 454543e074c489..abda7ca41e9db9 100644
--- a/mlir/test/tblgen-to-irdl/CMathDialect.td
+++ b/mlir/test/tblgen-to-irdl/CMathDialect.td
@@ -19,12 +19,14 @@ class CMath_Op<string mnemonic, list<Trait> traits = []>
 def f32Orf64Type : Or<[CPred<"::llvm::isa<::mlir::F32>">,
                        CPred<"::llvm::isa<::mlir::F64>">]>;
 
+// CHECK: irdl.type @"!complex"
 def CMath_ComplexType : CMath_Type<"ComplexType", "complex"> {
   let parameters = (ins f32Orf64Type:$elementType);
+  let assemblyFormat = "`<` $elementType `>`";
 }
 
 // CHECK:      irdl.operation @identity {
-// CHECK-NEXT:   %0 = irdl.base "!cmath.complex"
+// CHECK-NEXT:   %0 = irdl.base @cmath::@"!complex"
 // CHECK-NEXT:   irdl.results(%0)
 // CHECK-NEXT: }
 def CMath_IdentityOp : CMath_Op<"identity"> {
@@ -32,9 +34,9 @@ def CMath_IdentityOp : CMath_Op<"identity"> {
 }
 
 // CHECK:      irdl.operation @mul {
-// CHECK-NEXT:   %0 = irdl.base "!cmath.complex"
-// CHECK-NEXT:   %1 = irdl.base "!cmath.complex"
-// CHECK-NEXT:   %2 = irdl.base "!cmath.complex"
+// CHECK-NEXT:   %0 = irdl.base @cmath::@"!complex"
+// CHECK-NEXT:   %1 = irdl.base @cmath::@"!complex"
+// CHECK-NEXT:   %2 = irdl.base @cmath::@"!complex"
 // CHECK-NEXT:   irdl.operands(%0, %1)
 // CHECK-NEXT:   irdl.results(%2)
 // CHECK-NEXT: }
@@ -45,7 +47,7 @@ def CMath_MulOp : CMath_Op<"mul"> {
 
 // CHECK:      irdl.operation @norm {
 // CHECK-NEXT:   %0 = irdl.any
-// CHECK-NEXT:   %1 = irdl.base "!cmath.complex"
+// CHECK-NEXT:   %1 = irdl.base @cmath::@"!complex"
 // CHECK-NEXT:   irdl.operands(%0)
 // CHECK-NEXT:   irdl.results(%1)
 // CHECK-NEXT: }
diff --git a/mlir/test/tblgen-to-irdl/TestDialect.td b/mlir/test/tblgen-to-irdl/TestDialect.td
index 2622c817760762..4fea3d8576e9ab 100644
--- a/mlir/test/tblgen-to-irdl/TestDialect.td
+++ b/mlir/test/tblgen-to-irdl/TestDialect.td
@@ -16,8 +16,11 @@ class Test_Type<string name, string typeMnemonic, list<Trait> traits = []>
 class Test_Op<string mnemonic, list<Trait> traits = []>
     : Op<Test_Dialect, mnemonic, traits>;
 
+// CHECK: irdl.type @"!singleton_a"
 def Test_SingletonAType : Test_Type<"SingletonAType", "singleton_a"> {}
+// CHECK: irdl.type @"!singleton_b"
 def Test_SingletonBType : Test_Type<"SingletonBType", "singleton_b"> {}
+// CHECK: irdl.type @"!singleton_c"
 def Test_SingletonCType : Test_Type<"SingletonCType", "singleton_c"> {}
 
 
@@ -26,7 +29,7 @@ def Test_AndOp : Test_Op<"and"> {
   let arguments = (ins AllOfType<[Test_SingletonAType, AnyType]>:$in);
 }
 // CHECK-LABEL: irdl.operation @and {
-// CHECK-NEXT:    %[[v0:[^ ]*]] = irdl.base "!test.singleton_a"
+// CHECK-NEXT:    %[[v0:[^ ]*]] = irdl.base @test::@"!singleton_a"
 // CHECK-NEXT:    %[[v1:[^ ]*]] = irdl.any
 // CHECK-NEXT:    %[[v2:[^ ]*]] = irdl.all_of(%[[v0]], %[[v1]])
 // CHECK-NEXT:    irdl.operands(%[[v2]])
@@ -79,9 +82,9 @@ def Test_OrOp : Test_Op<"or"> {
   let arguments = (ins AnyTypeOf<[Test_SingletonAType, Test_SingletonBType, Test_SingletonCType]>:$in);
 }
 // CHECK-LABEL: irdl.operation @or {
-// CHECK-NEXT:    %[[v0:[^ ]*]] = irdl.base "!test.singleton_a"
-// CHECK-NEXT:    %[[v1:[^ ]*]] = irdl.base "!test.singleton_b"
-// CHECK-NEXT:    %[[v2:[^ ]*]] = irdl.base "!test.singleton_c"
+// CHECK-NEXT:    %[[v0:[^ ]*]] = irdl.base @test::@"!singleton_a"
+// CHECK-NEXT:    %[[v1:[^ ]*]] = irdl.base @test::@"!singleton_b"
+// CHECK-NEXT:    %[[v2:[^ ]*]] = irdl.base @test::@"!singleton_c"
 // CHECK-NEXT:    %[[v3:[^ ]*]] = irdl.any_of(%[[v0]], %[[v1]], %[[v2]])
 // CHECK-NEXT:    irdl.operands(%[[v3]])
 // CHECK-NEXT:  }
@@ -114,8 +117,8 @@ def Test_VariadicityOp : Test_Op<"variadicity"> {
                        Test_SingletonCType:$required);
 }
 // CHECK-LABEL: irdl.operation @variadicity {
-// CHECK-NEXT:    %[[v0:[^ ]*]] = irdl.base "!test.singleton_a"
-// CHECK-NEXT:    %[[v1:[^ ]*]] = irdl.base "!test.singleton_b"
-// CHECK-NEXT:    %[[v2:[^ ]*]] = irdl.base "!test.singleton_c"
+// CHECK-NEXT:    %[[v0:[^ ]*]] = irdl.base @test::@"!singleton_a"
+// CHECK-NEXT:    %[[v1:[^ ]*]] = irdl.base @test::@"!singleton_b"
+// CHECK-NEXT:    %[[v2:[^ ]*]] = irdl.base @test::@"!singleton_c"
 // CHECK-NEXT:    irdl.operands(variadic %[[v0]], optional %[[v1]], %[[v2]])
 // CHECK-NEXT:  }
diff --git a/mlir/tools/tblgen-to-irdl/OpDefinitionsGen.cpp b/mlir/tools/tblgen-to-irdl/OpDefinitionsGen.cpp
index dd0d98de496e86..45957bafc378e3 100644
--- a/mlir/tools/tblgen-to-irdl/OpDefinitionsGen.cpp
+++ b/mlir/tools/tblgen-to-irdl/OpDefinitionsGen.cpp
@@ -177,6 +177,15 @@ Value createConstraint(OpBuilder &builder, tblgen::Constraint constraint) {
   }
 
   if (predRec.isSubClassOf("TypeDef")) {
+    auto dialect = predRec.getValueAsDef("dialect")->getValueAsString("name");
+    if (dialect == selectedDialect) {
+      std::string combined = ("!" + predRec.getValueAsString("mnemonic")).str();
+      SmallVector<FlatSymbolRefAttr> nested = {
+          SymbolRefAttr::get(ctx, combined)};
+      auto typeSymbol = SymbolRefAttr::get(ctx, dialect, nested);
+      auto op = builder.create<irdl::BaseOp>(UnknownLoc::get(ctx), typeSymbol);
+      return op.getOutput();
+    }
     std::string typeName = ("!" + predRec.getValueAsString("typeName")).str();
     auto op = builder.create<irdl::BaseOp>(UnknownLoc::get(ctx),
                                            StringAttr::get(ctx, typeName));
@@ -250,6 +259,12 @@ static StringRef getOperatorName(tblgen::Operator &tblgenOp) {
   return opName;
 }
 
+/// Returns the name of the type without the dialect prefix.
+static StringRef getTypeName(tblgen::TypeDef &tblgenType) {
+  StringRef opName = tblgenType.getDef()->getValueAsString("mnemonic");
+  return opName;
+}
+
 /// Extract an operation to IRDL.
 irdl::OperationOp createIRDLOperation(OpBuilder &builder,
                                       tblgen::Operator &tblgenOp) {
@@ -300,6 +315,19 @@ irdl::OperationOp createIRDLOperation(OpBuilder &builder,
   return op;
 }
 
+irdl::TypeOp createIRDLType(OpBuilder &builder, tblgen::TypeDef &tblgenType) {
+  MLIRContext *ctx = builder.getContext();
+  StringRef typeName = getTypeName(tblgenType);
+  std::string combined = ("!" + typeName).str();
+
+  irdl::TypeOp op = builder.create<irdl::TypeOp>(
+      UnknownLoc::get(ctx), StringAttr::get(ctx, combined));
+
+  op.getBody().emplaceBlock();
+
+  return op;
+}
+
 static irdl::DialectOp createIRDLDialect(OpBuilder &builder) {
   MLIRContext *ctx = builder.getContext();
   return builder.create<irdl::DialectOp>(UnknownLoc::get(ctx),
@@ -322,6 +350,14 @@ static bool emitDialectIRDLDefs(const RecordKeeper &recordKeeper,
   // Set insertion point to start of DialectOp.
   builder = builder.atBlockBegin(&dialect.getBody().emplaceBlock());
 
+  for (const Record *type :
+       recordKeeper.getAllDerivedDefinitionsIfDefined("TypeDef")) {
+    tblgen::TypeDef tblgenType(type);
+    if (tblgenType.getDialect().getName() != selectedDialect)
+      continue;
+    createIRDLType(builder, tblgenType);
+  }
+
   for (const Record *def :
        recordKeeper.getAllDerivedDefinitionsIfDefined("Op")) {
     tblgen::Operator tblgenOp(def);

@math-fehr math-fehr self-requested a review September 15, 2024 14:59
Copy link
Contributor

@math-fehr math-fehr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perfect!
The only issue is that types names are prefixed by ! here, while we should not do that.

@@ -16,8 +16,11 @@ class Test_Type<string name, string typeMnemonic, list<Trait> traits = []>
class Test_Op<string mnemonic, list<Trait> traits = []>
: Op<Test_Dialect, mnemonic, traits>;

// CHECK: irdl.type @"!singleton_a"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The name should just be @singleton_a here.
The reason we had a ! in some places was only in string names (compared to references), to distinguish types and attributes.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can still get name clashes in the symbol names. There are dialects which have operations, types, and attributes share name (I think stablehlo is an example)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh wait you're right yes!

@math-fehr math-fehr merged commit 04575dc into llvm:main Sep 17, 2024
11 checks passed
hamphet pushed a commit to hamphet/llvm-project that referenced this pull request Sep 18, 2024
Adds dialect types to the tblgen-to-irdl script and also allows
operations to refer to types by symbol, when possible, and updates tests
to do this.

The name of the type is exported with an exclamation mark to avoid name
clashes.
tmsri pushed a commit to tmsri/llvm-project that referenced this pull request Sep 19, 2024
Adds dialect types to the tblgen-to-irdl script and also allows
operations to refer to types by symbol, when possible, and updates tests
to do this.

The name of the type is exported with an exclamation mark to avoid name
clashes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants