Skip to content

Commit 870cad5

Browse files
wizardengineerlanza
authored andcommitted
[CIR] Add support for __builtin_ia32_psrldqi_byteshift (#1886)
**Related Issue**: #1885
1 parent c4847c0 commit 870cad5

File tree

3 files changed

+252
-23
lines changed

3 files changed

+252
-23
lines changed

clang/lib/CIR/CodeGen/CIRGenBuiltinX86.cpp

Lines changed: 67 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -165,16 +165,23 @@ static mlir::Value emitX86PSLLDQIByteShift(CIRGenFunction &cgf,
165165
CIRGenBuilderTy &builder = cgf.getBuilder();
166166
unsigned shiftVal = getIntValueFromConstOp(Ops[1]) & 0xff;
167167
mlir::Location loc = cgf.getLoc(E->getExprLoc());
168-
auto resultType = cast<cir::VectorType>(Ops[0].getType());
168+
auto byteVecType = cast<cir::VectorType>(Ops[0].getType());
169+
170+
// Get the original return type from the expression
171+
auto resultType = cast<cir::VectorType>(cgf.convertType(E->getType()));
169172

170173
// If pslldq is shifting the vector more than 15 bytes, emit zero.
171174
// This matches the hardware behavior where shifting by 16+ bytes
172175
// clears the entire 128-bit lane.
173-
if (shiftVal >= 16)
174-
return builder.getZero(loc, resultType);
176+
if (shiftVal >= 16) {
177+
mlir::Value zero = builder.getZero(loc, byteVecType);
178+
if (byteVecType != resultType)
179+
return builder.createBitcast(zero, resultType);
180+
return zero;
181+
}
175182

176-
// Builtin type is vXi64 so multiply by 8 to get bytes.
177-
unsigned numElts = resultType.getSize() * 8;
183+
// Builtin type is vXi8 (already in bytes)
184+
unsigned numElts = byteVecType.getSize();
178185
assert(numElts % 16 == 0 && "Vector size must be multiple of 16 bytes");
179186

180187
llvm::SmallVector<int64_t, 64> indices;
@@ -189,17 +196,63 @@ static mlir::Value emitX86PSLLDQIByteShift(CIRGenFunction &cgf,
189196
}
190197
}
191198

192-
// Cast to byte vector for shuffle operation
193-
auto byteVecTy = cir::VectorType::get(builder.getSInt8Ty(), numElts);
194-
mlir::Value byteCast = builder.createBitcast(Ops[0], byteVecTy);
195-
mlir::Value zero = builder.getZero(loc, byteVecTy);
199+
mlir::Value zero = builder.getZero(loc, byteVecType);
196200

197201
// Perform the shuffle (left shift by inserting zeros)
198-
mlir::Value shuffleResult =
199-
builder.createVecShuffle(loc, zero, byteCast, indices);
202+
mlir::Value shuffleResult = builder.createVecShuffle(loc, zero, Ops[0], indices);
203+
204+
// Cast back to original type if necessary
205+
if (byteVecType != resultType)
206+
return builder.createBitcast(shuffleResult, resultType);
207+
return shuffleResult;
208+
}
209+
210+
static mlir::Value emitX86PSRLDQIByteShift(CIRGenFunction &cgf,
211+
const CallExpr *E,
212+
ArrayRef<mlir::Value> Ops) {
213+
CIRGenBuilderTy &builder = cgf.getBuilder();
214+
auto byteVecType = cast<cir::VectorType>(Ops[0].getType());
215+
mlir::Location loc = cgf.getLoc(E->getExprLoc());
216+
unsigned shiftVal = getIntValueFromConstOp(Ops[1]) & 0xff;
217+
218+
// Get the original return type from the expression
219+
auto resultType = cast<cir::VectorType>(cgf.convertType(E->getType()));
220+
221+
// If psrldq is shifting the vector more than 15 bytes, emit zero.
222+
if (shiftVal >= 16) {
223+
mlir::Value zero = builder.getZero(loc, byteVecType);
224+
if (byteVecType != resultType)
225+
return builder.createBitcast(zero, resultType);
226+
return zero;
227+
}
228+
229+
// Builtin type is vXi8 (already in bytes)
230+
uint64_t numElts = byteVecType.getSize();
231+
assert(numElts % 16 == 0 && "Expected a multiple of 16");
232+
233+
llvm::SmallVector<int64_t, 64> indices;
234+
235+
// This correlates to the OG CodeGen
236+
// As stated in the OG, 256/512-bit psrldq operates on 128-bit lanes.
237+
// So we have to make sure we handle it.
238+
for (unsigned l = 0; l < numElts; l += 16) {
239+
for (unsigned i = 0; i < 16; ++i) {
240+
unsigned idx = i + shiftVal;
241+
if (idx >= 16)
242+
idx += numElts - 16;
243+
indices.push_back(idx + l);
244+
}
245+
}
246+
247+
mlir::Value zero = builder.getZero(loc, byteVecType);
248+
249+
// Perform the shuffle (right shift by inserting zeros from the left)
250+
mlir::Value shuffleResult = builder.createVecShuffle(loc, Ops[0], zero, indices);
200251

201-
// Cast back to original type
202-
return builder.createBitcast(shuffleResult, resultType);
252+
// Cast back to original type if necessary
253+
if (byteVecType != resultType)
254+
return builder.createBitcast(shuffleResult, resultType);
255+
return shuffleResult;
203256
}
204257

205258
static mlir::Value emitX86MaskedCompareResult(CIRGenFunction &cgf,
@@ -1366,7 +1419,7 @@ mlir::Value CIRGenFunction::emitX86BuiltinExpr(unsigned BuiltinID,
13661419
case X86::BI__builtin_ia32_psrldqi128_byteshift:
13671420
case X86::BI__builtin_ia32_psrldqi256_byteshift:
13681421
case X86::BI__builtin_ia32_psrldqi512_byteshift:
1369-
llvm_unreachable("psrldqi NYI");
1422+
return emitX86PSRLDQIByteShift(*this, E, Ops);
13701423
case X86::BI__builtin_ia32_kshiftliqi:
13711424
case X86::BI__builtin_ia32_kshiftlihi:
13721425
case X86::BI__builtin_ia32_kshiftlisi:

clang/test/CIR/CodeGen/builtin-x86-pslldqi.cpp

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,6 @@ typedef long long __m128i __attribute__((__vector_size__(16)));
1212
typedef long long __m256i __attribute__((__vector_size__(32)));
1313
typedef long long __m512i __attribute__((__vector_size__(64)));
1414

15-
// Declare the builtins directly
16-
extern __m128i __builtin_ia32_pslldqi128_byteshift(__m128i, int);
17-
extern __m256i __builtin_ia32_pslldqi256_byteshift(__m256i, int);
18-
extern __m512i __builtin_ia32_pslldqi512_byteshift(__m512i, int);
19-
2015
// ============================================================================
2116
// Core Functionality Tests
2217
// ============================================================================
@@ -48,7 +43,8 @@ __m128i test_pslldqi128_shift0(__m128i a) {
4843
// OGCG-LABEL: @_Z23test_pslldqi128_shift16Dv2_x
4944
__m128i test_pslldqi128_shift16(__m128i a) {
5045
// Entire vector shifted out, should return zero
51-
// CIR: %{{.*}} = cir.const #cir.zero : !cir.vector<!s64i x 2>
46+
// CIR: %{{.*}} = cir.const #cir.zero : !cir.vector<!s8i x 16>
47+
// CIR: %{{.*}} = cir.cast bitcast %{{.*}} : !cir.vector<!s8i x 16> -> !cir.vector<!s64i x 2>
5248
// LLVM: store <2 x i64> zeroinitializer, ptr %{{.*}}, align 16
5349
// OGCG: ret <2 x i64> zeroinitializer
5450
return __builtin_ia32_pslldqi128_byteshift(a, 16);
@@ -74,7 +70,8 @@ __m256i test_pslldqi256_shift4(__m256i a) {
7470
// OGCG-LABEL: @_Z23test_pslldqi256_shift16Dv4_x
7571
__m256i test_pslldqi256_shift16(__m256i a) {
7672
// Both lanes completely shifted out, returns zero
77-
// CIR: %{{.*}} = cir.const #cir.zero : !cir.vector<!s64i x 4>
73+
// CIR: %{{.*}} = cir.const #cir.zero : !cir.vector<!s8i x 32>
74+
// CIR: %{{.*}} = cir.cast bitcast %{{.*}} : !cir.vector<!s8i x 32> -> !cir.vector<!s64i x 4>
7875
// LLVM: store <4 x i64> zeroinitializer, ptr %{{.*}}, align 32
7976
// OGCG: ret <4 x i64> zeroinitializer
8077
return __builtin_ia32_pslldqi256_byteshift(a, 16);
@@ -100,7 +97,8 @@ __m512i test_pslldqi512_shift4(__m512i a) {
10097
// OGCG-LABEL: @_Z23test_pslldqi512_shift16Dv8_x
10198
__m512i test_pslldqi512_shift16(__m512i a) {
10299
// All 4 lanes completely cleared
103-
// CIR: %{{.*}} = cir.const #cir.zero : !cir.vector<!s64i x 8>
100+
// CIR: %{{.*}} = cir.const #cir.zero : !cir.vector<!s8i x 64>
101+
// CIR: %{{.*}} = cir.cast bitcast %{{.*}} : !cir.vector<!s8i x 64> -> !cir.vector<!s64i x 8>
104102
// LLVM: store <8 x i64> zeroinitializer, ptr %{{.*}}, align 64
105103
// OGCG: ret <8 x i64> zeroinitializer
106104
return __builtin_ia32_pslldqi512_byteshift(a, 16);
@@ -170,7 +168,8 @@ __m128i test_concrete_input_constant() {
170168
// OGCG-LABEL: @_Z22test_large_shift_valueDv2_x
171169
__m128i test_large_shift_value(__m128i a) {
172170
// 240 & 0xFF = 240, so this should return zero (240 > 16)
173-
// CIR: %{{.*}} = cir.const #cir.zero : !cir.vector<!s64i x 2>
171+
// CIR: %{{.*}} = cir.const #cir.zero : !cir.vector<!s8i x 16>
172+
// CIR: %{{.*}} = cir.cast bitcast %{{.*}} : !cir.vector<!s8i x 16> -> !cir.vector<!s64i x 2>
174173
// LLVM: store <2 x i64> zeroinitializer, ptr %{{.*}}, align 16
175174
// OGCG: ret <2 x i64> zeroinitializer
176175
return __builtin_ia32_pslldqi128_byteshift(a, 240);

0 commit comments

Comments
 (0)