Skip to content

Commit

Permalink
[Substrait] Implemented folding of chains of emit ops.
Browse files Browse the repository at this point in the history
This extends the folding of `emit` ops to the case where the input is
also an `emit` op, in which case the two are fused into a single one.

Signed-off-by: Ingo Müller <ingomueller@google.com>
  • Loading branch information
ingomueller-net committed May 22, 2024
1 parent 0260aab commit 13f2ac4
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 0 deletions.
23 changes: 23 additions & 0 deletions lib/Dialect/Substrait/IR/Substrait.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,29 @@ CrossOp::inferReturnTypes(MLIRContext *context, std::optional<Location> loc,
}

OpFoldResult EmitOp::fold(FoldAdaptor adaptor) {
MLIRContext *context = getContext();
Type i64 = IntegerType::get(context, 64);

// If the input is also an `emit`, fold it into this op.
if (auto previousEmit = dyn_cast<EmitOp>(getInput().getDefiningOp())) {
// Compute new mapping.
ArrayAttr previousMapping = previousEmit.getMapping();
SmallVector<Attribute> newMapping;
newMapping.reserve(getMapping().size());
for (auto attr : getMapping().getAsRange<IntegerAttr>()) {
int64_t index = attr.getInt();
int64_t newIndex = cast<IntegerAttr>(previousMapping[index]).getInt();
newMapping.push_back(IntegerAttr::get(i64, newIndex));
}

// Update this op.
setMappingAttr(ArrayAttr::get(context, newMapping));
setOperand(previousEmit.getInput());
return getResult();
}

// Remainder: fold away if the mapping is the identity mapping.

// Return if the mapping is not the identity mapping.
int64_t numFields = cast<TupleType>(getInput().getType()).size();
int64_t numIndices = getMapping().size();
Expand Down
22 changes: 22 additions & 0 deletions test/Dialect/Substrait/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,25 @@ substrait.plan version 0 : 42 : 1 {
yield %1 : tuple<si1>
}
}

// -----

// Check that chains of `emit` ops are folded into one.

// CHECK-LABEL: substrait.plan
// CHECK-NEXT: relation
// CHECK-NEXT: %[[V0:.*]] = named_table
// TODO(ingomueller): check for DCE once implemented.
// CHECK: yield %[[V0]]

substrait.plan version 0 : 42 : 1 {
relation {
%0 = named_table @t1 as ["a", "b"] : tuple<si1, si32>
%1 = emit [1, 0] from %0 : tuple<si1, si32> -> tuple<si32, si1>
%2 = emit [1, 0] from %1 : tuple<si32, si1> -> tuple<si1, si32>
%3 = emit [0, 0, 1, 1] from %2 : tuple<si1, si32> -> tuple<si1, si1, si32, si32>
%4 = emit [3, 0, 1] from %3 : tuple<si1, si1, si32, si32> -> tuple<si32, si1, si1>
%5 = emit [1, 0] from %4 : tuple<si32, si1, si1> -> tuple<si1, si32>
yield %5 : tuple<si1, si32>
}
}

0 comments on commit 13f2ac4

Please sign in to comment.