Skip to content

Commit

Permalink
Rewrite op to follow onnx implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
vinayakdsci committed Apr 22, 2024
1 parent ebceb63 commit aedb725
Show file tree
Hide file tree
Showing 2 changed files with 130 additions and 112 deletions.
138 changes: 74 additions & 64 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2079,83 +2079,93 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
binder.tensorResultType(resultType)) {
return failure();
}

// If periodic is zero, subtract one from size before proceeding
if (periodic == 0) {
Value constantOne = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(1));
size = rewriter.create<Torch::AtenSubScalarOp>(
binder.getLoc(), size.getType(), size, constantOne, constantOne);
}

Value one = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(1));
Value zero = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(0));

Value scalarLimit = getItemOp<Torch::IntType>(binder, rewriter, size);

Value none = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
Value rangeArr = rewriter.create<Torch::AtenArangeStartStepOp>(
binder.getLoc(), resultType, zero, scalarLimit, one, none, none,
none, none);

// Required contants
constexpr double pi = llvm::numbers::pi;
Value alpha = rewriter.create<Torch::ConstantFloatOp>(
double isPeriodicFp = (double)periodic;
Value a0 = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(),
rewriter.getFloatAttr(rewriter.getF64Type(), 0.42));
Value beta = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(),
rewriter.getFloatAttr(rewriter.getF64Type(), 0.08));
Value negHalf = rewriter.create<Torch::ConstantFloatOp>(
Value a1 = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(),
rewriter.getFloatAttr(rewriter.getF64Type(), -0.5));
Value twicePi = rewriter.create<Torch::ConstantFloatOp>(
Value a2 = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(),
rewriter.getFloatAttr(rewriter.getF64Type(), 2.0 * pi));
Value fourPi = rewriter.create<Torch::ConstantFloatOp>(
rewriter.getFloatAttr(rewriter.getF64Type(), 0.08));
Value zero = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(), rewriter.getF64FloatAttr(0.0));
Value one = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(), rewriter.getF64FloatAttr(1.0));
Value two = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(), rewriter.getF64FloatAttr(2.0));

constexpr double pi = llvm::numbers::pi;
Value tau = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(),
rewriter.getFloatAttr(rewriter.getF64Type(), 4.0 * pi));

// Calculate the window function
Value productTimesTwoPi = rewriter.create<Torch::AtenMulScalarOp>(
binder.getLoc(), resultType, rangeArr, twicePi);
Value productTimesFourPi = rewriter.create<Torch::AtenMulScalarOp>(
binder.getLoc(), resultType, rangeArr, fourPi);

Value divTimesFourPi = rewriter.create<Torch::AtenDivTensorOp>(
binder.getLoc(), resultType, productTimesFourPi, size);
Value divTimesTwoPi = rewriter.create<Torch::AtenDivTensorOp>(
binder.getLoc(), resultType, productTimesTwoPi, size);

Value cosFunctionInitial = rewriter.create<Torch::AtenCosOp>(
binder.getLoc(), resultType, divTimesTwoPi);
Value cosTimesNegHalf = rewriter.create<Torch::AtenMulScalarOp>(
binder.getLoc(), resultType, cosFunctionInitial, negHalf);
Value cosFunctionFinal = rewriter.create<Torch::AtenCosOp>(
binder.getLoc(), resultType, divTimesFourPi);
Value cosFinalTimesBeta = rewriter.create<Torch::AtenMulScalarOp>(
binder.getLoc(), resultType, cosFunctionFinal, beta);

Value constOne = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(1));
Value valAdd = rewriter.create<Torch::AtenAddTensorOp>(
binder.getLoc(), resultType, cosTimesNegHalf, cosFinalTimesBeta,
constOne);
Value finalResult = rewriter.create<Torch::AtenAddScalarOp>(
binder.getLoc(), resultType, valAdd, alpha, constOne);
rewriter.getFloatAttr(rewriter.getF64Type(), 2.0 * pi));

Value noneVal = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
Value cstFalse =
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), false);
Value float32Type = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(/*float32Type*/ 6));

Value periodicSizeFloat = rewriter.create<Torch::AtenToDtypeOp>(
binder.getLoc(), size.getType(), size, float32Type, cstFalse,
cstFalse, noneVal);
Value symmetricSizeFloat = rewriter.create<Torch::AtenSubScalarOp>(
binder.getLoc(), periodicSizeFloat.getType(), periodicSizeFloat,
one, one);

Value isPeriodic = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(), rewriter.getF64FloatAttr(isPeriodicFp));
Value isSymmetricFloat = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(), rewriter.getF64FloatAttr(1.0 - isPeriodicFp));

Value periodicComponent = rewriter.create<Torch::AtenMulScalarOp>(
binder.getLoc(), periodicSizeFloat.getType(), periodicSizeFloat,
isPeriodic);
Value symmetricComponent = rewriter.create<Torch::AtenMulScalarOp>(
binder.getLoc(), symmetricSizeFloat.getType(), symmetricSizeFloat,
isSymmetricFloat);
Value sizeFloat = rewriter.create<Torch::AtenAddTensorOp>(
binder.getLoc(), symmetricComponent.getType(), symmetricComponent,
periodicComponent, one);

Value scalarLimit =
getItemOp<Torch::IntType>(binder, rewriter, periodicSizeFloat);

Value rangeArr = rewriter.create<Torch::AtenArangeStartStepOp>(
binder.getLoc(), resultType, zero, scalarLimit, one, noneVal,
noneVal, noneVal, noneVal);

Value rangeTimesTau = rewriter.create<Torch::AtenMulScalarOp>(
binder.getLoc(), resultType, rangeArr, tau);
Value rangeAngular = rewriter.create<Torch::AtenDivTensorOp>(
binder.getLoc(), resultType, rangeTimesTau, sizeFloat);
Value twoRangeAngular = rewriter.create<Torch::AtenMulScalarOp>(
binder.getLoc(), resultType, rangeAngular, two);

Value cosRangeAngular = rewriter.create<Torch::AtenCosOp>(
binder.getLoc(), resultType, rangeAngular);
Value cosTwoRangeAngular = rewriter.create<Torch::AtenCosOp>(
binder.getLoc(), resultType, twoRangeAngular);

Value a1Component = rewriter.create<Torch::AtenMulScalarOp>(
binder.getLoc(), resultType, cosRangeAngular, a1);
Value a2Component = rewriter.create<Torch::AtenMulScalarOp>(
binder.getLoc(), resultType, cosTwoRangeAngular, a2);

Value subA1Component = rewriter.create<Torch::AtenAddScalarOp>(
binder.getLoc(), resultType, a1Component, a0, one);
Value result = rewriter.create<Torch::AtenAddTensorOp>(
binder.getLoc(), resultType, subA1Component, a2Component, one);

int64_t dtypeIntTorch = onnxDtypeIntToTorchDtypeInt(output_datatype);
Value outputDtype = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64),
dtypeIntTorch));
Value noneVal = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
Value cstFalse =
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), false);

rewriter.replaceOpWithNewOp<Torch::AtenToDtypeOp>(
binder.op, resultType, finalResult, outputDtype,
binder.op, resultType, result, outputDtype,
/*non_blocking=*/cstFalse, /*copy=*/cstFalse,
/*memory_format=*/noneVal);
return success();
Expand Down
Loading

0 comments on commit aedb725

Please sign in to comment.