Skip to content

Commit

Permalink
[Arc] Make the canonicalizer shuffle the input vector elements before…
Browse files Browse the repository at this point in the history
… merging (#7394)
  • Loading branch information
elhewaty authored Aug 4, 2024
1 parent 17545bc commit 17e85f1
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 19 deletions.
50 changes: 45 additions & 5 deletions lib/Dialect/Arc/Transforms/ArcCanonicalizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -619,18 +619,58 @@ MergeVectorizeOps::matchAndRewrite(VectorizeOp vecOp,
// Ensure that the input vector matches the output of the `otherVecOp`
// Make sure that the results of the otherVecOp have only one use
auto otherVecOp = inputVec[0].getDefiningOp<VectorizeOp>();
if (!otherVecOp || inputVec != otherVecOp.getResults() ||
otherVecOp == vecOp ||
if (!otherVecOp || otherVecOp == vecOp ||
!llvm::all_of(otherVecOp.getResults(),
[](auto result) { return result.hasOneUse(); })) {
[](auto result) { return result.hasOneUse(); }) ||
!llvm::all_of(inputVec, [&](auto result) {
return result.template getDefiningOp<VectorizeOp>() == otherVecOp;
})) {
newOperands.insert(newOperands.end(), inputVec.begin(), inputVec.end());
continue;
}

// Here, all elements are from the same `VectorizeOp`.
// If all elements of the input vector come from the same `VectorizeOp`
// sort the vectors by their indices
DenseMap<Value, size_t> resultIdxMap;
for (auto [resultIdx, result] : llvm::enumerate(otherVecOp.getResults()))
resultIdxMap[result] = resultIdx;

SmallVector<Value> tempVec(inputVec.begin(), inputVec.end());
llvm::sort(tempVec, [&](Value a, Value b) {
return resultIdxMap[a] < resultIdxMap[b];
});

// Check if inputVec matches the result after sorting.
if (tempVec != SmallVector<Value>(otherVecOp.getResults().begin(),
otherVecOp.getResults().end())) {
newOperands.insert(newOperands.end(), inputVec.begin(), inputVec.end());
continue;
}

DenseMap<size_t, size_t> fromRealIdxToSortedIdx;
for (auto [inIdx, in] : llvm::enumerate(inputVec))
fromRealIdxToSortedIdx[inIdx] = resultIdxMap[in];

// If this flag is set that means we changed the IR so we cannot return
// failure
canBeMerged = true;
newOperands.insert(newOperands.end(), otherVecOp.getOperands().begin(),
otherVecOp.getOperands().end());

// If the results got shuffled, then shuffle the operands before merging.
if (inputVec != otherVecOp.getResults()) {
for (auto otherVecOpInputVec : otherVecOp.getInputs()) {
// use the tempVec again instead of creating another one.
tempVec = SmallVector<Value>(inputVec.size());
for (auto [realIdx, opernad] : llvm::enumerate(otherVecOpInputVec))
tempVec[realIdx] =
otherVecOpInputVec[fromRealIdxToSortedIdx[realIdx]];

newOperands.insert(newOperands.end(), tempVec.begin(), tempVec.end());
}

} else
newOperands.insert(newOperands.end(), otherVecOp.getOperands().begin(),
otherVecOp.getOperands().end());

auto &otherBlock = otherVecOp.getBody().front();
for (auto &otherArg : otherBlock.getArguments()) {
Expand Down
53 changes: 39 additions & 14 deletions test/Dialect/Arc/arc-canonicalizer.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -453,22 +453,47 @@ in %clock: !seq.clock, in %o: i8, in %v: i8, in %q: i8, in %s: i8) {
}

// CHECK-LABEL: hw.module @Needs_Shuffle(in %b : i8, in %e : i8, in %h : i8, in %k : i8, in %c : i8, in %f : i8, in %i : i8, in %l : i8, in %n : i8, in %p : i8, in %r : i8, in %t : i8, in %en : i1, in %clock : !seq.clock, in %o : i8, in %v : i8, in %q : i8, in %s : i8) {
// CHECK-NEXT: [[VEC0:%.+]]:4 = arc.vectorize (%b, %e, %h, %k), (%c, %f, %i, %l) : (i8, i8, i8, i8, i8, i8, i8, i8) -> (i8, i8, i8, i8) {
// CHECK-NEXT: ^[[BLOCK:[[:alnum:]]+]](%arg0: i8, %arg1: i8):
// CHECK-NEXT: [[OUT:%.+]] = comb.add %arg0, %arg1 : i8
// CHECK-NEXT: arc.vectorize.return [[OUT]] : i8
// CHECK-NEXT: }
// CHECK-NEXT: [[VEC1:%.+]]:4 = arc.vectorize ([[VEC0]]#1, [[VEC0]]#0, [[VEC0]]#2, [[VEC0]]#3), (%n, %p, %r, %t) : (i8, i8, i8, i8, i8, i8, i8, i8) -> (i8, i8, i8, i8) {
// CHECK-NEXT: ^[[BLOCK:[[:alnum:]]+]](%arg0: i8, %arg1: i8):
// CHECK-NEXT: [[OUT:%.+]] = comb.and %arg0, %arg1 : i8
// CHECK-NEXT: arc.vectorize.return [[OUT]] : i8
// CHECK-NEXT: [[VEC:%.+]]:4 = arc.vectorize (%b, %e, %h, %k), (%c, %f, %i, %l), (%p, %n, %r, %t), (%o, %v, %q, %s) : (i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8) -> (i8, i8, i8, i8) {
// CHECK-NEXT: ^[[BLOCK:[[:alnum:]]+]](%arg0: i8, %arg1: i8, %arg2: i8, %arg3: i8):
// CHECK-NEXT: [[ADD:%.+]] = comb.add %arg0, %arg1 : i8
// CHECK-NEXT: [[AND:%.+]] = comb.and [[ADD]], %arg2 : i8
// CHECK-NEXT: [[CALL:%.+]] = arc.call @Just_A_Dummy_Func([[AND]], %arg3) : (i8, i8) -> i8
// CHECK-NEXT: arc.vectorize.return [[CALL]] : i8
// CHECK-NEXT: }
// CHECK-NEXT: [[VEC2:%.+]]:4 = arc.vectorize ([[VEC1]]#1, [[VEC1]]#0, [[VEC1]]#2, [[VEC1]]#3), (%o, %v, %q, %s) : (i8, i8, i8, i8, i8, i8, i8, i8) -> (i8, i8, i8, i8) {
// CHECK-NEXT: ^[[BLOCK:[[:alnum:]]+]](%arg0: i8, %arg1: i8):
// CHECK-NEXT: [[OUT:%.+]] = arc.call @Just_A_Dummy_Func(%arg0, %arg1) : (i8, i8) -> i8
// CHECK-NEXT: arc.vectorize.return [[OUT]] : i8
// CHECK-NEXT: [[STATE:%.+]] = arc.state @FooMux(%en, [[VEC]]#0, [[STATE]]) clock %clock latency 1 : (i1, i8, i8) -> i8
// CHECK-NEXT: hw.output
// CHECK-NEXT: }

hw.module @Needs_Shuffle_2(in %b: i8, in %e: i8, in %h: i8, in %k: i8, in %c: i8, in %f: i8,
in %i: i8, in %l: i8, in %n: i8, in %p: i8, in %r: i8, in %t: i8, in %en: i1,
in %clock: !seq.clock, in %o: i8, in %v: i8, in %q: i8, in %s: i8) {
%R:4 = arc.vectorize(%b, %e, %h, %k), (%c, %f, %i, %l) : (i8, i8, i8, i8, i8, i8, i8, i8) -> (i8, i8, i8, i8) {
^bb0(%arg0: i8, %arg1: i8):
%ret = comb.add %arg0, %arg1: i8
arc.vectorize.return %ret: i8
}
%L:4 = arc.vectorize(%R#3, %R#2, %R#1, %R#0), (%n, %p, %r, %t): (i8, i8, i8, i8, i8, i8, i8, i8) -> (i8, i8, i8, i8) {
^bb0(%arg0: i8, %arg1: i8):
%ret = comb.and %arg0, %arg1: i8
arc.vectorize.return %ret: i8
}
%C:4 = arc.vectorize(%L#1, %L#0, %L#2, %L#3), (%o, %v, %q, %s) : (i8, i8, i8, i8, i8, i8, i8, i8) -> (i8, i8, i8, i8) {
^bb0(%arg0 : i8, %arg1: i8):
%1692 = arc.call @Just_A_Dummy_Func(%arg0, %arg1) : (i8, i8) -> i8
arc.vectorize.return %1692 : i8
}
%4 = arc.state @FooMux(%en, %C#0, %4) clock %clock latency 1 : (i1, i8, i8) -> i8
}

// CHECK-LABEL: hw.module @Needs_Shuffle_2(in %b : i8, in %e : i8, in %h : i8, in %k : i8, in %c : i8, in %f : i8, in %i : i8, in %l : i8, in %n : i8, in %p : i8, in %r : i8, in %t : i8, in %en : i1, in %clock : !seq.clock, in %o : i8, in %v : i8, in %q : i8, in %s : i8) {
// CHECK-NEXT: [[VEC:%.+]]:4 = arc.vectorize (%h, %k, %e, %b), (%i, %l, %f, %c), (%p, %n, %r, %t), (%o, %v, %q, %s) : (i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8) -> (i8, i8, i8, i8) {
// CHECK-NEXT: ^[[BLOCK:[[:alnum:]]+]](%arg0: i8, %arg1: i8, %arg2: i8, %arg3: i8):
// CHECK-NEXT: [[ADD:%.+]] = comb.add %arg0, %arg1 : i8
// CHECK-NEXT: [[AND:%.+]] = comb.and [[ADD]], %arg2 : i8
// CHECK-NEXT: [[CALL:%.+]] = arc.call @Just_A_Dummy_Func([[AND]], %arg3) : (i8, i8) -> i8
// CHECK-NEXT: arc.vectorize.return [[CALL]] : i8
// CHECK-NEXT: }
// CHECK-NEXT: [[STATE:%.+]] = arc.state @FooMux(%en, [[VEC2]]#0, [[STATE:%.+]]) clock %clock latency 1 : (i1, i8, i8) -> i8
// CHECK-NEXT: [[STATE:%.+]] = arc.state @FooMux(%en, [[VEC]]#0, [[STATE]]) clock %clock latency 1 : (i1, i8, i8) -> i8
// CHECK-NEXT: hw.output
// CHECK-NEXT: }

Expand Down

0 comments on commit 17e85f1

Please sign in to comment.