Skip to content

Commit

Permalink
AtenDivTensorModeOp folder in 'trunc' rounding mode
Browse files Browse the repository at this point in the history
  • Loading branch information
Vremold committed Aug 12, 2022
1 parent 4a096a6 commit 4bffccd
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 10 deletions.
32 changes: 24 additions & 8 deletions lib/Dialect/Torch/IR/TorchOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -843,22 +843,20 @@ LogicalResult rewrite0DBinaryTensorOp(Operation *op,
auto lhs = getScalarValue(op->getOperand(0), loc, rewriter);
auto rhs = getScalarValue(op->getOperand(1), loc, rewriter);
auto outType = op->getResult(0).getType();
Value alpha =
rewriter.create<Torch::ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));

if (!lhs || !rhs) {
return rewriter.notifyMatchFailure(
op, "only int scalar lhs or rhs is supported");
}
if (isa<AtenSubTensorOp, AtenSubScalarOp, AtenAddTensorOp, AtenAddScalarOp>(
op)) {
alpha = getScalarValue(op->getOperand(2), loc, rewriter);
Value alpha = getScalarValue(op->getOperand(2), loc, rewriter);
if (!alpha) {
return rewriter.notifyMatchFailure(op,
"only int scalar alpha is supported");
}
rhs = rewriter.create<AtenMulIntOp>(loc, rhs, alpha);
}
rhs = rewriter.create<AtenMulIntOp>(loc, rhs, alpha);

// AtenDivTensorModeOp
if (isa<AtenDivTensorModeOp>(op)) {
Expand All @@ -880,10 +878,28 @@ LogicalResult rewrite0DBinaryTensorOp(Operation *op,
quotient);
return success();
}
// For now, "trunc" rounding mode is not supported,
// as it introduces aten.abs, aten.floor, aten.sign ops,
// which adds complexity but helps little in optimization, such as constant
// folding
// For "trunc" rounding mode, insted of canonicalizing it into
// aten.abs, aten.floor, aten.sign and aten.mul.int ops, which adds
// complexity but helps little in optimization (such as constant folding),
// we are trying to fold it.
if (roundingMode == "trunc") {
int64_t lhsInt;
int64_t rhsInt;
if (!matchPattern(lhs, m_TorchConstantInt(&lhsInt))) {
return failure();
}
if (!matchPattern(rhs, m_TorchConstantInt(&rhsInt))) {
return failure();
}

int64_t result = (int64_t)std::trunc((double)lhsInt / rhsInt);
Value resultScalar =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(result));
rewriter.replaceOpWithNewOp<PrimNumToTensorScalarOp>(op, outType,
resultScalar);
return success();
}

return failure();
}

Expand Down
36 changes: 34 additions & 2 deletions test/Dialect/Torch/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1432,8 +1432,8 @@ func.func @torch.aten.div.Tensor_mode$canonicalize_literal_0d() -> !torch.vtenso

// CHECK-LABEL: func.func @torch.aten.div.Tensor_mode$canonicalize_numtotensor_0d() -> !torch.vtensor<[],si64> {
// CHECK: %[[INT3:.*]] = torch.constant.int 3
// CHECK: %[[INT2:.*]] = torch.constant.int 2
// CHECK: %[[INT6:.*]] = torch.constant.int 6
// CHECK: %[[INT2:.*]] = torch.constant.int 2
// CHECK: %[[PR1:.*]] = torch.prim.NumToTensor.Scalar %[[INT2]] : !torch.int -> !torch.vtensor<[],si64>
// CHECK: %[[PR2:.*]] = torch.prim.NumToTensor.Scalar %[[INT6]] : !torch.int -> !torch.vtensor<[],si64>
// CHECK: %[[PR3:.*]] = torch.prim.NumToTensor.Scalar %[[INT3]] : !torch.int -> !torch.vtensor<[],si64>
Expand Down Expand Up @@ -1574,8 +1574,8 @@ func.func @torch.aten.mul.Tensor$canonicalize_literal_0d() -> !torch.vtensor<[],

// CHECK-LABEL: func.func @torch.aten.mul.Tensor$canonicalize_numtotensor_0d() -> !torch.vtensor<[],si64> {
// CHECK: %[[INT6:.*]] = torch.constant.int 6
// CHECK: %[[INT3:.*]] = torch.constant.int 3
// CHECK: %[[INT2:.*]] = torch.constant.int 2
// CHECK: %[[INT3:.*]] = torch.constant.int 3
// CHECK: %[[PR0:.*]] = torch.prim.NumToTensor.Scalar %[[INT2]] : !torch.int -> !torch.vtensor<[],si64>
// CHECK: %[[PR1:.*]] = torch.prim.NumToTensor.Scalar %[[INT3]] : !torch.int -> !torch.vtensor<[],si64>
// CHECK: %[[PR2:.*]] = torch.prim.NumToTensor.Scalar %[[INT6]] : !torch.int -> !torch.vtensor<[],si64>
Expand All @@ -1589,3 +1589,35 @@ func.func @torch.aten.mul.Tensor$canonicalize_numtotensor_0d() -> !torch.vtensor
return %2 : !torch.vtensor<[],si64>
}

// CHECK-LABEL: func.func @torch.aten.div.Tensor_mode$canonicalize_numtotensor_0d_trunc() -> !torch.vtensor<[],si64> {
// CHECK: %[[INT3:.*]] = torch.constant.int 3
// CHECK: %[[INT6:.*]] = torch.constant.int 6
// CHECK: %[[INT2:.*]] = torch.constant.int 2
// CHECK: %[[PR1:.*]] = torch.prim.NumToTensor.Scalar %[[INT2]] : !torch.int -> !torch.vtensor<[],si64>
// CHECK: %[[PR2:.*]] = torch.prim.NumToTensor.Scalar %[[INT6]] : !torch.int -> !torch.vtensor<[],si64>
// CHECK: %[[PR3:.*]] = torch.prim.NumToTensor.Scalar %[[INT3]] : !torch.int -> !torch.vtensor<[],si64>
// CHECK: return %[[PR3]] : !torch.vtensor<[],si64>
func.func @torch.aten.div.Tensor_mode$canonicalize_numtotensor_0d_trunc() -> !torch.vtensor<[],si64> {
%int6 = torch.constant.int 6
%int2 = torch.constant.int 2
%str = torch.constant.str "trunc"
%0 = torch.prim.NumToTensor.Scalar %int2 : !torch.int -> !torch.vtensor<[],si64>
%1 = torch.prim.NumToTensor.Scalar %int6 : !torch.int -> !torch.vtensor<[],si64>
%2 = torch.aten.div.Tensor_mode %1, %0, %str : !torch.vtensor<[],si64>, !torch.vtensor<[],si64>, !torch.str -> !torch.vtensor<[],si64>
return %2 : !torch.vtensor<[],si64>
}

// CHECK-LABEL: func.func @torch.aten.div.Tensor_mode$canonicalize_literal_0d_trunc() -> !torch.vtensor<[],si64> {
// CHECK: %[[INT3:.*]] = torch.constant.int 3
// CHECK: %[[INT6:.*]] = torch.constant.int 6
// CHECK: %[[PR1:.*]] = torch.prim.NumToTensor.Scalar %[[INT6]] : !torch.int -> !torch.vtensor<[],si64>
// CHECK: %[[PR2:.*]] = torch.prim.NumToTensor.Scalar %[[INT3]] : !torch.int -> !torch.vtensor<[],si64>
// CHECK: return %[[PR2]] : !torch.vtensor<[],si64>
func.func @torch.aten.div.Tensor_mode$canonicalize_literal_0d_trunc() -> !torch.vtensor<[],si64> {
%int6 = torch.constant.int 6
%str = torch.constant.str "trunc"
%0 = torch.vtensor.literal(dense<2> : tensor<si64>) : !torch.vtensor<[],si64>
%1 = torch.prim.NumToTensor.Scalar %int6 : !torch.int -> !torch.vtensor<[],si64>
%2 = torch.aten.div.Tensor_mode %1, %0, %str : !torch.vtensor<[],si64>, !torch.vtensor<[],si64>, !torch.str -> !torch.vtensor<[],si64>
return %2 : !torch.vtensor<[],si64>
}

0 comments on commit 4bffccd

Please sign in to comment.