@@ -2388,6 +2388,97 @@ class ConvertAtenViewOp : public OpConversionPattern<AtenViewOp> {
23882388};
23892389} // namespace
23902390
2391+ namespace {
2392+ class ConvertAtenSqueezeOp : public OpConversionPattern <AtenSqueezeOp> {
2393+ public:
2394+ using OpConversionPattern::OpConversionPattern;
2395+ LogicalResult
2396+ matchAndRewrite (AtenSqueezeOp op, OpAdaptor adaptor,
2397+ ConversionPatternRewriter &rewriter) const override {
2398+ if (failed (verifyLinalgCompatibleTypes (op, rewriter)))
2399+ return failure ();
2400+ Location loc = op.getLoc ();
2401+ Value input = adaptor.self ();
2402+ auto inputType = input.getType ().cast <RankedTensorType>();
2403+ int64_t inputRank = inputType.getRank ();
2404+ TypeConverter *typeConverter = getTypeConverter ();
2405+ auto resultType =
2406+ typeConverter->convertType (op.getType ()).cast <RankedTensorType>();
2407+ int64_t resultRank = resultType.getRank ();
2408+
2409+ if (inputRank == 0 ) {
2410+ return rewriter.notifyMatchFailure (
2411+ op, " zero input rank should have been handled by the folder" );
2412+ }
2413+
2414+ // In case the operand tensor type is statically shaped with all dimensions
2415+ // being unit extent, it will be collapsed to a 0-D tensor.
2416+ if (resultRank == 0 ) {
2417+ SmallVector<ReassociationIndices> reassociation;
2418+ rewriter.replaceOpWithNewOp <linalg::TensorCollapseShapeOp>(
2419+ op, resultType, input, reassociation);
2420+ return success ();
2421+ }
2422+
2423+ // All the static size-1 dimensions at the beginning(going from higher to
2424+ // lower dimensions) will be collapsed into the first dynamic or first non
2425+ // size-1 static dimension. All the other static size-1 dimensions will be
2426+ // collapsed into its previous dynamic or non size-1 static dimension.
2427+ SmallVector<ReassociationIndices> reassociation (resultRank);
2428+ bool isSqueezed = false ;
2429+ int64_t headOnesCount = 0 ;
2430+ while (headOnesCount < inputRank &&
2431+ inputType.getDimSize (headOnesCount) == 1 ) {
2432+ isSqueezed = true ;
2433+ reassociation[0 ].push_back (headOnesCount++);
2434+ }
2435+
2436+ // TODO: Add support for size-1 dynamic dimensions.
2437+ Value one = rewriter.create <arith::ConstantOp>(
2438+ loc, rewriter.getIntegerAttr (rewriter.getIndexType (), 1 ));
2439+ int64_t j = -1 ;
2440+ for (auto i : llvm::seq<int64_t >(headOnesCount, inputRank)) {
2441+ if (inputType.isDynamicDim (i)) {
2442+ // Make sure that size-1 dynamic dimension does not exist.
2443+ Value dimSize = getDimOp (rewriter, loc, input, i);
2444+ Value dimSizeNotOne = rewriter.create <arith::CmpIOp>(
2445+ loc, arith::CmpIPredicate::ne, dimSize, one);
2446+ rewriter.create <AssertOp>(
2447+ loc, dimSizeNotOne,
2448+ rewriter.getStringAttr (
2449+ " unimplemented: size 1 dynamic dimension is not supported" ));
2450+ ++j;
2451+ } else if (inputType.getDimSize (i) != 1 ) {
2452+ ++j;
2453+ } else {
2454+ // `isSqueezed` checks if the operand tensor type contains at least one
2455+ // unit dimension.
2456+ isSqueezed = true ;
2457+ }
2458+ if (j == resultRank)
2459+ break ;
2460+ reassociation[j].push_back (i);
2461+ }
2462+
2463+ // Make sure that result type rank is compatible with the squeezed size.
2464+ if (j != resultRank - 1 )
2465+ return rewriter.notifyMatchFailure (
2466+ op, " expected output size mismatches with the result type rank" );
2467+
2468+ if (isSqueezed) {
2469+ rewriter.replaceOpWithNewOp <linalg::TensorCollapseShapeOp>(
2470+ op, resultType, input, reassociation);
2471+
2472+ } else {
2473+ // If the operand tensor type does not have any unit dimension,
2474+ // `aten.squeeze` will behave as an identity operation.
2475+ rewriter.replaceOpWithNewOp <tensor::CastOp>(op, resultType, input);
2476+ }
2477+ return success ();
2478+ }
2479+ };
2480+ } // namespace
2481+
23912482namespace {
23922483class ConvertAtenUnsqueezeOp : public OpConversionPattern <AtenUnsqueezeOp> {
23932484public:
@@ -3057,6 +3148,8 @@ class ConvertTorchToLinalg
30573148 AtenRsubScalarOp, AtenLogOp, AtenSqrtOp, AtenFloorOp,
30583149 AtenPowTensorScalarOp, AtenLog2Op, AtenRsqrtOp>();
30593150 patterns.add <ConvertElementwiseOp>(typeConverter, context);
3151+ target.addIllegalOp <AtenSqueezeOp>();
3152+ patterns.add <ConvertAtenSqueezeOp>(typeConverter, context);
30603153 target.addIllegalOp <AtenUnsqueezeOp>();
30613154 patterns.add <ConvertAtenUnsqueezeOp>(typeConverter, context);
30623155 target.addIllegalOp <AtenConv2dOp>();
0 commit comments