Skip to content

Commit 50e3616

Browse files
Improve dpp implementation
Signed-off-by: Muzammiluddin Syed <muzasyed@amd.com>
1 parent 987c4d7 commit 50e3616

File tree

1 file changed

+107
-108
lines changed

1 file changed

+107
-108
lines changed

mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp

Lines changed: 107 additions & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -367,112 +367,112 @@ struct VectorSubgroupReduceToShuffles final
367367
bool matchClustered = false;
368368
};
369369

370-
std::optional<Value> createSubgroupDPPReduction(OpBuilder &b, Location loc,
371-
Value input,
372-
gpu::AllReduceOperation mode,
373-
const ClusterInfo &ci,
374-
amdgpu::Chipset chipset) {
375-
Value result = input;
370+
FailureOr<Value>
371+
createSubgroupDPPReduction(PatternRewriter &rewriter, gpu::SubgroupReduceOp &op,
372+
Value input, gpu::AllReduceOperation mode,
373+
const ClusterInfo &ci, amdgpu::Chipset chipset) {
374+
Location loc = op.getLoc();
375+
Value dpp;
376+
Value res = input;
376377
constexpr int allRows = 0xf;
377378
constexpr int allBanks = 0xf;
378379
const bool boundCtrl = true;
379-
Value lane0 =
380-
b.create<arith::ConstantOp>(loc, b.getI32Type(), b.getI32IntegerAttr(0));
381-
Value lane32 =
382-
b.create<arith::ConstantOp>(loc, b.getI32Type(), b.getI32IntegerAttr(32));
383-
384-
auto dppReduceAcrossLanes = [&](int numLanes,
385-
Value res) -> std::optional<Value> {
386-
Value dppResult, laneVal;
387-
388-
switch (numLanes) {
389-
case 2:
390-
// Perform reduction between all lanes N <-> N+1.
391-
dppResult = b.create<amdgpu::DPPOp>(
392-
loc, res.getType(), res, res, amdgpu::DPPPerm::quad_perm,
393-
b.getI32ArrayAttr({1, 0, 3, 2}), allRows, allBanks, boundCtrl);
394-
break;
395-
case 4:
396-
// Perform reduction between all lanes N <-> N+2.
397-
dppResult = b.create<amdgpu::DPPOp>(
398-
loc, res.getType(), res, res, amdgpu::DPPPerm::quad_perm,
399-
b.getI32ArrayAttr({2, 3, 0, 1}), allRows, allBanks, boundCtrl);
400-
break;
401-
case 8:
402-
// Perform reduction between all lanes N <-> 7-N,
403-
// e.g lane[0] <-> lane[7], lane[1] <-> lane[6]..., lane[3] <-> lane[4].
404-
dppResult = b.create<amdgpu::DPPOp>(
405-
loc, res.getType(), res, res, amdgpu::DPPPerm::row_half_mirror,
406-
b.getUnitAttr(), allRows, allBanks, boundCtrl);
407-
break;
408-
case 16:
409-
// Perform reduction between all lanes N <-> 15-N,
410-
// e.g lane[0] <-> lane[15], lane[1] <-> lane[14]..., lane[7] <-> lane[8].
411-
dppResult = b.create<amdgpu::DPPOp>(
412-
loc, result.getType(), res, res, amdgpu::DPPPerm::row_mirror,
413-
b.getUnitAttr(), allRows, allBanks, boundCtrl);
414-
break;
415-
case 32:
416-
if (chipset.majorVersion <= 9) {
417-
// Broadcast last value from each row to next row.
418-
// Use row mask to avoid polluting rows 1 and 3.
419-
dppResult = b.create<amdgpu::DPPOp>(loc, res.getType(), res, res,
420-
amdgpu::DPPPerm::row_bcast_15,
421-
b.getUnitAttr(), 0xa, allBanks,
422-
/*bound_ctrl*/ false);
423-
} else if (chipset.majorVersion <= 12) {
424-
// Use a permute lane to cross rows (row 1 <-> row 0, row 3 <-> row 2).
425-
dppResult = b.create<ROCDL::PermlaneX16Op>(loc, res.getType(), res, res,
426-
-1, -1, /*fi=*/true,
427-
/*bound_ctrl=*/false);
428-
if (ci.subgroupSize == 32) {
429-
dppResult =
430-
b.create<ROCDL::ReadlaneOp>(loc, res.getType(), res, lane0);
431-
}
432-
} else {
433-
return std::nullopt;
434-
}
435-
break;
436-
case 64:
437-
if (chipset.majorVersion <= 9) {
438-
// Broadcast 31st lane value to rows 2 and 3.
439-
// Use row mask to avoid polluting rows 0 and 1.
440-
dppResult = b.create<amdgpu::DPPOp>(loc, res.getType(), res, res,
441-
amdgpu::DPPPerm::row_bcast_31,
442-
b.getUnitAttr(), 0xc, allBanks,
443-
/*bound_ctrl*/ false);
444-
} else if (chipset.majorVersion <= 12) {
445-
// Assume reduction across 32 lanes has been done.
446-
// Perform final reduction manually by summing values in lane 0 and
447-
// lane 32.
448-
dppResult =
449-
b.create<ROCDL::ReadlaneOp>(loc, res.getType(), res, lane32);
450-
laneVal = b.create<ROCDL::ReadlaneOp>(loc, res.getType(), res, lane0);
451-
return vector::makeArithReduction(
452-
b, loc, gpu::convertReductionKind(mode), dppResult, laneVal);
453-
} else {
454-
return std::nullopt;
380+
if (ci.clusterSize >= 2) {
381+
// Perform reduction between all lanes N <-> N+1.
382+
dpp = rewriter.create<amdgpu::DPPOp>(
383+
loc, res.getType(), res, res, amdgpu::DPPPerm::quad_perm,
384+
rewriter.getI32ArrayAttr({1, 0, 3, 2}), allRows, allBanks, boundCtrl);
385+
res = vector::makeArithReduction(rewriter, loc,
386+
gpu::convertReductionKind(mode), res, dpp);
387+
}
388+
389+
if (ci.clusterSize >= 4) {
390+
// Perform reduction between all lanes N <-> N+2.
391+
dpp = rewriter.create<amdgpu::DPPOp>(
392+
loc, res.getType(), res, res, amdgpu::DPPPerm::quad_perm,
393+
rewriter.getI32ArrayAttr({2, 3, 0, 1}), allRows, allBanks, boundCtrl);
394+
res = vector::makeArithReduction(rewriter, loc,
395+
gpu::convertReductionKind(mode), res, dpp);
396+
}
397+
if (ci.clusterSize >= 8) {
398+
// Perform reduction between all lanes N <-> 7-N,
399+
// e.g lane[0] <-> lane[7], lane[1] <-> lane[6]..., lane[3] <-> lane[4].
400+
dpp = rewriter.create<amdgpu::DPPOp>(
401+
loc, res.getType(), res, res, amdgpu::DPPPerm::row_half_mirror,
402+
rewriter.getUnitAttr(), allRows, allBanks, boundCtrl);
403+
res = vector::makeArithReduction(rewriter, loc,
404+
gpu::convertReductionKind(mode), res, dpp);
405+
}
406+
if (ci.clusterSize >= 16) {
407+
// Perform reduction between all lanes N <-> 15-N,
408+
// e.g lane[0] <-> lane[15], lane[1] <-> lane[14]..., lane[7] <-> lane[8].
409+
dpp = rewriter.create<amdgpu::DPPOp>(
410+
loc, res.getType(), res, res, amdgpu::DPPPerm::row_mirror,
411+
rewriter.getUnitAttr(), allRows, allBanks, boundCtrl);
412+
res = vector::makeArithReduction(rewriter, loc,
413+
gpu::convertReductionKind(mode), res, dpp);
414+
}
415+
if (ci.clusterSize >= 32) {
416+
if (chipset.majorVersion <= 9) {
417+
// Broadcast last value from each row to next row.
418+
// Use row mask to avoid polluting rows 1 and 3.
419+
dpp = rewriter.create<amdgpu::DPPOp>(
420+
loc, res.getType(), res, res, amdgpu::DPPPerm::row_bcast_15,
421+
rewriter.getUnitAttr(), 0xa, allBanks,
422+
/*bound_ctrl*/ false);
423+
res = vector::makeArithReduction(
424+
rewriter, loc, gpu::convertReductionKind(mode), res, dpp);
425+
} else if (chipset.majorVersion <= 12) {
426+
// Use a permute lane to cross rows (row 1 <-> row 0, row 3 <-> row 2).
427+
Value uint32Max = rewriter.create<arith::ConstantOp>(
428+
loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(-1));
429+
dpp = rewriter.create<ROCDL::PermlaneX16Op>(loc, res.getType(), res, res,
430+
uint32Max, uint32Max,
431+
/*fi=*/true,
432+
/*bound_ctrl=*/false);
433+
res = vector::makeArithReduction(
434+
rewriter, loc, gpu::convertReductionKind(mode), res, dpp);
435+
if (ci.subgroupSize == 32) {
436+
Value lane0 = rewriter.create<arith::ConstantOp>(
437+
loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(0));
438+
dpp =
439+
rewriter.create<ROCDL::ReadlaneOp>(loc, res.getType(), res, lane0);
455440
}
456-
break;
457-
default:
458-
// Should never reach here given previous validation of ClusterInfo.
459-
llvm_unreachable("ERROR: Unexpected cluster size.");
460-
return std::nullopt;
441+
} else {
442+
return rewriter.notifyMatchFailure(
443+
op, "Subgroup reduce lowering to DPP not currently supported for "
444+
"this device.");
461445
}
462-
return vector::makeArithReduction(b, loc, gpu::convertReductionKind(mode),
463-
res, dppResult);
464-
};
465-
466-
for (unsigned cs = 2; cs <= ci.clusterSize; cs <<= 1) {
467-
if (auto dpp = dppReduceAcrossLanes(cs, result)) {
468-
result = *dpp;
469-
continue;
446+
}
447+
if (ci.clusterSize >= 64) {
448+
if (chipset.majorVersion <= 9) {
449+
// Broadcast 31st lane value to rows 2 and 3.
450+
// Use row mask to avoid polluting rows 0 and 1.
451+
dpp = rewriter.create<amdgpu::DPPOp>(
452+
loc, res.getType(), res, res, amdgpu::DPPPerm::row_bcast_31,
453+
rewriter.getUnitAttr(), 0xc, allBanks,
454+
/*bound_ctrl*/ false);
455+
456+
} else if (chipset.majorVersion <= 12) {
457+
// Assume reduction across 32 lanes has been done.
458+
// Perform final reduction manually by summing values in lane 0 and
459+
// lane 32.
460+
Value lane0 = rewriter.create<arith::ConstantOp>(
461+
loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(0));
462+
Value lane32 = rewriter.create<arith::ConstantOp>(
463+
loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(32));
464+
dpp = rewriter.create<ROCDL::ReadlaneOp>(loc, res.getType(), res, lane32);
465+
res = rewriter.create<ROCDL::ReadlaneOp>(loc, res.getType(), res, lane0);
466+
} else {
467+
return rewriter.notifyMatchFailure(
468+
op, "Subgroup reduce lowering to DPP not currently supported for "
469+
"this device.");
470470
}
471-
return std::nullopt;
471+
res = vector::makeArithReduction(rewriter, loc,
472+
gpu::convertReductionKind(mode), res, dpp);
472473
}
473-
474-
assert(result.getType() == input.getType());
475-
return result;
474+
assert(res.getType() == input.getType());
475+
return res;
476476
}
477477

478478
/// Collect a set of patterns to lower `gpu.subgroup_reduce` into `amdgpu.dpp`
@@ -500,22 +500,21 @@ struct ScalarSubgroupReduceToDPP final
500500
return failure();
501501

502502
if (ci->clusterStride != 1)
503-
return failure();
503+
return rewriter.notifyMatchFailure(
504+
op, "Supgroup reductions using DPP are currently only available for "
505+
"clusters of contiguous lanes.");
504506

505507
Type valueTy = op.getType();
506508
if (!valueTy.isIntOrFloat())
507509
return rewriter.notifyMatchFailure(
508510
op, "value type is not a compatible scalar");
509511

510-
Location loc = op.getLoc();
511-
std::optional<Value> dpp = createSubgroupDPPReduction(
512-
rewriter, loc, op.getValue(), op.getOp(), *ci, chipset);
513-
if (!dpp)
514-
return rewriter.notifyMatchFailure(
515-
op, "Subgroup reduce lowering to DPP not currently supported for "
516-
"this device.");
512+
FailureOr<Value> dpp = createSubgroupDPPReduction(
513+
rewriter, op, op.getValue(), op.getOp(), *ci, chipset);
514+
if (failed(dpp))
515+
return failure();
517516

518-
rewriter.replaceOp(op, *dpp);
517+
rewriter.replaceOp(op, dpp.value());
519518
return success();
520519
}
521520

0 commit comments

Comments
 (0)