@@ -445,16 +445,36 @@ struct LinearizeVectorExtract final
445445 }
446446};
447447
448- // / This pattern converts the InsertOp to a ShuffleOp that works on a
449- // / linearized vector.
450- // / Following,
451- // / vector.insert %source %destination [ position ]
452- // / is converted to :
453- // / %source_1d = vector.shape_cast %source
454- // / %destination_1d = vector.shape_cast %destination
455- // / %out_1d = vector.shuffle %destination_1d, %source_1d [ shuffle_indices_1d
456- // / ] %out_nd = vector.shape_cast %out_1d
457- // / `shuffle_indices_1d` is computed using the position of the original insert.
448+ // / This pattern linearizes `vector.insert` operations. It generates a 1-D
449+ // / version of the `vector.insert` operation when inserting a scalar into a
450+ // / vector. It generates a 1-D `vector.shuffle` operation when inserting a
451+ // / vector into another vector.
452+ // /
453+ // / Example #1:
454+ // /
455+ // / %0 = vector.insert %source, %destination[0] :
456+ // / vector<2x4xf32> into vector<2x2x4xf32>
457+ // /
458+ // / is converted to:
459+ // /
460+ // / %0 = vector.shape_cast %source : vector<2x4xf32> to vector<8xf32>
461+ // / %1 = vector.shape_cast %destination :
462+ // / vector<2x2x4xf32> to vector<16xf32>
463+ // / %2 = vector.shuffle %1, %0 [16, 17, 18, 19, 20, 21, 22, 23
464+ // / 8, 9, 10, 11, 12, 13, 14, 15] :
465+ // / vector<16xf32>, vector<8xf32>
466+ // / %3 = vector.shape_cast %2 : vector<16xf32> to vector<2x2x4xf32>
467+ // /
468+ // / Example #2:
469+ // /
470+ // / %0 = vector.insert %source, %destination[1, 2]: f32 into vector<2x4xf32>
471+ // /
472+ // / is converted to:
473+ // /
474+ // / %0 = vector.shape_cast %destination : vector<2x4xf32> to vector<8xf32>
475+ // / %1 = vector.insert %source, %0[6]: f32 into vector<8xf32>
476+ // / %2 = vector.shape_cast %1 : vector<8xf32> to vector<2x4xf32>
477+ // /
458478struct LinearizeVectorInsert final
459479 : public OpConversionPattern<vector::InsertOp> {
460480 using OpConversionPattern::OpConversionPattern;
@@ -468,48 +488,55 @@ struct LinearizeVectorInsert final
468488 insertOp.getDestVectorType ());
469489 assert (dstTy && " vector type destination expected." );
470490
471- // dynamic position is not supported
491+ // Dynamic position is not supported.
472492 if (insertOp.hasDynamicPosition ())
473493 return rewriter.notifyMatchFailure (insertOp,
474494 " dynamic position is not supported." );
475495 auto srcTy = insertOp.getValueToStoreType ();
476496 auto srcAsVec = dyn_cast<VectorType>(srcTy);
477- uint64_t srcSize = 0 ;
478- if (srcAsVec) {
479- srcSize = srcAsVec.getNumElements ();
480- } else {
481- return rewriter.notifyMatchFailure (insertOp,
482- " scalars are not supported." );
483- }
497+ uint64_t srcSize = srcAsVec ? srcAsVec.getNumElements () : 1 ;
484498
485499 auto dstShape = insertOp.getDestVectorType ().getShape ();
486500 const auto dstSize = insertOp.getDestVectorType ().getNumElements ();
487501 auto dstSizeForOffsets = dstSize;
488502
489- // compute linearized offset
503+ // Compute linearized offset.
490504 int64_t linearizedOffset = 0 ;
491505 auto offsetsNd = insertOp.getStaticPosition ();
492506 for (auto [dim, offset] : llvm::enumerate (offsetsNd)) {
493507 dstSizeForOffsets /= dstShape[dim];
494508 linearizedOffset += offset * dstSizeForOffsets;
495509 }
496510
511+ Location loc = insertOp.getLoc ();
512+ Value valueToStore = adaptor.getValueToStore ();
513+
514+ if (!isa<VectorType>(valueToStore.getType ())) {
515+ // Scalar case: generate a 1-D insert.
516+ Value result = rewriter.createOrFold <vector::InsertOp>(
517+ loc, valueToStore, adaptor.getDest (), linearizedOffset);
518+ rewriter.replaceOp (insertOp, result);
519+ return success ();
520+ }
521+
522+ // Vector case: generate a shuffle.
497523 llvm::SmallVector<int64_t , 2 > indices (dstSize);
498524 auto *origValsUntil = indices.begin ();
499525 std::advance (origValsUntil, linearizedOffset);
500- std::iota (indices.begin (), origValsUntil,
501- 0 ); // original values that remain [0, offset)
526+
527+ // Original values that remain [0, offset).
528+ std::iota (indices.begin (), origValsUntil, 0 );
502529 auto *newValsUntil = origValsUntil;
503530 std::advance (newValsUntil, srcSize);
504- std::iota (origValsUntil, newValsUntil,
505- dstSize); // new values [offset, offset+srcNumElements)
506- std::iota (newValsUntil, indices.end (),
507- linearizedOffset + srcSize); // the rest of original values
508- // [offset+srcNumElements, end)
531+ // New values [offset, offset+srcNumElements).
532+ std::iota (origValsUntil, newValsUntil, dstSize);
533+ // The rest of original values [offset+srcNumElements, end);
534+ std::iota (newValsUntil, indices.end (), linearizedOffset + srcSize);
509535
510- rewriter.replaceOpWithNewOp <vector::ShuffleOp>(
511- insertOp , dstTy, adaptor.getDest (), adaptor. getValueToStore () , indices);
536+ Value result = rewriter.createOrFold <vector::ShuffleOp>(
537+ loc , dstTy, adaptor.getDest (), valueToStore , indices);
512538
539+ rewriter.replaceOp (insertOp, result);
513540 return success ();
514541 }
515542};
0 commit comments