2929using namespace mlir ;
3030using namespace mlir ::torch;
3131using namespace mlir ::torch::Torch;
32+ using namespace mlir ::torch::TorchConversion;
3233using namespace mlir ::torch::torch_to_mhlo;
3334
3435LogicalResult broadcastRanks (PatternRewriter &rewriter, Operation *op,
@@ -166,16 +167,19 @@ class ConvertAtenUnaryFPOnlyOp : public OpConversionPattern<AtenOpT> {
166167 if (!selfTy)
167168 return op.emitError (" only Tensor types supported in MHLO" );
168169
169- if (selfTy.getElementType ().isa <mlir::FloatType>()) {
170+ auto outTy = OpConversionPattern<AtenOpT>::getTypeConverter ()->convertType (
171+ op.getType ());
172+ if (selfTy != outTy) {
173+ auto out = rewriter.create <MhloOpT>(op.getLoc (), selfTy, self);
174+ rewriter.replaceOpWithNewOp <tensor::CastOp>(op, outTy, out);
175+ return success ();
176+ } else {
170177 rewriter.replaceOpWithNewOp <MhloOpT>(
171178 op,
172179 OpConversionPattern<AtenOpT>::getTypeConverter ()->convertType (
173180 op.getType ()),
174181 self);
175182 return success ();
176- } else {
177- return op.emitError (
178- " only floating-point datatype legalization supported" );
179183 }
180184 }
181185};
@@ -345,15 +349,10 @@ class ConvertAtenMulDivOp : public OpConversionPattern<AtenOpT> {
345349 } else if (!rhsType) {
346350 rhs = mhlo::scalarToMhloTensor (rewriter, op, adaptor.getOther (), outElemTy);
347351 }
348- DenseIntElementsAttr bcastDimensions;
349- lhs = mhlo::promoteType (rewriter, lhs, outType);
350- rhs = mhlo::promoteType (rewriter, rhs, outType);
351- auto loc = op.getLoc ();
352- Value result =
353- rewriter.create <ChloOpT>(loc, outType, lhs, rhs, bcastDimensions);
354-
355352 if (!isa<AtenDivTensorModeOp>(op)) {
356- rewriter.replaceOp (op, result);
353+ lhs = mhlo::promoteType (rewriter, lhs, outType);
354+ rhs = mhlo::promoteType (rewriter, rhs, outType);
355+ rewriter.replaceOpWithNewOp <ChloOpT>(op, outType, lhs, rhs, nullptr );
357356 return success ();
358357 }
359358
@@ -365,6 +364,17 @@ class ConvertAtenMulDivOp : public OpConversionPattern<AtenOpT> {
365364 return rewriter.notifyMatchFailure (
366365 op, " only support constant str rounding mode" );
367366
367+ auto computeTy = outType;
368+ if (outElemTy.isIntOrIndex ()) {
369+ computeTy =
370+ RankedTensorType::get (outType.getShape (), rewriter.getF32Type ());
371+ }
372+ lhs = mhlo::promoteType (rewriter, lhs, computeTy);
373+ rhs = mhlo::promoteType (rewriter, rhs, computeTy);
374+ auto loc = op.getLoc ();
375+ auto result =
376+ rewriter.create <ChloOpT>(loc, computeTy, lhs, rhs, nullptr ).getResult ();
377+
368378 if (roundingMode == " trunc" ) {
369379 // "trunc" - rounds the results of the division towards zero. Equivalent
370380 // to C-style integer division.
@@ -378,7 +388,7 @@ class ConvertAtenMulDivOp : public OpConversionPattern<AtenOpT> {
378388 // floor division in Python (the // operator)
379389 result = rewriter.create <mhlo::FloorOp>(loc, result).getResult ();
380390 }
381- rewriter.replaceOp (op, result);
391+ rewriter.replaceOpWithNewOp <mhlo::ConvertOp> (op, outType , result);
382392 return success ();
383393 }
384394};
@@ -836,7 +846,11 @@ LogicalResult ConvertAtenOp<AtenReluOp>::matchAndRewrite(
836846 APFloat::getZero (lhsElemTy.cast <mlir::FloatType>().getFloatSemantics (),
837847 false ),
838848 lhs);
839- rewriter.replaceOpWithNewOp <mhlo::MaxOp>(op, lhs, zeroTensor);
849+ auto outType = getTypeConverter ()
850+ ->convertType (op.getType ())
851+ .template dyn_cast <TensorType>();
852+
853+ rewriter.replaceOpWithNewOp <mhlo::MaxOp>(op, outType, lhs, zeroTensor);
840854 return success ();
841855}
842856
@@ -862,7 +876,11 @@ LogicalResult ConvertAtenOp<AtenGeluOp>::matchAndRewrite(
862876 auto erf = rewriter.create <mlir::chlo::ErfOp>(loc, erfElement);
863877 auto erfAdd = rewriter.create <mhlo::AddOp>(loc, erf, one);
864878 auto halfMul = rewriter.create <mhlo::MulOp>(loc, erfAdd, half);
865- rewriter.replaceOpWithNewOp <mhlo::MulOp>(op, input, halfMul);
879+ auto outType = getTypeConverter ()
880+ ->convertType (op.getType ())
881+ .template dyn_cast <TensorType>();
882+
883+ rewriter.replaceOpWithNewOp <mhlo::MulOp>(op, outType, input, halfMul);
866884 return success ();
867885}
868886
@@ -1463,7 +1481,6 @@ void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality(
14631481 INSERT_ATENOP_PATTERN (ValueTensorLiteralOp);
14641482 INSERT_ATENOP_PATTERN (AtenReciprocalOp);
14651483 INSERT_ATENOP_PATTERN (PrimNumToTensorScalarOp);
1466- INSERT_ATENOP_PATTERN (AtenContiguousOp);
14671484
14681485 INSERT_ATENOP_PATTERN (AtenReluOp);
14691486 INSERT_ATENOP_PATTERN (AtenGeluOp);
0 commit comments