Skip to content

Commit 6ce0474

Browse files
authored
[SDISel] Teach the type legalizer about ADDRSPACECAST (#90969)
Vectorized ADDRSPACECASTs were not supported by the type legalizer. This patch adds the support for: - splitting the vector result: <2 x ptr> => 2 x <1 x ptr> - scalarization: <1 x ptr> => ptr - widening: <3 x ptr> => <4 x ptr> This is all exercised by the added NVPTX tests.
1 parent d838e5b commit 6ce0474

File tree

3 files changed

+160
-0
lines changed

3 files changed

+160
-0
lines changed

llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -786,6 +786,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
786786
SDValue ScalarizeVecRes_InregOp(SDNode *N);
787787
SDValue ScalarizeVecRes_VecInregOp(SDNode *N);
788788

789+
SDValue ScalarizeVecRes_ADDRSPACECAST(SDNode *N);
789790
SDValue ScalarizeVecRes_BITCAST(SDNode *N);
790791
SDValue ScalarizeVecRes_BUILD_VECTOR(SDNode *N);
791792
SDValue ScalarizeVecRes_EXTRACT_SUBVECTOR(SDNode *N);
@@ -853,6 +854,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
853854
void SplitVecRes_BinOp(SDNode *N, SDValue &Lo, SDValue &Hi);
854855
void SplitVecRes_TernaryOp(SDNode *N, SDValue &Lo, SDValue &Hi);
855856
void SplitVecRes_UnaryOp(SDNode *N, SDValue &Lo, SDValue &Hi);
857+
void SplitVecRes_ADDRSPACECAST(SDNode *N, SDValue &Lo, SDValue &Hi);
856858
void SplitVecRes_FFREXP(SDNode *N, unsigned ResNo, SDValue &Lo, SDValue &Hi);
857859
void SplitVecRes_ExtendOp(SDNode *N, SDValue &Lo, SDValue &Hi);
858860
void SplitVecRes_InregOp(SDNode *N, SDValue &Lo, SDValue &Hi);
@@ -956,6 +958,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
956958
// Widen Vector Result Promotion.
957959
void WidenVectorResult(SDNode *N, unsigned ResNo);
958960
SDValue WidenVecRes_MERGE_VALUES(SDNode* N, unsigned ResNo);
961+
SDValue WidenVecRes_ADDRSPACECAST(SDNode *N);
959962
SDValue WidenVecRes_AssertZext(SDNode* N);
960963
SDValue WidenVecRes_BITCAST(SDNode* N);
961964
SDValue WidenVecRes_BUILD_VECTOR(SDNode* N);

llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include "llvm/ADT/SmallBitVector.h"
2424
#include "llvm/Analysis/MemoryLocation.h"
2525
#include "llvm/Analysis/VectorUtils.h"
26+
#include "llvm/CodeGen/ISDOpcodes.h"
2627
#include "llvm/IR/DataLayout.h"
2728
#include "llvm/Support/ErrorHandling.h"
2829
#include "llvm/Support/TypeSize.h"
@@ -116,6 +117,9 @@ void DAGTypeLegalizer::ScalarizeVectorResult(SDNode *N, unsigned ResNo) {
116117
case ISD::FCANONICALIZE:
117118
R = ScalarizeVecRes_UnaryOp(N);
118119
break;
120+
case ISD::ADDRSPACECAST:
121+
R = ScalarizeVecRes_ADDRSPACECAST(N);
122+
break;
119123
case ISD::FFREXP:
120124
R = ScalarizeVecRes_FFREXP(N, ResNo);
121125
break;
@@ -475,6 +479,31 @@ SDValue DAGTypeLegalizer::ScalarizeVecRes_VecInregOp(SDNode *N) {
475479
llvm_unreachable("Illegal extend_vector_inreg opcode");
476480
}
477481

482+
SDValue DAGTypeLegalizer::ScalarizeVecRes_ADDRSPACECAST(SDNode *N) {
483+
EVT DestVT = N->getValueType(0).getVectorElementType();
484+
SDValue Op = N->getOperand(0);
485+
EVT OpVT = Op.getValueType();
486+
SDLoc DL(N);
487+
// The result needs scalarizing, but it's not a given that the source does.
488+
// This is a workaround for targets where it's impossible to scalarize the
489+
// result of a conversion, because the source type is legal.
490+
// For instance, this happens on AArch64: v1i1 is illegal but v1i{8,16,32}
491+
// are widened to v8i8, v4i16, and v2i32, which is legal, because v1i64 is
492+
// legal and was not scalarized.
493+
// See the similar logic in ScalarizeVecRes_SETCC
494+
if (getTypeAction(OpVT) == TargetLowering::TypeScalarizeVector) {
495+
Op = GetScalarizedVector(Op);
496+
} else {
497+
EVT VT = OpVT.getVectorElementType();
498+
Op = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, Op,
499+
DAG.getVectorIdxConstant(0, DL));
500+
}
501+
auto *AddrSpaceCastN = cast<AddrSpaceCastSDNode>(N);
502+
unsigned SrcAS = AddrSpaceCastN->getSrcAddressSpace();
503+
unsigned DestAS = AddrSpaceCastN->getDestAddressSpace();
504+
return DAG.getAddrSpaceCast(DL, DestVT, Op, SrcAS, DestAS);
505+
}
506+
478507
SDValue DAGTypeLegalizer::ScalarizeVecRes_SCALAR_TO_VECTOR(SDNode *N) {
479508
// If the operand is wider than the vector element type then it is implicitly
480509
// truncated. Make that explicit here.
@@ -1122,6 +1151,9 @@ void DAGTypeLegalizer::SplitVectorResult(SDNode *N, unsigned ResNo) {
11221151
case ISD::FCANONICALIZE:
11231152
SplitVecRes_UnaryOp(N, Lo, Hi);
11241153
break;
1154+
case ISD::ADDRSPACECAST:
1155+
SplitVecRes_ADDRSPACECAST(N, Lo, Hi);
1156+
break;
11251157
case ISD::FFREXP:
11261158
SplitVecRes_FFREXP(N, ResNo, Lo, Hi);
11271159
break;
@@ -2353,6 +2385,26 @@ void DAGTypeLegalizer::SplitVecRes_UnaryOp(SDNode *N, SDValue &Lo,
23532385
Hi = DAG.getNode(Opcode, dl, HiVT, {Hi, MaskHi, EVLHi}, Flags);
23542386
}
23552387

2388+
void DAGTypeLegalizer::SplitVecRes_ADDRSPACECAST(SDNode *N, SDValue &Lo,
2389+
SDValue &Hi) {
2390+
SDLoc dl(N);
2391+
auto [LoVT, HiVT] = DAG.GetSplitDestVTs(N->getValueType(0));
2392+
2393+
// If the input also splits, handle it directly for a compile time speedup.
2394+
// Otherwise split it by hand.
2395+
EVT InVT = N->getOperand(0).getValueType();
2396+
if (getTypeAction(InVT) == TargetLowering::TypeSplitVector)
2397+
GetSplitVector(N->getOperand(0), Lo, Hi);
2398+
else
2399+
std::tie(Lo, Hi) = DAG.SplitVectorOperand(N, 0);
2400+
2401+
auto *AddrSpaceCastN = cast<AddrSpaceCastSDNode>(N);
2402+
unsigned SrcAS = AddrSpaceCastN->getSrcAddressSpace();
2403+
unsigned DestAS = AddrSpaceCastN->getDestAddressSpace();
2404+
Lo = DAG.getAddrSpaceCast(dl, LoVT, Lo, SrcAS, DestAS);
2405+
Hi = DAG.getAddrSpaceCast(dl, HiVT, Hi, SrcAS, DestAS);
2406+
}
2407+
23562408
void DAGTypeLegalizer::SplitVecRes_FFREXP(SDNode *N, unsigned ResNo,
23572409
SDValue &Lo, SDValue &Hi) {
23582410
SDLoc dl(N);
@@ -4121,6 +4173,9 @@ void DAGTypeLegalizer::WidenVectorResult(SDNode *N, unsigned ResNo) {
41214173
report_fatal_error("Do not know how to widen the result of this operator!");
41224174

41234175
case ISD::MERGE_VALUES: Res = WidenVecRes_MERGE_VALUES(N, ResNo); break;
4176+
case ISD::ADDRSPACECAST:
4177+
Res = WidenVecRes_ADDRSPACECAST(N);
4178+
break;
41244179
case ISD::AssertZext: Res = WidenVecRes_AssertZext(N); break;
41254180
case ISD::BITCAST: Res = WidenVecRes_BITCAST(N); break;
41264181
case ISD::BUILD_VECTOR: Res = WidenVecRes_BUILD_VECTOR(N); break;
@@ -5086,6 +5141,16 @@ SDValue DAGTypeLegalizer::WidenVecRes_MERGE_VALUES(SDNode *N, unsigned ResNo) {
50865141
return GetWidenedVector(WidenVec);
50875142
}
50885143

5144+
SDValue DAGTypeLegalizer::WidenVecRes_ADDRSPACECAST(SDNode *N) {
5145+
EVT WidenVT = TLI.getTypeToTransformTo(*DAG.getContext(), N->getValueType(0));
5146+
SDValue InOp = GetWidenedVector(N->getOperand(0));
5147+
auto *AddrSpaceCastN = cast<AddrSpaceCastSDNode>(N);
5148+
5149+
return DAG.getAddrSpaceCast(SDLoc(N), WidenVT, InOp,
5150+
AddrSpaceCastN->getSrcAddressSpace(),
5151+
AddrSpaceCastN->getDestAddressSpace());
5152+
}
5153+
50895154
SDValue DAGTypeLegalizer::WidenVecRes_BITCAST(SDNode *N) {
50905155
SDValue InOp = N->getOperand(0);
50915156
EVT InVT = InOp.getValueType();

llvm/test/CodeGen/NVPTX/addrspacecast.ll

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,3 +98,95 @@ define i32 @conv8(ptr %ptr) {
9898
%val = load i32, ptr addrspace(5) %specptr
9999
ret i32 %val
100100
}
101+
102+
; Check that we support addrspacecast when splitting the vector
103+
; result (<2 x ptr> => 2 x <1 x ptr>).
104+
; This also checks that scalarization works for addrspacecast
105+
; (when going from <1 x ptr> to ptr.)
106+
; ALL-LABEL: split1To0
107+
define void @split1To0(ptr nocapture noundef readonly %xs) {
108+
; CLS32: cvta.global.u32
109+
; CLS32: cvta.global.u32
110+
; CLS64: cvta.global.u64
111+
; CLS64: cvta.global.u64
112+
; ALL: st.u32
113+
; ALL: st.u32
114+
%vec_addr = load <2 x ptr addrspace(1)>, ptr %xs, align 16
115+
%addrspacecast = addrspacecast <2 x ptr addrspace(1)> %vec_addr to <2 x ptr>
116+
%extractelement0 = extractelement <2 x ptr> %addrspacecast, i64 0
117+
store float 0.5, ptr %extractelement0, align 4
118+
%extractelement1 = extractelement <2 x ptr> %addrspacecast, i64 1
119+
store float 1.0, ptr %extractelement1, align 4
120+
ret void
121+
}
122+
123+
; Same as split1To0 but from 0 to 1, to make sure the addrspacecast preserve
124+
; the source and destination addrspaces properly.
125+
; ALL-LABEL: split0To1
126+
define void @split0To1(ptr nocapture noundef readonly %xs) {
127+
; CLS32: cvta.to.global.u32
128+
; CLS32: cvta.to.global.u32
129+
; CLS64: cvta.to.global.u64
130+
; CLS64: cvta.to.global.u64
131+
; ALL: st.global.u32
132+
; ALL: st.global.u32
133+
%vec_addr = load <2 x ptr>, ptr %xs, align 16
134+
%addrspacecast = addrspacecast <2 x ptr> %vec_addr to <2 x ptr addrspace(1)>
135+
%extractelement0 = extractelement <2 x ptr addrspace(1)> %addrspacecast, i64 0
136+
store float 0.5, ptr addrspace(1) %extractelement0, align 4
137+
%extractelement1 = extractelement <2 x ptr addrspace(1)> %addrspacecast, i64 1
138+
store float 1.0, ptr addrspace(1) %extractelement1, align 4
139+
ret void
140+
}
141+
142+
; Check that we support addrspacecast when a widening is required
143+
; (3 x ptr => 4 x ptr).
144+
; ALL-LABEL: widen1To0
145+
define void @widen1To0(ptr nocapture noundef readonly %xs) {
146+
; CLS32: cvta.global.u32
147+
; CLS32: cvta.global.u32
148+
; CLS32: cvta.global.u32
149+
150+
; CLS64: cvta.global.u64
151+
; CLS64: cvta.global.u64
152+
; CLS64: cvta.global.u64
153+
154+
; ALL: st.u32
155+
; ALL: st.u32
156+
; ALL: st.u32
157+
%vec_addr = load <3 x ptr addrspace(1)>, ptr %xs, align 16
158+
%addrspacecast = addrspacecast <3 x ptr addrspace(1)> %vec_addr to <3 x ptr>
159+
%extractelement0 = extractelement <3 x ptr> %addrspacecast, i64 0
160+
store float 0.5, ptr %extractelement0, align 4
161+
%extractelement1 = extractelement <3 x ptr> %addrspacecast, i64 1
162+
store float 1.0, ptr %extractelement1, align 4
163+
%extractelement2 = extractelement <3 x ptr> %addrspacecast, i64 2
164+
store float 1.5, ptr %extractelement2, align 4
165+
ret void
166+
}
167+
168+
; Same as widen1To0 but from 0 to 1, to make sure the addrspacecast preserve
169+
; the source and destination addrspaces properly.
170+
; ALL-LABEL: widen0To1
171+
define void @widen0To1(ptr nocapture noundef readonly %xs) {
172+
; CLS32: cvta.to.global.u32
173+
; CLS32: cvta.to.global.u32
174+
; CLS32: cvta.to.global.u32
175+
176+
; CLS64: cvta.to.global.u64
177+
; CLS64: cvta.to.global.u64
178+
; CLS64: cvta.to.global.u64
179+
180+
; ALL: st.global.u32
181+
; ALL: st.global.u32
182+
; ALL: st.global.u32
183+
%vec_addr = load <3 x ptr>, ptr %xs, align 16
184+
%addrspacecast = addrspacecast <3 x ptr> %vec_addr to <3 x ptr addrspace(1)>
185+
%extractelement0 = extractelement <3 x ptr addrspace(1)> %addrspacecast, i64 0
186+
store float 0.5, ptr addrspace(1) %extractelement0, align 4
187+
%extractelement1 = extractelement <3 x ptr addrspace(1)> %addrspacecast, i64 1
188+
store float 1.0, ptr addrspace(1) %extractelement1, align 4
189+
%extractelement2 = extractelement <3 x ptr addrspace(1)> %addrspacecast, i64 2
190+
store float 1.5, ptr addrspace(1) %extractelement2, align 4
191+
ret void
192+
}

0 commit comments

Comments
 (0)