Skip to content

Commit 0804843

Browse files
committed
Rebase and address review
1 parent d7baf48 commit 0804843

File tree

2 files changed

+101
-21
lines changed

2 files changed

+101
-21
lines changed

llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp

Lines changed: 16 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -3653,28 +3653,23 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red,
36533653
VPValue *&ValB, VPWidenRecipe *Mul) {
36543654
if (ExtA && !ExtB && ValB->isLiveIn()) {
36553655
Type *NarrowTy = Ctx.Types.inferScalarType(ExtA->getOperand(0));
3656-
Type *WideTy = Ctx.Types.inferScalarType(ExtA);
36573656
Instruction::CastOps ExtOpc = ExtA->getOpcode();
3658-
auto *Const = dyn_cast<ConstantInt>(ValB->getLiveInIRValue());
3659-
if (Const &&
3660-
llvm::canConstantBeExtended(
3661-
Const, NarrowTy, TTI::getPartialReductionExtendKind(ExtOpc))) {
3662-
// The truncate ensures that the type of each extended operand is the
3663-
// same, and it's been proven that the constant can be extended from
3664-
// NarrowTy safely. Necessary since ExtA's extended operand would be
3665-
// e.g. an i8, while the const will likely be an i32. This will be
3666-
// elided by later optimisations.
3667-
auto *Trunc =
3668-
new VPWidenCastRecipe(Instruction::CastOps::Trunc, ValB, NarrowTy);
3669-
Trunc->insertBefore(*ExtA->getParent(), std::next(ExtA->getIterator()));
3670-
3671-
VPWidenCastRecipe *NewCast =
3672-
new VPWidenCastRecipe(ExtOpc, Trunc, WideTy);
3673-
NewCast->insertAfter(Trunc);
3674-
ExtB = NewCast;
3675-
ValB = NewCast;
3676-
Mul->setOperand(1, NewCast);
3677-
}
3657+
const APInt *Const;
3658+
if (!match(ValB, m_APInt(Const)) ||
3659+
!llvm::canConstantBeExtended(
3660+
Const, NarrowTy, TTI::getPartialReductionExtendKind(ExtOpc)))
3661+
return;
3662+
// The truncate ensures that the type of each extended operand is the
3663+
// same, and it's been proven that the constant can be extended from
3664+
// NarrowTy safely. Necessary since ExtA's extended operand would be
3665+
// e.g. an i8, while the const will likely be an i32. This will be
3666+
// elided by later optimisations.
3667+
VPBuilder Builder(Mul);
3668+
auto *Trunc =
3669+
Builder.createWidenCast(Instruction::CastOps::Trunc, ValB, NarrowTy);
3670+
Type *WideTy = Ctx.Types.inferScalarType(ExtA);
3671+
ValB = ExtB = Builder.createWidenCast(ExtOpc, Trunc, WideTy);
3672+
Mul->setOperand(1, ExtB);
36783673
}
36793674
};
36803675

llvm/test/Transforms/LoopVectorize/reduction-inloop.ll

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2800,6 +2800,91 @@ exit:
28002800
ret i64 %r.0.lcssa
28012801
}
28022802

2803+
define i32 @reduction_expression_ext_mulacc_livein(ptr %a, ptr %b, i16 %c) {
2804+
; CHECK-LABEL: define i32 @reduction_expression_ext_mulacc_livein(
2805+
; CHECK-SAME: ptr [[A:%.*]], ptr [[B:%.*]], i16 [[C:%.*]]) {
2806+
; CHECK-NEXT: [[ENTRY:.*:]]
2807+
; CHECK-NEXT: br label %[[VECTOR_PH:.*]]
2808+
; CHECK: [[VECTOR_PH]]:
2809+
; CHECK-NEXT: [[BROADCAST_SPLATINSERT:%.*]] = insertelement <4 x i16> poison, i16 [[C]], i64 0
2810+
; CHECK-NEXT: [[BROADCAST_SPLAT:%.*]] = shufflevector <4 x i16> [[BROADCAST_SPLATINSERT]], <4 x i16> poison, <4 x i32> zeroinitializer
2811+
; CHECK-NEXT: br label %[[VECTOR_BODY:.*]]
2812+
; CHECK: [[VECTOR_BODY]]:
2813+
; CHECK-NEXT: [[INDEX:%.*]] = phi i64 [ 0, %[[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], %[[VECTOR_BODY]] ]
2814+
; CHECK-NEXT: [[VEC_PHI:%.*]] = phi i32 [ 0, %[[VECTOR_PH]] ], [ [[TMP5:%.*]], %[[VECTOR_BODY]] ]
2815+
; CHECK-NEXT: [[TMP0:%.*]] = getelementptr i8, ptr [[A]], i64 [[INDEX]]
2816+
; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <4 x i8>, ptr [[TMP0]], align 1
2817+
; CHECK-NEXT: [[TMP1:%.*]] = zext <4 x i8> [[WIDE_LOAD]] to <4 x i16>
2818+
; CHECK-NEXT: [[TMP2:%.*]] = mul <4 x i16> [[BROADCAST_SPLAT]], [[TMP1]]
2819+
; CHECK-NEXT: [[TMP3:%.*]] = zext <4 x i16> [[TMP2]] to <4 x i32>
2820+
; CHECK-NEXT: [[TMP4:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[TMP3]])
2821+
; CHECK-NEXT: [[TMP5]] = add i32 [[VEC_PHI]], [[TMP4]]
2822+
; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 4
2823+
; CHECK-NEXT: [[TMP6:%.*]] = icmp eq i64 [[INDEX_NEXT]], 1024
2824+
; CHECK-NEXT: br i1 [[TMP6]], label %[[MIDDLE_BLOCK:.*]], label %[[VECTOR_BODY]], !llvm.loop [[LOOP32:![0-9]+]]
2825+
; CHECK: [[MIDDLE_BLOCK]]:
2826+
; CHECK-NEXT: br label %[[FOR_EXIT:.*]]
2827+
; CHECK: [[FOR_EXIT]]:
2828+
; CHECK-NEXT: ret i32 [[TMP5]]
2829+
;
2830+
; CHECK-INTERLEAVED-LABEL: define i32 @reduction_expression_ext_mulacc_livein(
2831+
; CHECK-INTERLEAVED-SAME: ptr [[A:%.*]], ptr [[B:%.*]], i16 [[C:%.*]]) {
2832+
; CHECK-INTERLEAVED-NEXT: [[ENTRY:.*:]]
2833+
; CHECK-INTERLEAVED-NEXT: br label %[[VECTOR_PH:.*]]
2834+
; CHECK-INTERLEAVED: [[VECTOR_PH]]:
2835+
; CHECK-INTERLEAVED-NEXT: [[BROADCAST_SPLATINSERT:%.*]] = insertelement <4 x i16> poison, i16 [[C]], i64 0
2836+
; CHECK-INTERLEAVED-NEXT: [[BROADCAST_SPLAT:%.*]] = shufflevector <4 x i16> [[BROADCAST_SPLATINSERT]], <4 x i16> poison, <4 x i32> zeroinitializer
2837+
; CHECK-INTERLEAVED-NEXT: br label %[[VECTOR_BODY:.*]]
2838+
; CHECK-INTERLEAVED: [[VECTOR_BODY]]:
2839+
; CHECK-INTERLEAVED-NEXT: [[INDEX:%.*]] = phi i64 [ 0, %[[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], %[[VECTOR_BODY]] ]
2840+
; CHECK-INTERLEAVED-NEXT: [[VEC_PHI:%.*]] = phi i32 [ 0, %[[VECTOR_PH]] ], [ [[TMP8:%.*]], %[[VECTOR_BODY]] ]
2841+
; CHECK-INTERLEAVED-NEXT: [[VEC_PHI1:%.*]] = phi i32 [ 0, %[[VECTOR_PH]] ], [ [[TMP11:%.*]], %[[VECTOR_BODY]] ]
2842+
; CHECK-INTERLEAVED-NEXT: [[TMP0:%.*]] = getelementptr i8, ptr [[A]], i64 [[INDEX]]
2843+
; CHECK-INTERLEAVED-NEXT: [[TMP1:%.*]] = getelementptr i8, ptr [[TMP0]], i32 4
2844+
; CHECK-INTERLEAVED-NEXT: [[WIDE_LOAD:%.*]] = load <4 x i8>, ptr [[TMP0]], align 1
2845+
; CHECK-INTERLEAVED-NEXT: [[WIDE_LOAD2:%.*]] = load <4 x i8>, ptr [[TMP1]], align 1
2846+
; CHECK-INTERLEAVED-NEXT: [[TMP2:%.*]] = zext <4 x i8> [[WIDE_LOAD]] to <4 x i16>
2847+
; CHECK-INTERLEAVED-NEXT: [[TMP3:%.*]] = zext <4 x i8> [[WIDE_LOAD2]] to <4 x i16>
2848+
; CHECK-INTERLEAVED-NEXT: [[TMP4:%.*]] = mul <4 x i16> [[BROADCAST_SPLAT]], [[TMP2]]
2849+
; CHECK-INTERLEAVED-NEXT: [[TMP5:%.*]] = mul <4 x i16> [[BROADCAST_SPLAT]], [[TMP3]]
2850+
; CHECK-INTERLEAVED-NEXT: [[TMP6:%.*]] = zext <4 x i16> [[TMP4]] to <4 x i32>
2851+
; CHECK-INTERLEAVED-NEXT: [[TMP7:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[TMP6]])
2852+
; CHECK-INTERLEAVED-NEXT: [[TMP8]] = add i32 [[VEC_PHI]], [[TMP7]]
2853+
; CHECK-INTERLEAVED-NEXT: [[TMP9:%.*]] = zext <4 x i16> [[TMP5]] to <4 x i32>
2854+
; CHECK-INTERLEAVED-NEXT: [[TMP10:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[TMP9]])
2855+
; CHECK-INTERLEAVED-NEXT: [[TMP11]] = add i32 [[VEC_PHI1]], [[TMP10]]
2856+
; CHECK-INTERLEAVED-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 8
2857+
; CHECK-INTERLEAVED-NEXT: [[TMP12:%.*]] = icmp eq i64 [[INDEX_NEXT]], 1024
2858+
; CHECK-INTERLEAVED-NEXT: br i1 [[TMP12]], label %[[MIDDLE_BLOCK:.*]], label %[[VECTOR_BODY]], !llvm.loop [[LOOP32:![0-9]+]]
2859+
; CHECK-INTERLEAVED: [[MIDDLE_BLOCK]]:
2860+
; CHECK-INTERLEAVED-NEXT: [[BIN_RDX:%.*]] = add i32 [[TMP11]], [[TMP8]]
2861+
; CHECK-INTERLEAVED-NEXT: br label %[[FOR_EXIT:.*]]
2862+
; CHECK-INTERLEAVED: [[FOR_EXIT]]:
2863+
; CHECK-INTERLEAVED-NEXT: ret i32 [[BIN_RDX]]
2864+
;
2865+
entry:
2866+
br label %for.body
2867+
2868+
for.body: ; preds = %for.body, %entry
2869+
%iv = phi i64 [ 0, %entry ], [ %iv.next, %for.body ]
2870+
%accum = phi i32 [ 0, %entry ], [ %add, %for.body ]
2871+
%gep.a = getelementptr i8, ptr %a, i64 %iv
2872+
%load.a = load i8, ptr %gep.a, align 1
2873+
%ext.a = zext i8 %load.a to i16
2874+
%gep.b = getelementptr i8, ptr %b, i64 %iv
2875+
%load.b = load i8, ptr %gep.b, align 1
2876+
%ext.b = zext i8 %load.b to i16
2877+
%mul = mul i16 %c, %ext.a
2878+
%mul.ext = zext i16 %mul to i32
2879+
%add = add i32 %mul.ext, %accum
2880+
%iv.next = add i64 %iv, 1
2881+
%exitcond.not = icmp eq i64 %iv.next, 1024
2882+
br i1 %exitcond.not, label %for.exit, label %for.body
2883+
2884+
for.exit: ; preds = %for.body
2885+
ret i32 %add
2886+
}
2887+
28032888
declare float @llvm.fmuladd.f32(float, float, float)
28042889

28052890
!6 = distinct !{!6, !7, !8}

0 commit comments

Comments
 (0)