Skip to content

Commit aedb725

Browse files
committed
Rewrite op to follow onnx implementation
1 parent ebceb63 commit aedb725

File tree

2 files changed

+130
-112
lines changed

2 files changed

+130
-112
lines changed

lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp

Lines changed: 74 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -2079,83 +2079,93 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
20792079
binder.tensorResultType(resultType)) {
20802080
return failure();
20812081
}
2082-
2083-
// If periodic is zero, subtract one from size before proceeding
2084-
if (periodic == 0) {
2085-
Value constantOne = rewriter.create<Torch::ConstantIntOp>(
2086-
binder.getLoc(), rewriter.getI64IntegerAttr(1));
2087-
size = rewriter.create<Torch::AtenSubScalarOp>(
2088-
binder.getLoc(), size.getType(), size, constantOne, constantOne);
2089-
}
2090-
2091-
Value one = rewriter.create<Torch::ConstantIntOp>(
2092-
binder.getLoc(), rewriter.getI64IntegerAttr(1));
2093-
Value zero = rewriter.create<Torch::ConstantIntOp>(
2094-
binder.getLoc(), rewriter.getI64IntegerAttr(0));
2095-
2096-
Value scalarLimit = getItemOp<Torch::IntType>(binder, rewriter, size);
2097-
2098-
Value none = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
2099-
Value rangeArr = rewriter.create<Torch::AtenArangeStartStepOp>(
2100-
binder.getLoc(), resultType, zero, scalarLimit, one, none, none,
2101-
none, none);
2102-
2103-
// Required contants
2104-
constexpr double pi = llvm::numbers::pi;
2105-
Value alpha = rewriter.create<Torch::ConstantFloatOp>(
2082+
double isPeriodicFp = (double)periodic;
2083+
Value a0 = rewriter.create<Torch::ConstantFloatOp>(
21062084
binder.getLoc(),
21072085
rewriter.getFloatAttr(rewriter.getF64Type(), 0.42));
2108-
Value beta = rewriter.create<Torch::ConstantFloatOp>(
2109-
binder.getLoc(),
2110-
rewriter.getFloatAttr(rewriter.getF64Type(), 0.08));
2111-
Value negHalf = rewriter.create<Torch::ConstantFloatOp>(
2086+
Value a1 = rewriter.create<Torch::ConstantFloatOp>(
21122087
binder.getLoc(),
21132088
rewriter.getFloatAttr(rewriter.getF64Type(), -0.5));
2114-
Value twicePi = rewriter.create<Torch::ConstantFloatOp>(
2089+
Value a2 = rewriter.create<Torch::ConstantFloatOp>(
21152090
binder.getLoc(),
2116-
rewriter.getFloatAttr(rewriter.getF64Type(), 2.0 * pi));
2117-
Value fourPi = rewriter.create<Torch::ConstantFloatOp>(
2091+
rewriter.getFloatAttr(rewriter.getF64Type(), 0.08));
2092+
Value zero = rewriter.create<Torch::ConstantFloatOp>(
2093+
binder.getLoc(), rewriter.getF64FloatAttr(0.0));
2094+
Value one = rewriter.create<Torch::ConstantFloatOp>(
2095+
binder.getLoc(), rewriter.getF64FloatAttr(1.0));
2096+
Value two = rewriter.create<Torch::ConstantFloatOp>(
2097+
binder.getLoc(), rewriter.getF64FloatAttr(2.0));
2098+
2099+
constexpr double pi = llvm::numbers::pi;
2100+
Value tau = rewriter.create<Torch::ConstantFloatOp>(
21182101
binder.getLoc(),
2119-
rewriter.getFloatAttr(rewriter.getF64Type(), 4.0 * pi));
2120-
2121-
// Calculate the window function
2122-
Value productTimesTwoPi = rewriter.create<Torch::AtenMulScalarOp>(
2123-
binder.getLoc(), resultType, rangeArr, twicePi);
2124-
Value productTimesFourPi = rewriter.create<Torch::AtenMulScalarOp>(
2125-
binder.getLoc(), resultType, rangeArr, fourPi);
2126-
2127-
Value divTimesFourPi = rewriter.create<Torch::AtenDivTensorOp>(
2128-
binder.getLoc(), resultType, productTimesFourPi, size);
2129-
Value divTimesTwoPi = rewriter.create<Torch::AtenDivTensorOp>(
2130-
binder.getLoc(), resultType, productTimesTwoPi, size);
2131-
2132-
Value cosFunctionInitial = rewriter.create<Torch::AtenCosOp>(
2133-
binder.getLoc(), resultType, divTimesTwoPi);
2134-
Value cosTimesNegHalf = rewriter.create<Torch::AtenMulScalarOp>(
2135-
binder.getLoc(), resultType, cosFunctionInitial, negHalf);
2136-
Value cosFunctionFinal = rewriter.create<Torch::AtenCosOp>(
2137-
binder.getLoc(), resultType, divTimesFourPi);
2138-
Value cosFinalTimesBeta = rewriter.create<Torch::AtenMulScalarOp>(
2139-
binder.getLoc(), resultType, cosFunctionFinal, beta);
2140-
2141-
Value constOne = rewriter.create<Torch::ConstantIntOp>(
2142-
binder.getLoc(), rewriter.getI64IntegerAttr(1));
2143-
Value valAdd = rewriter.create<Torch::AtenAddTensorOp>(
2144-
binder.getLoc(), resultType, cosTimesNegHalf, cosFinalTimesBeta,
2145-
constOne);
2146-
Value finalResult = rewriter.create<Torch::AtenAddScalarOp>(
2147-
binder.getLoc(), resultType, valAdd, alpha, constOne);
2102+
rewriter.getFloatAttr(rewriter.getF64Type(), 2.0 * pi));
2103+
2104+
Value noneVal = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
2105+
Value cstFalse =
2106+
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), false);
2107+
Value float32Type = rewriter.create<Torch::ConstantIntOp>(
2108+
binder.getLoc(), rewriter.getI64IntegerAttr(/*float32Type*/ 6));
2109+
2110+
Value periodicSizeFloat = rewriter.create<Torch::AtenToDtypeOp>(
2111+
binder.getLoc(), size.getType(), size, float32Type, cstFalse,
2112+
cstFalse, noneVal);
2113+
Value symmetricSizeFloat = rewriter.create<Torch::AtenSubScalarOp>(
2114+
binder.getLoc(), periodicSizeFloat.getType(), periodicSizeFloat,
2115+
one, one);
2116+
2117+
Value isPeriodic = rewriter.create<Torch::ConstantFloatOp>(
2118+
binder.getLoc(), rewriter.getF64FloatAttr(isPeriodicFp));
2119+
Value isSymmetricFloat = rewriter.create<Torch::ConstantFloatOp>(
2120+
binder.getLoc(), rewriter.getF64FloatAttr(1.0 - isPeriodicFp));
2121+
2122+
Value periodicComponent = rewriter.create<Torch::AtenMulScalarOp>(
2123+
binder.getLoc(), periodicSizeFloat.getType(), periodicSizeFloat,
2124+
isPeriodic);
2125+
Value symmetricComponent = rewriter.create<Torch::AtenMulScalarOp>(
2126+
binder.getLoc(), symmetricSizeFloat.getType(), symmetricSizeFloat,
2127+
isSymmetricFloat);
2128+
Value sizeFloat = rewriter.create<Torch::AtenAddTensorOp>(
2129+
binder.getLoc(), symmetricComponent.getType(), symmetricComponent,
2130+
periodicComponent, one);
2131+
2132+
Value scalarLimit =
2133+
getItemOp<Torch::IntType>(binder, rewriter, periodicSizeFloat);
2134+
2135+
Value rangeArr = rewriter.create<Torch::AtenArangeStartStepOp>(
2136+
binder.getLoc(), resultType, zero, scalarLimit, one, noneVal,
2137+
noneVal, noneVal, noneVal);
2138+
2139+
Value rangeTimesTau = rewriter.create<Torch::AtenMulScalarOp>(
2140+
binder.getLoc(), resultType, rangeArr, tau);
2141+
Value rangeAngular = rewriter.create<Torch::AtenDivTensorOp>(
2142+
binder.getLoc(), resultType, rangeTimesTau, sizeFloat);
2143+
Value twoRangeAngular = rewriter.create<Torch::AtenMulScalarOp>(
2144+
binder.getLoc(), resultType, rangeAngular, two);
2145+
2146+
Value cosRangeAngular = rewriter.create<Torch::AtenCosOp>(
2147+
binder.getLoc(), resultType, rangeAngular);
2148+
Value cosTwoRangeAngular = rewriter.create<Torch::AtenCosOp>(
2149+
binder.getLoc(), resultType, twoRangeAngular);
2150+
2151+
Value a1Component = rewriter.create<Torch::AtenMulScalarOp>(
2152+
binder.getLoc(), resultType, cosRangeAngular, a1);
2153+
Value a2Component = rewriter.create<Torch::AtenMulScalarOp>(
2154+
binder.getLoc(), resultType, cosTwoRangeAngular, a2);
2155+
2156+
Value subA1Component = rewriter.create<Torch::AtenAddScalarOp>(
2157+
binder.getLoc(), resultType, a1Component, a0, one);
2158+
Value result = rewriter.create<Torch::AtenAddTensorOp>(
2159+
binder.getLoc(), resultType, subA1Component, a2Component, one);
21482160

21492161
int64_t dtypeIntTorch = onnxDtypeIntToTorchDtypeInt(output_datatype);
21502162
Value outputDtype = rewriter.create<Torch::ConstantIntOp>(
21512163
binder.getLoc(), rewriter.getType<Torch::IntType>(),
21522164
rewriter.getIntegerAttr(rewriter.getIntegerType(64),
21532165
dtypeIntTorch));
2154-
Value noneVal = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
2155-
Value cstFalse =
2156-
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), false);
2166+
21572167
rewriter.replaceOpWithNewOp<Torch::AtenToDtypeOp>(
2158-
binder.op, resultType, finalResult, outputDtype,
2168+
binder.op, resultType, result, outputDtype,
21592169
/*non_blocking=*/cstFalse, /*copy=*/cstFalse,
21602170
/*memory_format=*/noneVal);
21612171
return success();

0 commit comments

Comments
 (0)