@@ -165,16 +165,23 @@ static mlir::Value emitX86PSLLDQIByteShift(CIRGenFunction &cgf,
165165 CIRGenBuilderTy &builder = cgf.getBuilder ();
166166 unsigned shiftVal = getIntValueFromConstOp (Ops[1 ]) & 0xff ;
167167 mlir::Location loc = cgf.getLoc (E->getExprLoc ());
168- auto resultType = cast<cir::VectorType>(Ops[0 ].getType ());
168+ auto byteVecType = cast<cir::VectorType>(Ops[0 ].getType ());
169+
170+ // Get the original return type from the expression
171+ auto resultType = cast<cir::VectorType>(cgf.convertType (E->getType ()));
169172
170173 // If pslldq is shifting the vector more than 15 bytes, emit zero.
171174 // This matches the hardware behavior where shifting by 16+ bytes
172175 // clears the entire 128-bit lane.
173- if (shiftVal >= 16 )
174- return builder.getZero (loc, resultType);
176+ if (shiftVal >= 16 ) {
177+ mlir::Value zero = builder.getZero (loc, byteVecType);
178+ if (byteVecType != resultType)
179+ return builder.createBitcast (zero, resultType);
180+ return zero;
181+ }
175182
176- // Builtin type is vXi64 so multiply by 8 to get bytes.
177- unsigned numElts = resultType .getSize () * 8 ;
183+ // Builtin type is vXi8 (already in bytes)
184+ unsigned numElts = byteVecType .getSize ();
178185 assert (numElts % 16 == 0 && " Vector size must be multiple of 16 bytes" );
179186
180187 llvm::SmallVector<int64_t , 64 > indices;
@@ -189,17 +196,63 @@ static mlir::Value emitX86PSLLDQIByteShift(CIRGenFunction &cgf,
189196 }
190197 }
191198
192- // Cast to byte vector for shuffle operation
193- auto byteVecTy = cir::VectorType::get (builder.getSInt8Ty (), numElts);
194- mlir::Value byteCast = builder.createBitcast (Ops[0 ], byteVecTy);
195- mlir::Value zero = builder.getZero (loc, byteVecTy);
199+ mlir::Value zero = builder.getZero (loc, byteVecType);
196200
197201 // Perform the shuffle (left shift by inserting zeros)
198- mlir::Value shuffleResult =
199- builder.createVecShuffle (loc, zero, byteCast, indices);
202+ mlir::Value shuffleResult = builder.createVecShuffle (loc, zero, Ops[0 ], indices);
203+
204+ // Cast back to original type if necessary
205+ if (byteVecType != resultType)
206+ return builder.createBitcast (shuffleResult, resultType);
207+ return shuffleResult;
208+ }
209+
210+ static mlir::Value emitX86PSRLDQIByteShift (CIRGenFunction &cgf,
211+ const CallExpr *E,
212+ ArrayRef<mlir::Value> Ops) {
213+ CIRGenBuilderTy &builder = cgf.getBuilder ();
214+ auto byteVecType = cast<cir::VectorType>(Ops[0 ].getType ());
215+ mlir::Location loc = cgf.getLoc (E->getExprLoc ());
216+ unsigned shiftVal = getIntValueFromConstOp (Ops[1 ]) & 0xff ;
217+
218+ // Get the original return type from the expression
219+ auto resultType = cast<cir::VectorType>(cgf.convertType (E->getType ()));
220+
221+ // If psrldq is shifting the vector more than 15 bytes, emit zero.
222+ if (shiftVal >= 16 ) {
223+ mlir::Value zero = builder.getZero (loc, byteVecType);
224+ if (byteVecType != resultType)
225+ return builder.createBitcast (zero, resultType);
226+ return zero;
227+ }
228+
229+ // Builtin type is vXi8 (already in bytes)
230+ uint64_t numElts = byteVecType.getSize ();
231+ assert (numElts % 16 == 0 && " Expected a multiple of 16" );
232+
233+ llvm::SmallVector<int64_t , 64 > indices;
234+
235+ // This correlates to the OG CodeGen
236+ // As stated in the OG, 256/512-bit psrldq operates on 128-bit lanes.
237+ // So we have to make sure we handle it.
238+ for (unsigned l = 0 ; l < numElts; l += 16 ) {
239+ for (unsigned i = 0 ; i < 16 ; ++i) {
240+ unsigned idx = i + shiftVal;
241+ if (idx >= 16 )
242+ idx += numElts - 16 ;
243+ indices.push_back (idx + l);
244+ }
245+ }
246+
247+ mlir::Value zero = builder.getZero (loc, byteVecType);
248+
249+ // Perform the shuffle (right shift by inserting zeros from the left)
250+ mlir::Value shuffleResult = builder.createVecShuffle (loc, Ops[0 ], zero, indices);
200251
201- // Cast back to original type
202- return builder.createBitcast (shuffleResult, resultType);
252+ // Cast back to original type if necessary
253+ if (byteVecType != resultType)
254+ return builder.createBitcast (shuffleResult, resultType);
255+ return shuffleResult;
203256}
204257
205258static mlir::Value emitX86MaskedCompareResult (CIRGenFunction &cgf,
@@ -1366,7 +1419,7 @@ mlir::Value CIRGenFunction::emitX86BuiltinExpr(unsigned BuiltinID,
13661419 case X86::BI__builtin_ia32_psrldqi128_byteshift:
13671420 case X86::BI__builtin_ia32_psrldqi256_byteshift:
13681421 case X86::BI__builtin_ia32_psrldqi512_byteshift:
1369- llvm_unreachable ( " psrldqi NYI " );
1422+ return emitX86PSRLDQIByteShift (* this , E, Ops );
13701423 case X86::BI__builtin_ia32_kshiftliqi:
13711424 case X86::BI__builtin_ia32_kshiftlihi:
13721425 case X86::BI__builtin_ia32_kshiftlisi:
0 commit comments