@@ -880,6 +880,38 @@ static Value rewriteI4ToI8SignedExt(PatternRewriter &rewriter, Location loc,
880880 return rewriter.create <vector::InterleaveOp>(loc, low, high);
881881}
882882
883+ // / Rewrite the i4 -> i8 unsigned extension into a sequence of shuffles and
884+ // / bitwise ops that take advantage of high-level information to avoid leaving
885+ // / LLVM to scramble with peephole optimizations.
886+ static Value rewriteI4ToI8UnsignedExt (PatternRewriter &rewriter, Location loc,
887+ Value srcValue) {
888+ VectorType srcVecType = cast<VectorType>(srcValue.getType ());
889+ assert (srcVecType.getElementType ().isSignlessInteger (4 ) &&
890+ " Expected i4 type" );
891+
892+ // 1. Generate a bitcast vector<Xxi4> -> vector<X/2xi8>.
893+ SmallVector<int64_t > i8VecShape = llvm::to_vector (srcVecType.getShape ());
894+ constexpr int64_t i4Toi8BitwidthFactor = 2 ;
895+ i8VecShape.back () = i8VecShape.back () / i4Toi8BitwidthFactor;
896+ auto i8VecType = VectorType::get (i8VecShape, rewriter.getI8Type ());
897+ Value i8Vector = rewriter.create <vector::BitCastOp>(loc, i8VecType, srcValue);
898+
899+ // 2 Extend the i4 elements using shifts & masking. Low i4 elements of each
900+ // byte are placed in one vector and the high i4 elements in another vector.
901+ constexpr uint8_t lowBitsMask = 15 ; // Equivalent to [00001111] bit mask
902+ auto lowBitsMaskValues = rewriter.create <arith::ConstantOp>(
903+ loc, DenseElementsAttr::get (i8VecType, lowBitsMask));
904+ Value low = rewriter.create <arith::AndIOp>(loc, i8VecType, i8Vector,
905+ lowBitsMaskValues);
906+ constexpr int8_t highBitsToShift = 4 ;
907+ auto highShiftValues = rewriter.create <arith::ConstantOp>(
908+ loc, DenseElementsAttr::get (i8VecType, highBitsToShift));
909+ Value high = rewriter.create <arith::ShRUIOp>(loc, i8Vector, highShiftValues);
910+
911+ // 3. Interleave low and high i8 elements.
912+ return rewriter.create <vector::InterleaveOp>(loc, low, high);
913+ }
914+
883915// / Rewrite the i8 -> i4 truncation into a sequence of shuffles and bitwise ops
884916// / that take advantage of high-level information to avoid leaving LLVM to
885917// / scramble with peephole optimizations.
@@ -1048,9 +1080,10 @@ struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
10481080
10491081// / Rewrite the i4 -> i8 part of any conversion into a sequence of shuffles and
10501082// / bitwise ops that take advantage of high-level information to avoid leaving
1051- // / LLVM to scramble with peephole optimizations.
1083+ // / LLVM to scramble with peephole optimizations. Templated to choose between
1084+ // / signed and unsigned conversions.
10521085// /
1053- // / For example:
1086+ // / For example (signed) :
10541087// / arith.extsi %in : vector<8xi4> to vector<8xi32>
10551088// / is rewriten as
10561089// / %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
@@ -1069,16 +1102,25 @@ struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
10691102// / %4 = vector.interleave %2, %3 : vector<4xi8>
10701103// / %5 = arith.sitofp %4 : vector<8xi8> to vector<8xf32>
10711104// /
1072- template <typename ConversionOpType>
1073- struct RewriteAlignedSubByteIntSignedExt : OpRewritePattern<ConversionOpType> {
1105+ // / Example (unsigned):
1106+ // / arith.extui %in : vector<8xi4> to vector<8xi32>
1107+ // / is rewritten as
1108+ // / %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
1109+ // / %1 = arith.andi %0, 15 : vector<4xi8>
1110+ // / %2 = arith.shrui %0, 4 : vector<4xi8>
1111+ // / %3 = vector.interleave %1, %2 : vector<4xi8>
1112+ // / %4 = arith.extui %3 : vector<8xi8> to vector<8xi32>
1113+ // /
1114+ template <typename ConversionOpType, bool isSigned>
1115+ struct RewriteAlignedSubByteIntExt : OpRewritePattern<ConversionOpType> {
10741116 using OpRewritePattern<ConversionOpType>::OpRewritePattern;
10751117
10761118 LogicalResult matchAndRewrite (ConversionOpType conversionOp,
10771119 PatternRewriter &rewriter) const override {
10781120 // Verify the preconditions.
10791121 Value srcValue = conversionOp.getIn ();
1080- auto srcVecType = dyn_cast <VectorType>(srcValue.getType ());
1081- auto dstVecType = dyn_cast <VectorType>(conversionOp.getType ());
1122+ auto srcVecType = cast <VectorType>(srcValue.getType ());
1123+ auto dstVecType = cast <VectorType>(conversionOp.getType ());
10821124 if (failed (
10831125 commonConversionPrecondition (rewriter, dstVecType, conversionOp)))
10841126 return failure ();
@@ -1089,8 +1131,14 @@ struct RewriteAlignedSubByteIntSignedExt : OpRewritePattern<ConversionOpType> {
10891131 return failure ();
10901132
10911133 // Perform the rewrite.
1092- Value subByteExt =
1093- rewriteI4ToI8SignedExt (rewriter, conversionOp.getLoc (), srcValue);
1134+ Value subByteExt;
1135+ if (isSigned) {
1136+ subByteExt =
1137+ rewriteI4ToI8SignedExt (rewriter, conversionOp.getLoc (), srcValue);
1138+ } else {
1139+ subByteExt =
1140+ rewriteI4ToI8UnsignedExt (rewriter, conversionOp.getLoc (), srcValue);
1141+ }
10941142
10951143 // Finalize the rewrite.
10961144 rewriter.replaceOpWithNewOp <ConversionOpType>(
@@ -1229,10 +1277,12 @@ void vector::populateVectorNarrowTypeRewritePatterns(
12291277
12301278 // Patterns for aligned cases. We set higher priority as they are expected to
12311279 // generate better performance for aligned cases.
1232- patterns.add <RewriteAlignedSubByteIntSignedExt <arith::ExtSIOp>,
1233- RewriteAlignedSubByteIntSignedExt <arith::SIToFPOp>,
1280+ patterns.add <RewriteAlignedSubByteIntExt <arith::ExtSIOp, /* isSigned= */ true >,
1281+ RewriteAlignedSubByteIntExt <arith::SIToFPOp, /* isSigned= */ true >,
12341282 RewriteAlignedSubByteIntTrunc>(patterns.getContext (),
12351283 benefit.getBenefit () + 1 );
1284+ patterns.add <RewriteAlignedSubByteIntExt<arith::ExtUIOp, /* isSigned=*/ false >>(
1285+ patterns.getContext (), benefit.getBenefit () + 1 );
12361286}
12371287
12381288void vector::populateVectorTransposeNarrowTypeRewritePatterns (
0 commit comments