Skip to content

Commit eada1e3

Browse files
authored
PadConcat to (MaybePad)ConcatPad (EnzymeAD#639)
* initial commit * it comiles * used untyped constructor for pad * bugfix + cleanup * added test * fix test * add checks * flippin bits * keep clangformat happy
1 parent ded5e19 commit eada1e3

File tree

3 files changed

+158
-4
lines changed

3 files changed

+158
-4
lines changed

src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp

Lines changed: 116 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12440,6 +12440,119 @@ struct BroadcastInDimIsReshape final
1244012440
}
1244112441
};
1244212442

12443+
struct PadConcatToConcatPad
12444+
: public OpRewritePattern<stablehlo::ConcatenateOp> {
12445+
using OpRewritePattern<stablehlo::ConcatenateOp>::OpRewritePattern;
12446+
12447+
LogicalResult matchAndRewrite(stablehlo::ConcatenateOp concatOp,
12448+
PatternRewriter &rewriter) const override {
12449+
12450+
if (concatOp.getNumOperands() <= 1) {
12451+
return failure();
12452+
}
12453+
12454+
// Check if all operands are pad ops with the same padding value
12455+
SmallVector<stablehlo::PadOp> padOps;
12456+
Value padValue;
12457+
12458+
for (Value operand : concatOp.getOperands()) {
12459+
auto padOp = operand.getDefiningOp<stablehlo::PadOp>();
12460+
if (!padOp)
12461+
return failure();
12462+
12463+
if (padOps.empty()) {
12464+
padValue = padOp.getPaddingValue();
12465+
} else if (padValue != padOp.getPaddingValue()) {
12466+
return failure(); // Different padding values not supported
12467+
}
12468+
12469+
padOps.push_back(padOp);
12470+
}
12471+
12472+
int64_t concatDim = concatOp.getDimension();
12473+
int64_t rank = padOps[0].getEdgePaddingLow().size();
12474+
12475+
// Compute smallest common padding for all tensors
12476+
SmallVector<int64_t> commonLowPadding(rank,
12477+
std::numeric_limits<int64_t>::max());
12478+
SmallVector<int64_t> commonHighPadding(rank,
12479+
std::numeric_limits<int64_t>::max());
12480+
SmallVector<int64_t> interiorPadding(rank, 0);
12481+
12482+
// Find minimum padding across all inputs (conservative common padding)
12483+
for (auto padOp : padOps) {
12484+
for (int64_t dim = 0; dim < rank; ++dim) {
12485+
commonLowPadding[dim] =
12486+
std::min(commonLowPadding[dim], padOp.getEdgePaddingLow()[dim]);
12487+
commonHighPadding[dim] =
12488+
std::min(commonHighPadding[dim], padOp.getEdgePaddingHigh()[dim]);
12489+
}
12490+
}
12491+
12492+
bool commonPad = false;
12493+
12494+
for (int64_t dim = 0; dim < rank; ++dim) {
12495+
if (commonLowPadding[dim] != 0 || commonHighPadding[dim] != 0) {
12496+
commonPad = true;
12497+
break;
12498+
}
12499+
}
12500+
12501+
if (!commonPad) {
12502+
return failure();
12503+
}
12504+
12505+
// Collect original operands with adjusted padding
12506+
SmallVector<Value> adjOperands;
12507+
12508+
for (auto padOp : padOps) {
12509+
12510+
SmallVector<int64_t> diffLowPadding(rank);
12511+
SmallVector<int64_t> diffHighPadding(rank);
12512+
12513+
for (int64_t dim = 0; dim < rank; ++dim) {
12514+
diffLowPadding[dim] =
12515+
padOp.getEdgePaddingLow()[dim] - commonLowPadding[dim];
12516+
diffHighPadding[dim] =
12517+
padOp.getEdgePaddingHigh()[dim] - commonHighPadding[dim];
12518+
}
12519+
12520+
bool needsExtraPad = false;
12521+
for (int64_t dim = 0; dim < rank; ++dim) {
12522+
if (diffLowPadding[dim] > 0 || diffHighPadding[dim] > 0) {
12523+
needsExtraPad = true;
12524+
break;
12525+
}
12526+
}
12527+
12528+
if (needsExtraPad) {
12529+
12530+
auto adjustedOp = rewriter.create<stablehlo::PadOp>(
12531+
padOp.getLoc(),
12532+
padOp.getOperand(), // we pad the input operand
12533+
padOp.getPaddingValue(), diffLowPadding, diffHighPadding,
12534+
padOp.getInteriorPaddingAttr());
12535+
12536+
adjOperands.push_back(adjustedOp);
12537+
} else {
12538+
// No extra padding needed, use original tensor
12539+
adjOperands.push_back(padOp.getOperand());
12540+
}
12541+
}
12542+
12543+
auto newConcatOp = rewriter.create<stablehlo::ConcatenateOp>(
12544+
concatOp.getLoc(), adjOperands, concatDim);
12545+
12546+
// Apply the common padding to get the final result
12547+
auto result = rewriter.create<stablehlo::PadOp>(
12548+
concatOp.getLoc(), newConcatOp, padValue, commonLowPadding,
12549+
commonHighPadding, interiorPadding);
12550+
12551+
rewriter.replaceOp(concatOp, result);
12552+
return success();
12553+
}
12554+
};
12555+
1244312556
struct ConstPadConcatToConcat : public OpRewritePattern<stablehlo::PadOp> {
1244412557
using OpRewritePattern<stablehlo::PadOp>::OpRewritePattern;
1244512558

@@ -12705,10 +12818,9 @@ struct EnzymeHLOOptPass
1270512818
AssociativeBinaryOpReordering<stablehlo::AndOp>,
1270612819
AssociativeBinaryOpReordering<stablehlo::OrOp>>(context);
1270712820

12708-
patterns
12709-
.add<BinopPadToConcat<stablehlo::AddOp>,
12710-
BinopPadToConcat<stablehlo::MulOp>, ConcatPad, PadReduceWindow>(
12711-
context);
12821+
patterns.add<BinopPadToConcat<stablehlo::AddOp>,
12822+
BinopPadToConcat<stablehlo::MulOp>, ConcatPad,
12823+
PadConcatToConcatPad, PadReduceWindow>(context);
1271212824

1271312825
if (passses & 512) {
1271412826
patterns.add<TransposeDotReorder, DotTranspose, ConvolutionTranspose,

src/enzyme_ad/jax/TransformOps/TransformOps.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1222,6 +1222,10 @@ def ConstPadConcatToConcat : EnzymeHLOPatternOp<
12221222
let patterns = ["ConstPadConcatToConcat"];
12231223
}
12241224

1225+
def PadConcatToConcatPad : EnzymeHLOPatternOp<
1226+
"pad_concat_to_concat_pad"> {
1227+
let patterns = ["PadConcatToConcatPad"];
1228+
}
12251229
// TODO: better naming for parameters requires a static interface for
12261230
// constructing them in search.
12271231

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
// RUN: enzymexlamlir-opt %s --pass-pipeline="builtin.module(enzyme-hlo-generate-td{patterns=pad_concat_to_concat_pad},transform-interpreter,enzyme-hlo-remove-transform)" | FileCheck %s
2+
3+
func.func @test_pad_leftover(%arg0 : tensor<128x2031x2032xf64>, %arg1 : tensor<1x2032x2032xf64>, %arg2: tensor<1x2032x2032xf64>) -> tensor<130x2033x2032xf64> {
4+
%cst_29 = stablehlo.constant dense<0.5> : tensor<f64>
5+
%p1 = stablehlo.pad %arg0, %cst_29, low = [0, 1, 0], high = [0, 1, 0], interior = [0, 0, 0] : (tensor<128x2031x2032xf64>, tensor<f64>) -> tensor<128x2033x2032xf64>
6+
%p2 = stablehlo.pad %arg1, %cst_29, low = [0, 1, 0], high = [0, 0, 0], interior = [0, 0, 0] : (tensor<1x2032x2032xf64>, tensor<f64>) -> tensor<1x2033x2032xf64>
7+
%p3 = stablehlo.pad %arg2, %cst_29, low = [0, 1, 0], high = [0, 0, 0], interior = [0, 0, 0] : (tensor<1x2032x2032xf64>, tensor<f64>) -> tensor<1x2033x2032xf64>
8+
9+
%concat = stablehlo.concatenate %p2,%p1,%p3, dim = 0 : (tensor<1x2033x2032xf64>, tensor<128x2033x2032xf64>, tensor<1x2033x2032xf64>) -> tensor<130x2033x2032xf64>
10+
return %concat : tensor<130x2033x2032xf64>
11+
}
12+
13+
14+
// CHECK: func.func @test_pad_leftover(%arg0: tensor<128x2031x2032xf64>, %arg1: tensor<1x2032x2032xf64>, %arg2: tensor<1x2032x2032xf64>) -> tensor<130x2033x2032xf64> {
15+
// CHECK-NEXT: %cst = stablehlo.constant dense<5.000000e-01> : tensor<f64>
16+
// CHECK-NEXT: %0 = stablehlo.pad %arg0, %cst, low = [0, 0, 0], high = [0, 1, 0], interior = [0, 0, 0] : (tensor<128x2031x2032xf64>, tensor<f64>) -> tensor<128x2032x2032xf64>
17+
// CHECK-NEXT: %1 = stablehlo.concatenate %arg1, %0, %arg2, dim = 0 : (tensor<1x2032x2032xf64>, tensor<128x2032x2032xf64>, tensor<1x2032x2032xf64>) -> tensor<130x2032x2032xf64>
18+
// CHECK-NEXT: %2 = stablehlo.pad %1, %cst, low = [0, 1, 0], high = [0, 0, 0], interior = [0, 0, 0] : (tensor<130x2032x2032xf64>, tensor<f64>) -> tensor<130x2033x2032xf64>
19+
// CHECK-NEXT: return %2 : tensor<130x2033x2032xf64>
20+
// CHECK-NEXT: }
21+
22+
func.func @test_pad_clean(%arg0 : tensor<128x2032x2032xf64>, %arg1 : tensor<1x2032x2032xf64>, %arg2: tensor<1x2032x2032xf64>) -> tensor<130x2033x2032xf64> {
23+
%cst_29 = stablehlo.constant dense<0.5> : tensor<f64>
24+
%p1 = stablehlo.pad %arg0, %cst_29, low = [0, 1, 0], high = [0, 0, 0], interior = [0, 0, 0] : (tensor<128x2032x2032xf64>, tensor<f64>) -> tensor<128x2033x2032xf64>
25+
%p2 = stablehlo.pad %arg1, %cst_29, low = [0, 1, 0], high = [0, 0, 0], interior = [0, 0, 0] : (tensor<1x2032x2032xf64>, tensor<f64>) -> tensor<1x2033x2032xf64>
26+
%p3 = stablehlo.pad %arg2, %cst_29, low = [0, 1, 0], high = [0, 0, 0], interior = [0, 0, 0] : (tensor<1x2032x2032xf64>, tensor<f64>) -> tensor<1x2033x2032xf64>
27+
28+
%concat = stablehlo.concatenate %p2,%p1,%p3, dim = 0 : (tensor<1x2033x2032xf64>, tensor<128x2033x2032xf64>, tensor<1x2033x2032xf64>) -> tensor<130x2033x2032xf64>
29+
return %concat : tensor<130x2033x2032xf64>
30+
}
31+
32+
33+
// CHECK-NEXT: func.func @test_pad_clean(%arg0: tensor<128x2032x2032xf64>, %arg1: tensor<1x2032x2032xf64>, %arg2: tensor<1x2032x2032xf64>) -> tensor<130x2033x2032xf64> {
34+
// CHECK-NEXT: %cst = stablehlo.constant dense<5.000000e-01> : tensor<f64>
35+
// CHECK-NEXT: %0 = stablehlo.concatenate %arg1, %arg0, %arg2, dim = 0 : (tensor<1x2032x2032xf64>, tensor<128x2032x2032xf64>, tensor<1x2032x2032xf64>) -> tensor<130x2032x2032xf64>
36+
// CHECK-NEXT: %1 = stablehlo.pad %0, %cst, low = [0, 1, 0], high = [0, 0, 0], interior = [0, 0, 0] : (tensor<130x2032x2032xf64>, tensor<f64>) -> tensor<130x2033x2032xf64>
37+
// CHECK-NEXT: return %1 : tensor<130x2033x2032xf64>
38+
// CHECK-NEXT: }

0 commit comments

Comments
 (0)