Skip to content

Commit da5cb46

Browse files
authored
Check for parallel IV in affineMapToSlice (EnzymeAD#658)
* Check for parallel IV in affineMapToSlice * remove log * fmt * fmt2 * Update test * remove duplicate test * literal
1 parent fbb7296 commit da5cb46

File tree

2 files changed

+62
-58
lines changed

2 files changed

+62
-58
lines changed

src/enzyme_ad/jax/Passes/AffineToStableHLORaising.cpp

Lines changed: 50 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -200,53 +200,6 @@ emitIVToStableHLO(OpBuilder &builder, Value iv, InductionVariableRange range,
200200
maps[iota] = accessMap;
201201
}
202202

203-
// Given an affine map for a load/store operation, compute the startIndices,
204-
// limitIndices and strides corresponding in the memref based on the loop
205-
// induction variables.
206-
//
207-
// (i) -> (0, i, 10) will give [0:1:1, begin:end:step, 10:11:1]
208-
// (i) -> (2 * i, i + 2, 10) will give [begin*2:end*2:2*step,
209-
// begin+2:end+2:step, 10:11:1]
210-
//
211-
// with begin:end:step corresponding to the range of the iv i.
212-
static LogicalResult affineMapToSlice(affine::AffineValueMap accessValueMap,
213-
SmallVectorImpl<int64_t> &strides,
214-
SmallVectorImpl<int64_t> &reverseDims) {
215-
auto rank = accessValueMap.getNumResults();
216-
217-
strides.reserve(rank);
218-
219-
for (unsigned i = 0; i < rank; i++) {
220-
auto expr = accessValueMap.getResult(i);
221-
222-
if (auto constExpr = dyn_cast<AffineConstantExpr>(expr)) {
223-
strides.push_back(1);
224-
continue;
225-
}
226-
227-
Value iv = getIVForExpr(accessValueMap, expr);
228-
if (affine::isAffineForInductionVar(iv)) {
229-
strides.push_back(1);
230-
continue;
231-
}
232-
233-
auto range = computeExprRange(accessValueMap, expr);
234-
235-
if (!range.has_value())
236-
return failure();
237-
238-
if (range->step < 0) {
239-
// 0:-1:-180 -> -179:1:1
240-
strides.push_back(-range->step);
241-
reverseDims.push_back(i);
242-
} else {
243-
strides.push_back(range->step);
244-
}
245-
}
246-
247-
return success();
248-
}
249-
250203
// The name is parallel context but a more accurate description would be
251204
// LockStepContext
252205
struct ParallelContext {
@@ -337,6 +290,54 @@ struct ParallelContext {
337290
}
338291
};
339292

293+
// Given an affine map for a load/store operation, compute the startIndices,
294+
// limitIndices and strides corresponding in the memref based on the loop
295+
// induction variables.
296+
//
297+
// (i) -> (0, i, 10) will give [0:1:1, begin:end:step, 10:11:1]
298+
// (i) -> (2 * i, i + 2, 10) will give [begin*2:end*2:2*step,
299+
// begin+2:end+2:step, 10:11:1]
300+
//
301+
// with begin:end:step corresponding to the range of the iv i.
302+
static LogicalResult affineMapToSlice(affine::AffineValueMap accessValueMap,
303+
SmallVectorImpl<int64_t> &strides,
304+
SmallVectorImpl<int64_t> &reverseDims,
305+
ParallelContext pc) {
306+
auto rank = accessValueMap.getNumResults();
307+
308+
strides.reserve(rank);
309+
310+
for (unsigned i = 0; i < rank; i++) {
311+
auto expr = accessValueMap.getResult(i);
312+
313+
if (auto constExpr = dyn_cast<AffineConstantExpr>(expr)) {
314+
strides.push_back(1);
315+
continue;
316+
}
317+
318+
Value iv = getIVForExpr(accessValueMap, expr);
319+
if (affine::isAffineForInductionVar(iv) && !pc.isParallelIV(iv)) {
320+
strides.push_back(1);
321+
continue;
322+
}
323+
324+
auto range = computeExprRange(accessValueMap, expr);
325+
326+
if (!range.has_value())
327+
return failure();
328+
329+
if (range->step < 0) {
330+
// 0:-1:-180 -> -179:1:1
331+
strides.push_back(-range->step);
332+
reverseDims.push_back(i);
333+
} else {
334+
strides.push_back(range->step);
335+
}
336+
}
337+
338+
return success();
339+
}
340+
340341
static SmallVector<int64_t>
341342
affineMapShape(affine::AffineValueMap accessValueMap, ParallelContext pc) {
342343
AffineMap map = accessValueMap.getAffineMap();
@@ -1597,7 +1598,7 @@ tryRaisingOpToStableHLO(Operation *op, IRMapping &mapping, OpBuilder &builder,
15971598
SmallVector<int64_t> strides;
15981599
SmallVector<int64_t> reverseDims;
15991600

1600-
if (affineMapToSlice(accessValueMap, strides, reverseDims).failed()) {
1601+
if (affineMapToSlice(accessValueMap, strides, reverseDims, pc).failed()) {
16011602
LLVM_DEBUG(llvm::dbgs()
16021603
<< "Failed to affine map to slice: " << *op << "\n");
16031604
return failure();
@@ -1733,7 +1734,7 @@ tryRaisingOpToStableHLO(Operation *op, IRMapping &mapping, OpBuilder &builder,
17331734
SmallVector<int64_t> strides;
17341735
SmallVector<int64_t> reverseDims;
17351736

1736-
if (affineMapToSlice(accessValueMap, strides, reverseDims).failed()) {
1737+
if (affineMapToSlice(accessValueMap, strides, reverseDims, pc).failed()) {
17371738
LLVM_DEBUG(llvm::dbgs()
17381739
<< "Failed to affine map to slice: " << *op << "\n");
17391740
return failure();

test/lit_tests/raising/affine_to_stablehlo_forred2.mlir

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -36,15 +36,18 @@ module {
3636
// CHECK-NEXT: %5 = stablehlo.reshape %4 : (tensor<32x16xf64>) -> tensor<1x32x16xf64>
3737
// CHECK-NEXT: %6 = stablehlo.slice %arg0 [8:9, 0:32, 0:16] : (tensor<9x32x16xf64>) -> tensor<1x32x16xf64>
3838
// CHECK-NEXT: %7 = stablehlo.slice %arg1 [0:7, 0:32, 0:16] : (tensor<9x32x16xf64>) -> tensor<7x32x16xf64>
39-
// CHECK-NEXT: %8 = stablehlo.slice %arg1 [1:8, 0:32, 0:16] : (tensor<9x32x16xf64>) -> tensor<7x32x16xf64>
40-
// CHECK-NEXT: %9 = arith.addf %8, %7 : tensor<7x32x16xf64>
41-
// CHECK-NEXT: %10 = stablehlo.broadcast_in_dim %4, dims = [1, 2] : (tensor<32x16xf64>) -> tensor<7x32x16xf64>
42-
// CHECK-NEXT{LITERAL}: %11 = "stablehlo.reduce_window"(%9, %cst) <{base_dilations = array<i64: 1, 1, 1>, padding = dense<[[6, 0], [0, 0], [0, 0]]> : tensor<3x2xi64>, window_dilations = array<i64: 1, 1, 1>, window_dimensions = array<i64: 7, 1, 1>, window_strides = array<i64: 1, 1, 1>}> ({
39+
// CHECK-NEXT: %8 = stablehlo.reverse %7, dims = [0] : tensor<7x32x16xf64>
40+
// CHECK-NEXT: %9 = stablehlo.slice %arg1 [1:8, 0:32, 0:16] : (tensor<9x32x16xf64>) -> tensor<7x32x16xf64>
41+
// CHECK-NEXT: %10 = stablehlo.reverse %9, dims = [0] : tensor<7x32x16xf64>
42+
// CHECK-NEXT: %11 = arith.addf %10, %8 : tensor<7x32x16xf64>
43+
// CHECK-NEXT: %12 = stablehlo.broadcast_in_dim %4, dims = [1, 2] : (tensor<32x16xf64>) -> tensor<7x32x16xf64>
44+
// CHECK-NEXT{LITERAL}: %13 = "stablehlo.reduce_window"(%11, %cst) <{base_dilations = array<i64: 1, 1, 1>, padding = dense<[[6, 0], [0, 0], [0, 0]]> : tensor<3x2xi64>, window_dilations = array<i64: 1, 1, 1>, window_dimensions = array<i64: 7, 1, 1>, window_strides = array<i64: 1, 1, 1>}> ({
4345
// CHECK-NEXT: ^bb0(%arg2: tensor<f64>, %arg3: tensor<f64>):
44-
// CHECK-NEXT: %14 = stablehlo.add %arg2, %arg3 : tensor<f64>
45-
// CHECK-NEXT: stablehlo.return %14 : tensor<f64>
46+
// CHECK-NEXT: %17 = stablehlo.add %arg2, %arg3 : tensor<f64>
47+
// CHECK-NEXT: stablehlo.return %17 : tensor<f64>
4648
// CHECK-NEXT: }) : (tensor<7x32x16xf64>, tensor<f64>) -> tensor<7x32x16xf64>
47-
// CHECK-NEXT: %12 = stablehlo.add %11, %10 : tensor<7x32x16xf64>
48-
// CHECK-NEXT: %13 = stablehlo.concatenate %12, %5, %6, dim = 0 : (tensor<7x32x16xf64>, tensor<1x32x16xf64>, tensor<1x32x16xf64>) -> tensor<9x32x16xf64>
49-
// CHECK-NEXT: return %13, %arg1 : tensor<9x32x16xf64>, tensor<9x32x16xf64>
49+
// CHECK-NEXT: %14 = stablehlo.add %13, %12 : tensor<7x32x16xf64>
50+
// CHECK-NEXT: %15 = stablehlo.reverse %14, dims = [0] : tensor<7x32x16xf64>
51+
// CHECK-NEXT: %16 = stablehlo.concatenate %15, %5, %6, dim = 0 : (tensor<7x32x16xf64>, tensor<1x32x16xf64>, tensor<1x32x16xf64>) -> tensor<9x32x16xf64>
52+
// CHECK-NEXT: return %16, %arg1 : tensor<9x32x16xf64>, tensor<9x32x16xf64>
5053
// CHECK-NEXT: }

0 commit comments

Comments
 (0)