Skip to content

Commit aeffd16

Browse files
author
Tanyo Kwok
committed
Add native_dropout_backward & native_layer_norm_backward decomposition (#15)
1 parent 8429920 commit aeffd16

File tree

1 file changed

+23
-0
lines changed

1 file changed

+23
-0
lines changed

lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1338,6 +1338,25 @@ class DecomposeAtenDropoutOp : public OpRewritePattern<AtenDropoutOp> {
13381338
};
13391339
} // namespace
13401340

1341+
// grad_output * mask * scale
1342+
namespace {
1343+
class DecomposeAtenNativeDropoutBackwardOp
1344+
: public OpRewritePattern<AtenNativeDropoutBackwardOp> {
1345+
public:
1346+
using OpRewritePattern::OpRewritePattern;
1347+
LogicalResult matchAndRewrite(AtenNativeDropoutBackwardOp op,
1348+
PatternRewriter &rewriter) const override {
1349+
Location loc = op.getLoc();
1350+
1351+
Value maskedGradOutput = rewriter.create<AtenMulTensorOp>(
1352+
loc, op.getType(), op.grad_output(), op.mask());
1353+
rewriter.replaceOpWithNewOp<AtenMulScalarOp>(op, op.getType(),
1354+
maskedGradOutput, op.scale());
1355+
return success();
1356+
}
1357+
};
1358+
} // namespace
1359+
13411360
// Decompose aten.var into: aten.var.dim op.
13421361
namespace {
13431362
class DecomposeAtenVarOp : public OpRewritePattern<AtenVarOp> {
@@ -3087,6 +3106,8 @@ class DecomposeComplexOpsPass
30873106
patterns.add<DecomposeAtenLayerNormOp>(context);
30883107
target.addIllegalOp<AtenNativeLayerNormOp>();
30893108
patterns.add<DecomposeAtenNativeLayerNormOp>(context);
3109+
target.addIllegalOp<AtenNativeLayerNormBackwardOp>();
3110+
patterns.add<DecomposeAtenNativeLayerNormBackwardOp>(context);
30903111

30913112
target.addIllegalOp<AtenNativeBatchNormOp>();
30923113
patterns.add<DecomposeAtenNativeBatchNormOp>(context);
@@ -3160,6 +3181,8 @@ class DecomposeComplexOpsPass
31603181
target.addIllegalOp<Aten_ToCopyOp>();
31613182
patterns.add<DecomposeAtenDropoutOp>(context);
31623183
target.addIllegalOp<AtenDropoutOp>();
3184+
patterns.add<DecomposeAtenNativeDropoutBackwardOp>(context);
3185+
target.addIllegalOp<AtenNativeDropoutBackwardOp>();
31633186
target.addIllegalOp<AtenNewEmptyOp>();
31643187
patterns.add<DecomposeAtenNewEmptyOp>(context);
31653188
patterns.add<DecomposeAtenIndexPutHackedTwinOp>(context);

0 commit comments

Comments
 (0)