Skip to content

Commit b3f2dfa

Browse files
committed
[DAG] visitBITCAST - fold (bitcast (freeze (load x))) -> (freeze (load (bitcast*)x))
Tweak the existing (bitcast (load x)) -> (load (bitcast*)x) fold to handle freeze as well Inspired by llvm#163070 - attempt to pass the bitcast through a oneuse frozen load This tries to avoid in place replacement of frozen nodes which has caused infinite loops in the past
1 parent c636a39 commit b3f2dfa

File tree

5 files changed

+128
-138
lines changed

5 files changed

+128
-138
lines changed

llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp

Lines changed: 41 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -16693,38 +16693,51 @@ SDValue DAGCombiner::visitBITCAST(SDNode *N) {
1669316693
}
1669416694

1669516695
// fold (conv (load x)) -> (load (conv*)x)
16696+
// fold (conv (freeze (load x))) -> (freeze (load (conv*)x))
1669616697
// If the resultant load doesn't need a higher alignment than the original!
16697-
if (ISD::isNormalLoad(N0.getNode()) && N0.hasOneUse() &&
16698-
// Do not remove the cast if the types differ in endian layout.
16699-
TLI.hasBigEndianPartOrdering(N0.getValueType(), DAG.getDataLayout()) ==
16700-
TLI.hasBigEndianPartOrdering(VT, DAG.getDataLayout()) &&
16701-
// If the load is volatile, we only want to change the load type if the
16702-
// resulting load is legal. Otherwise we might increase the number of
16703-
// memory accesses. We don't care if the original type was legal or not
16704-
// as we assume software couldn't rely on the number of accesses of an
16705-
// illegal type.
16706-
((!LegalOperations && cast<LoadSDNode>(N0)->isSimple()) ||
16707-
TLI.isOperationLegal(ISD::LOAD, VT))) {
16708-
LoadSDNode *LN0 = cast<LoadSDNode>(N0);
16698+
auto CastLoad = [this, &VT](SDValue N0, const SDLoc &DL) {
16699+
auto *LN0 = dyn_cast<LoadSDNode>(N0);
16700+
if (!LN0 || !ISD::isNormalLoad(LN0) || !N0.hasOneUse())
16701+
return SDValue();
1670916702

16710-
if (TLI.isLoadBitCastBeneficial(N0.getValueType(), VT, DAG,
16711-
*LN0->getMemOperand())) {
16712-
// If the range metadata type does not match the new memory
16713-
// operation type, remove the range metadata.
16714-
if (const MDNode *MD = LN0->getRanges()) {
16715-
ConstantInt *Lower = mdconst::extract<ConstantInt>(MD->getOperand(0));
16716-
if (Lower->getBitWidth() != VT.getScalarSizeInBits() ||
16717-
!VT.isInteger()) {
16718-
LN0->getMemOperand()->clearRanges();
16719-
}
16703+
// Do not remove the cast if the types differ in endian layout.
16704+
if (TLI.hasBigEndianPartOrdering(N0.getValueType(), DAG.getDataLayout()) !=
16705+
TLI.hasBigEndianPartOrdering(VT, DAG.getDataLayout()))
16706+
return SDValue();
16707+
16708+
// If the load is volatile, we only want to change the load type if the
16709+
// resulting load is legal. Otherwise we might increase the number of
16710+
// memory accesses. We don't care if the original type was legal or not
16711+
// as we assume software couldn't rely on the number of accesses of an
16712+
// illegal type.
16713+
if (((LegalOperations || !LN0->isSimple()) &&
16714+
!TLI.isOperationLegal(ISD::LOAD, VT)))
16715+
return SDValue();
16716+
16717+
if (!TLI.isLoadBitCastBeneficial(N0.getValueType(), VT, DAG,
16718+
*LN0->getMemOperand()))
16719+
return SDValue();
16720+
16721+
// If the range metadata type does not match the new memory
16722+
// operation type, remove the range metadata.
16723+
if (const MDNode *MD = LN0->getRanges()) {
16724+
ConstantInt *Lower = mdconst::extract<ConstantInt>(MD->getOperand(0));
16725+
if (Lower->getBitWidth() != VT.getScalarSizeInBits() || !VT.isInteger()) {
16726+
LN0->getMemOperand()->clearRanges();
1672016727
}
16721-
SDValue Load =
16722-
DAG.getLoad(VT, SDLoc(N), LN0->getChain(), LN0->getBasePtr(),
16723-
LN0->getMemOperand());
16724-
DAG.ReplaceAllUsesOfValueWith(N0.getValue(1), Load.getValue(1));
16725-
return Load;
1672616728
}
16727-
}
16729+
SDValue Load = DAG.getLoad(VT, DL, LN0->getChain(), LN0->getBasePtr(),
16730+
LN0->getMemOperand());
16731+
DAG.ReplaceAllUsesOfValueWith(N0.getValue(1), Load.getValue(1));
16732+
return Load;
16733+
};
16734+
16735+
if (SDValue NewLd = CastLoad(N0, SDLoc(N)))
16736+
return NewLd;
16737+
16738+
if (N0.getOpcode() == ISD::FREEZE && N0.hasOneUse())
16739+
if (SDValue NewLd = CastLoad(N0.getOperand(0), SDLoc(N)))
16740+
return DAG.getFreeze(NewLd);
1672816741

1672916742
if (SDValue V = foldBitcastedFPLogic(N, DAG, TLI))
1673016743
return V;

llvm/lib/Target/X86/X86ISelLowering.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3453,6 +3453,12 @@ bool X86TargetLowering::isLoadBitCastBeneficial(EVT LoadVT, EVT BitcastVT,
34533453
isTypeLegal(LoadVT) && isTypeLegal(BitcastVT))
34543454
return true;
34553455

3456+
// If we have a large vector type (even if illegal), don't bitcast to large
3457+
// (illegal) scalar types. Better to load fewer vectors and extract.
3458+
if (LoadVT.isVector() && !BitcastVT.isVector() && LoadVT.isInteger() &&
3459+
BitcastVT.isInteger() && (LoadVT.getSizeInBits() % 128) == 0)
3460+
return false;
3461+
34563462
return TargetLowering::isLoadBitCastBeneficial(LoadVT, BitcastVT, DAG, MMO);
34573463
}
34583464

llvm/test/CodeGen/X86/avx10_2_512bf16-arith.ll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,8 @@ define <32 x bfloat> @test_int_x86_avx10_maskz_sub_bf16_512(<32 x bfloat> %src,
9494
;
9595
; X86-LABEL: test_int_x86_avx10_maskz_sub_bf16_512:
9696
; X86: # %bb.0:
97-
; X86-NEXT: kmovd {{[0-9]+}}(%esp), %k1 # encoding: [0xc4,0xe1,0xf9,0x90,0x4c,0x24,0x04]
9897
; X86-NEXT: movl {{[0-9]+}}(%esp), %eax # encoding: [0x8b,0x44,0x24,0x08]
98+
; X86-NEXT: kmovd {{[0-9]+}}(%esp), %k1 # encoding: [0xc4,0xe1,0xf9,0x90,0x4c,0x24,0x04]
9999
; X86-NEXT: vsubbf16 %zmm2, %zmm1, %zmm0 {%k1} {z} # encoding: [0x62,0xf5,0x75,0xc9,0x5c,0xc2]
100100
; X86-NEXT: vsubbf16 (%eax), %zmm1, %zmm1 # encoding: [0x62,0xf5,0x75,0x48,0x5c,0x08]
101101
; X86-NEXT: vsubbf16 %zmm1, %zmm0, %zmm0 {%k1} # encoding: [0x62,0xf5,0x7d,0x49,0x5c,0xc1]

llvm/test/CodeGen/X86/avx10_2bf16-arith.ll

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,8 +147,8 @@ define <16 x bfloat> @test_int_x86_avx10_maskz_sub_bf16_256(<16 x bfloat> %src,
147147
;
148148
; X86-LABEL: test_int_x86_avx10_maskz_sub_bf16_256:
149149
; X86: # %bb.0:
150-
; X86-NEXT: kmovw {{[0-9]+}}(%esp), %k1 # encoding: [0xc5,0xf8,0x90,0x4c,0x24,0x04]
151150
; X86-NEXT: movl {{[0-9]+}}(%esp), %eax # encoding: [0x8b,0x44,0x24,0x08]
151+
; X86-NEXT: kmovw {{[0-9]+}}(%esp), %k1 # encoding: [0xc5,0xf8,0x90,0x4c,0x24,0x04]
152152
; X86-NEXT: vsubbf16 %ymm2, %ymm1, %ymm0 {%k1} {z} # encoding: [0x62,0xf5,0x75,0xa9,0x5c,0xc2]
153153
; X86-NEXT: vsubbf16 (%eax), %ymm1, %ymm1 # encoding: [0x62,0xf5,0x75,0x28,0x5c,0x08]
154154
; X86-NEXT: vsubbf16 %ymm1, %ymm0, %ymm0 {%k1} # encoding: [0x62,0xf5,0x7d,0x29,0x5c,0xc1]
@@ -201,8 +201,8 @@ define <8 x bfloat> @test_int_x86_avx10_maskz_sub_bf16_128(<8 x bfloat> %src, <8
201201
;
202202
; X86-LABEL: test_int_x86_avx10_maskz_sub_bf16_128:
203203
; X86: # %bb.0:
204-
; X86-NEXT: kmovb {{[0-9]+}}(%esp), %k1 # encoding: [0xc5,0xf9,0x90,0x4c,0x24,0x04]
205204
; X86-NEXT: movl {{[0-9]+}}(%esp), %eax # encoding: [0x8b,0x44,0x24,0x08]
205+
; X86-NEXT: kmovb {{[0-9]+}}(%esp), %k1 # encoding: [0xc5,0xf9,0x90,0x4c,0x24,0x04]
206206
; X86-NEXT: vsubbf16 %xmm2, %xmm1, %xmm0 {%k1} {z} # encoding: [0x62,0xf5,0x75,0x89,0x5c,0xc2]
207207
; X86-NEXT: vsubbf16 (%eax), %xmm1, %xmm1 # encoding: [0x62,0xf5,0x75,0x08,0x5c,0x08]
208208
; X86-NEXT: vsubbf16 %xmm1, %xmm0, %xmm0 {%k1} # encoding: [0x62,0xf5,0x7d,0x09,0x5c,0xc1]

0 commit comments

Comments
 (0)