@@ -367,112 +367,112 @@ struct VectorSubgroupReduceToShuffles final
367
367
bool matchClustered = false ;
368
368
};
369
369
370
- std::optional<Value> createSubgroupDPPReduction (OpBuilder &b, Location loc,
371
- Value input,
372
- gpu::AllReduceOperation mode,
373
- const ClusterInfo &ci,
374
- amdgpu::Chipset chipset) {
375
- Value result = input;
370
+ FailureOr<Value>
371
+ createSubgroupDPPReduction (PatternRewriter &rewriter, gpu::SubgroupReduceOp &op,
372
+ Value input, gpu::AllReduceOperation mode,
373
+ const ClusterInfo &ci, amdgpu::Chipset chipset) {
374
+ Location loc = op.getLoc ();
375
+ Value dpp;
376
+ Value res = input;
376
377
constexpr int allRows = 0xf ;
377
378
constexpr int allBanks = 0xf ;
378
379
const bool boundCtrl = true ;
379
- Value lane0 =
380
- b.create <arith::ConstantOp>(loc, b.getI32Type (), b.getI32IntegerAttr (0 ));
381
- Value lane32 =
382
- b.create <arith::ConstantOp>(loc, b.getI32Type (), b.getI32IntegerAttr (32 ));
383
-
384
- auto dppReduceAcrossLanes = [&](int numLanes,
385
- Value res) -> std::optional<Value> {
386
- Value dppResult, laneVal;
387
-
388
- switch (numLanes) {
389
- case 2 :
390
- // Perform reduction between all lanes N <-> N+1.
391
- dppResult = b.create <amdgpu::DPPOp>(
392
- loc, res.getType (), res, res, amdgpu::DPPPerm::quad_perm,
393
- b.getI32ArrayAttr ({1 , 0 , 3 , 2 }), allRows, allBanks, boundCtrl);
394
- break ;
395
- case 4 :
396
- // Perform reduction between all lanes N <-> N+2.
397
- dppResult = b.create <amdgpu::DPPOp>(
398
- loc, res.getType (), res, res, amdgpu::DPPPerm::quad_perm,
399
- b.getI32ArrayAttr ({2 , 3 , 0 , 1 }), allRows, allBanks, boundCtrl);
400
- break ;
401
- case 8 :
402
- // Perform reduction between all lanes N <-> 7-N,
403
- // e.g lane[0] <-> lane[7], lane[1] <-> lane[6]..., lane[3] <-> lane[4].
404
- dppResult = b.create <amdgpu::DPPOp>(
405
- loc, res.getType (), res, res, amdgpu::DPPPerm::row_half_mirror,
406
- b.getUnitAttr (), allRows, allBanks, boundCtrl);
407
- break ;
408
- case 16 :
409
- // Perform reduction between all lanes N <-> 15-N,
410
- // e.g lane[0] <-> lane[15], lane[1] <-> lane[14]..., lane[7] <-> lane[8].
411
- dppResult = b.create <amdgpu::DPPOp>(
412
- loc, result.getType (), res, res, amdgpu::DPPPerm::row_mirror,
413
- b.getUnitAttr (), allRows, allBanks, boundCtrl);
414
- break ;
415
- case 32 :
416
- if (chipset.majorVersion <= 9 ) {
417
- // Broadcast last value from each row to next row.
418
- // Use row mask to avoid polluting rows 1 and 3.
419
- dppResult = b.create <amdgpu::DPPOp>(loc, res.getType (), res, res,
420
- amdgpu::DPPPerm::row_bcast_15,
421
- b.getUnitAttr (), 0xa , allBanks,
422
- /* bound_ctrl*/ false );
423
- } else if (chipset.majorVersion <= 12 ) {
424
- // Use a permute lane to cross rows (row 1 <-> row 0, row 3 <-> row 2).
425
- dppResult = b.create <ROCDL::PermlaneX16Op>(loc, res.getType (), res, res,
426
- -1 , -1 , /* fi=*/ true ,
427
- /* bound_ctrl=*/ false );
428
- if (ci.subgroupSize == 32 ) {
429
- dppResult =
430
- b.create <ROCDL::ReadlaneOp>(loc, res.getType (), res, lane0);
431
- }
432
- } else {
433
- return std::nullopt;
434
- }
435
- break ;
436
- case 64 :
437
- if (chipset.majorVersion <= 9 ) {
438
- // Broadcast 31st lane value to rows 2 and 3.
439
- // Use row mask to avoid polluting rows 0 and 1.
440
- dppResult = b.create <amdgpu::DPPOp>(loc, res.getType (), res, res,
441
- amdgpu::DPPPerm::row_bcast_31,
442
- b.getUnitAttr (), 0xc , allBanks,
443
- /* bound_ctrl*/ false );
444
- } else if (chipset.majorVersion <= 12 ) {
445
- // Assume reduction across 32 lanes has been done.
446
- // Perform final reduction manually by summing values in lane 0 and
447
- // lane 32.
448
- dppResult =
449
- b.create <ROCDL::ReadlaneOp>(loc, res.getType (), res, lane32);
450
- laneVal = b.create <ROCDL::ReadlaneOp>(loc, res.getType (), res, lane0);
451
- return vector::makeArithReduction (
452
- b, loc, gpu::convertReductionKind (mode), dppResult, laneVal);
453
- } else {
454
- return std::nullopt;
380
+ if (ci.clusterSize >= 2 ) {
381
+ // Perform reduction between all lanes N <-> N+1.
382
+ dpp = rewriter.create <amdgpu::DPPOp>(
383
+ loc, res.getType (), res, res, amdgpu::DPPPerm::quad_perm,
384
+ rewriter.getI32ArrayAttr ({1 , 0 , 3 , 2 }), allRows, allBanks, boundCtrl);
385
+ res = vector::makeArithReduction (rewriter, loc,
386
+ gpu::convertReductionKind (mode), res, dpp);
387
+ }
388
+
389
+ if (ci.clusterSize >= 4 ) {
390
+ // Perform reduction between all lanes N <-> N+2.
391
+ dpp = rewriter.create <amdgpu::DPPOp>(
392
+ loc, res.getType (), res, res, amdgpu::DPPPerm::quad_perm,
393
+ rewriter.getI32ArrayAttr ({2 , 3 , 0 , 1 }), allRows, allBanks, boundCtrl);
394
+ res = vector::makeArithReduction (rewriter, loc,
395
+ gpu::convertReductionKind (mode), res, dpp);
396
+ }
397
+ if (ci.clusterSize >= 8 ) {
398
+ // Perform reduction between all lanes N <-> 7-N,
399
+ // e.g lane[0] <-> lane[7], lane[1] <-> lane[6]..., lane[3] <-> lane[4].
400
+ dpp = rewriter.create <amdgpu::DPPOp>(
401
+ loc, res.getType (), res, res, amdgpu::DPPPerm::row_half_mirror,
402
+ rewriter.getUnitAttr (), allRows, allBanks, boundCtrl);
403
+ res = vector::makeArithReduction (rewriter, loc,
404
+ gpu::convertReductionKind (mode), res, dpp);
405
+ }
406
+ if (ci.clusterSize >= 16 ) {
407
+ // Perform reduction between all lanes N <-> 15-N,
408
+ // e.g lane[0] <-> lane[15], lane[1] <-> lane[14]..., lane[7] <-> lane[8].
409
+ dpp = rewriter.create <amdgpu::DPPOp>(
410
+ loc, res.getType (), res, res, amdgpu::DPPPerm::row_mirror,
411
+ rewriter.getUnitAttr (), allRows, allBanks, boundCtrl);
412
+ res = vector::makeArithReduction (rewriter, loc,
413
+ gpu::convertReductionKind (mode), res, dpp);
414
+ }
415
+ if (ci.clusterSize >= 32 ) {
416
+ if (chipset.majorVersion <= 9 ) {
417
+ // Broadcast last value from each row to next row.
418
+ // Use row mask to avoid polluting rows 1 and 3.
419
+ dpp = rewriter.create <amdgpu::DPPOp>(
420
+ loc, res.getType (), res, res, amdgpu::DPPPerm::row_bcast_15,
421
+ rewriter.getUnitAttr (), 0xa , allBanks,
422
+ /* bound_ctrl*/ false );
423
+ res = vector::makeArithReduction (
424
+ rewriter, loc, gpu::convertReductionKind (mode), res, dpp);
425
+ } else if (chipset.majorVersion <= 12 ) {
426
+ // Use a permute lane to cross rows (row 1 <-> row 0, row 3 <-> row 2).
427
+ Value uint32Max = rewriter.create <arith::ConstantOp>(
428
+ loc, rewriter.getI32Type (), rewriter.getI32IntegerAttr (-1 ));
429
+ dpp = rewriter.create <ROCDL::PermlaneX16Op>(loc, res.getType (), res, res,
430
+ uint32Max, uint32Max,
431
+ /* fi=*/ true ,
432
+ /* bound_ctrl=*/ false );
433
+ res = vector::makeArithReduction (
434
+ rewriter, loc, gpu::convertReductionKind (mode), res, dpp);
435
+ if (ci.subgroupSize == 32 ) {
436
+ Value lane0 = rewriter.create <arith::ConstantOp>(
437
+ loc, rewriter.getI32Type (), rewriter.getI32IntegerAttr (0 ));
438
+ dpp =
439
+ rewriter.create <ROCDL::ReadlaneOp>(loc, res.getType (), res, lane0);
455
440
}
456
- break ;
457
- default :
458
- // Should never reach here given previous validation of ClusterInfo.
459
- llvm_unreachable (" ERROR: Unexpected cluster size." );
460
- return std::nullopt;
441
+ } else {
442
+ return rewriter.notifyMatchFailure (
443
+ op, " Subgroup reduce lowering to DPP not currently supported for "
444
+ " this device." );
461
445
}
462
- return vector::makeArithReduction (b, loc, gpu::convertReductionKind (mode),
463
- res, dppResult);
464
- };
465
-
466
- for (unsigned cs = 2 ; cs <= ci.clusterSize ; cs <<= 1 ) {
467
- if (auto dpp = dppReduceAcrossLanes (cs, result)) {
468
- result = *dpp;
469
- continue ;
446
+ }
447
+ if (ci.clusterSize >= 64 ) {
448
+ if (chipset.majorVersion <= 9 ) {
449
+ // Broadcast 31st lane value to rows 2 and 3.
450
+ // Use row mask to avoid polluting rows 0 and 1.
451
+ dpp = rewriter.create <amdgpu::DPPOp>(
452
+ loc, res.getType (), res, res, amdgpu::DPPPerm::row_bcast_31,
453
+ rewriter.getUnitAttr (), 0xc , allBanks,
454
+ /* bound_ctrl*/ false );
455
+
456
+ } else if (chipset.majorVersion <= 12 ) {
457
+ // Assume reduction across 32 lanes has been done.
458
+ // Perform final reduction manually by summing values in lane 0 and
459
+ // lane 32.
460
+ Value lane0 = rewriter.create <arith::ConstantOp>(
461
+ loc, rewriter.getI32Type (), rewriter.getI32IntegerAttr (0 ));
462
+ Value lane32 = rewriter.create <arith::ConstantOp>(
463
+ loc, rewriter.getI32Type (), rewriter.getI32IntegerAttr (32 ));
464
+ dpp = rewriter.create <ROCDL::ReadlaneOp>(loc, res.getType (), res, lane32);
465
+ res = rewriter.create <ROCDL::ReadlaneOp>(loc, res.getType (), res, lane0);
466
+ } else {
467
+ return rewriter.notifyMatchFailure (
468
+ op, " Subgroup reduce lowering to DPP not currently supported for "
469
+ " this device." );
470
470
}
471
- return std::nullopt;
471
+ res = vector::makeArithReduction (rewriter, loc,
472
+ gpu::convertReductionKind (mode), res, dpp);
472
473
}
473
-
474
- assert (result.getType () == input.getType ());
475
- return result;
474
+ assert (res.getType () == input.getType ());
475
+ return res;
476
476
}
477
477
478
478
// / Collect a set of patterns to lower `gpu.subgroup_reduce` into `amdgpu.dpp`
@@ -500,22 +500,21 @@ struct ScalarSubgroupReduceToDPP final
500
500
return failure ();
501
501
502
502
if (ci->clusterStride != 1 )
503
- return failure ();
503
+ return rewriter.notifyMatchFailure (
504
+ op, " Supgroup reductions using DPP are currently only available for "
505
+ " clusters of contiguous lanes." );
504
506
505
507
Type valueTy = op.getType ();
506
508
if (!valueTy.isIntOrFloat ())
507
509
return rewriter.notifyMatchFailure (
508
510
op, " value type is not a compatible scalar" );
509
511
510
- Location loc = op.getLoc ();
511
- std::optional<Value> dpp = createSubgroupDPPReduction (
512
- rewriter, loc, op.getValue (), op.getOp (), *ci, chipset);
513
- if (!dpp)
514
- return rewriter.notifyMatchFailure (
515
- op, " Subgroup reduce lowering to DPP not currently supported for "
516
- " this device." );
512
+ FailureOr<Value> dpp = createSubgroupDPPReduction (
513
+ rewriter, op, op.getValue (), op.getOp (), *ci, chipset);
514
+ if (failed (dpp))
515
+ return failure ();
517
516
518
- rewriter.replaceOp (op, * dpp);
517
+ rewriter.replaceOp (op, dpp. value () );
519
518
return success ();
520
519
}
521
520
0 commit comments