Skip to content

Commit 3cdf6e7

Browse files
committed
[mlir][ODS] Fix default inferReturnTypes generation for variadic operands
1 parent 7c26407 commit 3cdf6e7

File tree

2 files changed

+30
-39
lines changed

2 files changed

+30
-39
lines changed

mlir/test/mlir-tblgen/op-result.td

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -136,9 +136,8 @@ def OpL1 : NS_Op<"op_with_all_types_constraint",
136136

137137
// CHECK-LABEL: LogicalResult OpL1::inferReturnTypes
138138
// CHECK-NOT: }
139-
// CHECK: if (operands.size() <= 0)
140-
// CHECK-NEXT: return ::mlir::failure();
141-
// CHECK: ::mlir::Type odsInferredType0 = operands[0].getType();
139+
// CHECK: OpL1::Adaptor adaptor
140+
// CHECK: ::mlir::Type odsInferredType0 = adaptor.getA().getType();
142141
// CHECK: inferredReturnTypes[0] = odsInferredType0;
143142

144143
def OpL2 : NS_Op<"op_with_all_types_constraint",
@@ -149,11 +148,9 @@ def OpL2 : NS_Op<"op_with_all_types_constraint",
149148

150149
// CHECK-LABEL: LogicalResult OpL2::inferReturnTypes
151150
// CHECK-NOT: }
152-
// CHECK: if (operands.size() <= 2)
153-
// CHECK-NEXT: return ::mlir::failure();
154-
// CHECK-NOT: if (operands.size() <= 0)
155-
// CHECK: ::mlir::Type odsInferredType0 = operands[2].getType();
156-
// CHECK: ::mlir::Type odsInferredType1 = operands[0].getType();
151+
// CHECK: OpL2::Adaptor adaptor
152+
// CHECK: ::mlir::Type odsInferredType0 = adaptor.getC().getType();
153+
// CHECK: ::mlir::Type odsInferredType1 = adaptor.getA().getType();
157154
// CHECK: inferredReturnTypes[0] = odsInferredType0;
158155
// CHECK: inferredReturnTypes[1] = odsInferredType1;
159156

@@ -177,9 +174,8 @@ def OpL4 : NS_Op<"two_inference_edges", [
177174
}
178175

179176
// CHECK-LABEL: LogicalResult OpL4::inferReturnTypes
180-
// CHECK: if (operands.size() <= 0)
181-
// CHECK-NEXT: return ::mlir::failure();
182-
// CHECK: odsInferredType0 = fromInput(operands[0].getType())
177+
// CHECK: OpL4::Adaptor adaptor
178+
// CHECK: odsInferredType0 = fromInput(adaptor.getInput().getType())
183179
// CHECK: odsInferredType1 = infer0(odsInferredType0)
184180
// CHECK: odsInferredType2 = infer1(odsInferredType1)
185181
// CHECK: inferredReturnTypes[0] = odsInferredType0

mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp

Lines changed: 23 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -2641,8 +2641,7 @@ void OpEmitter::genSeparateArgParamBuilder() {
26412641

26422642
// Avoid emitting "resultTypes.size() >= 0u" which is always true.
26432643
if (!hasVariadicResult || numNonVariadicResults != 0)
2644-
body << " "
2645-
<< "assert(resultTypes.size() "
2644+
body << " " << "assert(resultTypes.size() "
26462645
<< (hasVariadicResult ? ">=" : "==") << " "
26472646
<< numNonVariadicResults
26482647
<< "u && \"mismatched number of results\");\n";
@@ -3751,29 +3750,24 @@ void OpEmitter::genTypeInterfaceMethods() {
37513750
fctx.addSubst("_ctxt", "context");
37523751
body << " ::mlir::Builder odsBuilder(context);\n";
37533752

3754-
// Preprocessing stage to verify all accesses to operands are valid.
3755-
int maxAccessedIndex = -1;
3756-
for (int i = 0, e = op.getNumResults(); i != e; ++i) {
3757-
const InferredResultType &infer = op.getInferredResultType(i);
3758-
if (!infer.isArg())
3759-
continue;
3760-
Operator::OperandOrAttribute arg =
3761-
op.getArgToOperandOrAttribute(infer.getIndex());
3762-
if (arg.kind() == Operator::OperandOrAttribute::Kind::Operand) {
3763-
maxAccessedIndex =
3764-
std::max(maxAccessedIndex, arg.operandOrAttributeIndex());
3765-
}
3766-
}
3767-
if (maxAccessedIndex != -1) {
3768-
body << " if (operands.size() <= " << Twine(maxAccessedIndex) << ")\n";
3769-
body << " return ::mlir::failure();\n";
3770-
}
3753+
// Emit an adaptor to access right ranges for ods operands.
3754+
body << " " << op.getCppClassName()
3755+
<< "::Adaptor adaptor(operands, attributes, properties, regions);\n";
37713756

3772-
// Process the type inference graph in topological order, starting from types
3773-
// that are always fully-inferred: operands and results with constructible
3774-
// types. The type inference graph here will always be a DAG, so this gives
3775-
// us the correct order for generating the types. -1 is a placeholder to
3776-
// indicate the type for a result has not been generated.
3757+
// TODO: Ideally, we should be doing some sort of verification here. This
3758+
// is however problemetic due to 2 reasons:
3759+
//
3760+
// 1. Adaptor::verify only verifies attributes. It really should verify
3761+
// if the number of given attributes is right too.
3762+
// 2. PDL passes empty properties to inferReturnTypes, which does not verify.
3763+
// Without properties, it's not really possible to verify the number of
3764+
// operands as we do not know the variadic operand segment sizes.
3765+
3766+
// Process the type inference graph in topological order, starting from
3767+
// types that are always fully-inferred: operands and results with
3768+
// constructible types. The type inference graph here will always be a
3769+
// DAG, so this gives us the correct order for generating the types. -1 is
3770+
// a placeholder to indicate the type for a result has not been generated.
37773771
SmallVector<int> constructedIndices(op.getNumResults(), -1);
37783772
int inferredTypeIdx = 0;
37793773
for (int numResults = op.getNumResults(); inferredTypeIdx != numResults;) {
@@ -3788,10 +3782,11 @@ void OpEmitter::genTypeInterfaceMethods() {
37883782
Operator::OperandOrAttribute arg =
37893783
op.getArgToOperandOrAttribute(infer.getIndex());
37903784
if (arg.kind() == Operator::OperandOrAttribute::Kind::Operand) {
3791-
typeStr = ("operands[" + Twine(arg.operandOrAttributeIndex()) +
3792-
"].getType()")
3793-
.str();
3794-
3785+
std::string getter =
3786+
"adaptor." +
3787+
op.getGetterName(
3788+
op.getOperand(arg.operandOrAttributeIndex()).name);
3789+
typeStr = (getter + "().getType()");
37953790
// If this is an attribute, index into the attribute dictionary.
37963791
} else {
37973792
auto *attr =

0 commit comments

Comments
 (0)