@@ -1338,6 +1338,25 @@ class DecomposeAtenDropoutOp : public OpRewritePattern<AtenDropoutOp> {
13381338};
13391339} //  namespace
13401340
1341+ //  grad_output * mask * scale
1342+ namespace  {
1343+ class  DecomposeAtenNativeDropoutBackwardOp 
1344+     : public OpRewritePattern<AtenNativeDropoutBackwardOp> {
1345+ public: 
1346+   using  OpRewritePattern::OpRewritePattern;
1347+   LogicalResult matchAndRewrite (AtenNativeDropoutBackwardOp op,
1348+                                 PatternRewriter &rewriter) const  override  {
1349+     Location loc = op.getLoc ();
1350+ 
1351+     Value maskedGradOutput = rewriter.create <AtenMulTensorOp>(
1352+         loc, op.getType (), op.grad_output (), op.mask ());
1353+     rewriter.replaceOpWithNewOp <AtenMulScalarOp>(op, op.getType (),
1354+                                                  maskedGradOutput, op.scale ());
1355+     return  success ();
1356+   }
1357+ };
1358+ } //  namespace
1359+ 
13411360//  Decompose aten.var into: aten.var.dim op.
13421361namespace  {
13431362class  DecomposeAtenVarOp  : public  OpRewritePattern <AtenVarOp> {
@@ -3087,6 +3106,8 @@ class DecomposeComplexOpsPass
30873106    patterns.add <DecomposeAtenLayerNormOp>(context);
30883107    target.addIllegalOp <AtenNativeLayerNormOp>();
30893108    patterns.add <DecomposeAtenNativeLayerNormOp>(context);
3109+     target.addIllegalOp <AtenNativeLayerNormBackwardOp>();
3110+     patterns.add <DecomposeAtenNativeLayerNormBackwardOp>(context);
30903111
30913112    target.addIllegalOp <AtenNativeBatchNormOp>();
30923113    patterns.add <DecomposeAtenNativeBatchNormOp>(context);
@@ -3160,6 +3181,8 @@ class DecomposeComplexOpsPass
31603181    target.addIllegalOp <Aten_ToCopyOp>();
31613182    patterns.add <DecomposeAtenDropoutOp>(context);
31623183    target.addIllegalOp <AtenDropoutOp>();
3184+     patterns.add <DecomposeAtenNativeDropoutBackwardOp>(context);
3185+     target.addIllegalOp <AtenNativeDropoutBackwardOp>();
31633186    target.addIllegalOp <AtenNewEmptyOp>();
31643187    patterns.add <DecomposeAtenNewEmptyOp>(context);
31653188    patterns.add <DecomposeAtenIndexPutHackedTwinOp>(context);
0 commit comments