@@ -2079,83 +2079,93 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
2079
2079
binder.tensorResultType (resultType)) {
2080
2080
return failure ();
2081
2081
}
2082
-
2083
- // If periodic is zero, subtract one from size before proceeding
2084
- if (periodic == 0 ) {
2085
- Value constantOne = rewriter.create <Torch::ConstantIntOp>(
2086
- binder.getLoc (), rewriter.getI64IntegerAttr (1 ));
2087
- size = rewriter.create <Torch::AtenSubScalarOp>(
2088
- binder.getLoc (), size.getType (), size, constantOne, constantOne);
2089
- }
2090
-
2091
- Value one = rewriter.create <Torch::ConstantIntOp>(
2092
- binder.getLoc (), rewriter.getI64IntegerAttr (1 ));
2093
- Value zero = rewriter.create <Torch::ConstantIntOp>(
2094
- binder.getLoc (), rewriter.getI64IntegerAttr (0 ));
2095
-
2096
- Value scalarLimit = getItemOp<Torch::IntType>(binder, rewriter, size);
2097
-
2098
- Value none = rewriter.create <Torch::ConstantNoneOp>(binder.getLoc ());
2099
- Value rangeArr = rewriter.create <Torch::AtenArangeStartStepOp>(
2100
- binder.getLoc (), resultType, zero, scalarLimit, one, none, none,
2101
- none, none);
2102
-
2103
- // Required contants
2104
- constexpr double pi = llvm::numbers::pi;
2105
- Value alpha = rewriter.create <Torch::ConstantFloatOp>(
2082
+ double isPeriodicFp = (double )periodic;
2083
+ Value a0 = rewriter.create <Torch::ConstantFloatOp>(
2106
2084
binder.getLoc (),
2107
2085
rewriter.getFloatAttr (rewriter.getF64Type (), 0.42 ));
2108
- Value beta = rewriter.create <Torch::ConstantFloatOp>(
2109
- binder.getLoc (),
2110
- rewriter.getFloatAttr (rewriter.getF64Type (), 0.08 ));
2111
- Value negHalf = rewriter.create <Torch::ConstantFloatOp>(
2086
+ Value a1 = rewriter.create <Torch::ConstantFloatOp>(
2112
2087
binder.getLoc (),
2113
2088
rewriter.getFloatAttr (rewriter.getF64Type (), -0.5 ));
2114
- Value twicePi = rewriter.create <Torch::ConstantFloatOp>(
2089
+ Value a2 = rewriter.create <Torch::ConstantFloatOp>(
2115
2090
binder.getLoc (),
2116
- rewriter.getFloatAttr (rewriter.getF64Type (), 2.0 * pi));
2117
- Value fourPi = rewriter.create <Torch::ConstantFloatOp>(
2091
+ rewriter.getFloatAttr (rewriter.getF64Type (), 0.08 ));
2092
+ Value zero = rewriter.create <Torch::ConstantFloatOp>(
2093
+ binder.getLoc (), rewriter.getF64FloatAttr (0.0 ));
2094
+ Value one = rewriter.create <Torch::ConstantFloatOp>(
2095
+ binder.getLoc (), rewriter.getF64FloatAttr (1.0 ));
2096
+ Value two = rewriter.create <Torch::ConstantFloatOp>(
2097
+ binder.getLoc (), rewriter.getF64FloatAttr (2.0 ));
2098
+
2099
+ constexpr double pi = llvm::numbers::pi;
2100
+ Value tau = rewriter.create <Torch::ConstantFloatOp>(
2118
2101
binder.getLoc (),
2119
- rewriter.getFloatAttr (rewriter.getF64Type (), 4.0 * pi));
2120
-
2121
- // Calculate the window function
2122
- Value productTimesTwoPi = rewriter.create <Torch::AtenMulScalarOp>(
2123
- binder.getLoc (), resultType, rangeArr, twicePi);
2124
- Value productTimesFourPi = rewriter.create <Torch::AtenMulScalarOp>(
2125
- binder.getLoc (), resultType, rangeArr, fourPi);
2126
-
2127
- Value divTimesFourPi = rewriter.create <Torch::AtenDivTensorOp>(
2128
- binder.getLoc (), resultType, productTimesFourPi, size);
2129
- Value divTimesTwoPi = rewriter.create <Torch::AtenDivTensorOp>(
2130
- binder.getLoc (), resultType, productTimesTwoPi, size);
2131
-
2132
- Value cosFunctionInitial = rewriter.create <Torch::AtenCosOp>(
2133
- binder.getLoc (), resultType, divTimesTwoPi);
2134
- Value cosTimesNegHalf = rewriter.create <Torch::AtenMulScalarOp>(
2135
- binder.getLoc (), resultType, cosFunctionInitial, negHalf);
2136
- Value cosFunctionFinal = rewriter.create <Torch::AtenCosOp>(
2137
- binder.getLoc (), resultType, divTimesFourPi);
2138
- Value cosFinalTimesBeta = rewriter.create <Torch::AtenMulScalarOp>(
2139
- binder.getLoc (), resultType, cosFunctionFinal, beta);
2140
-
2141
- Value constOne = rewriter.create <Torch::ConstantIntOp>(
2142
- binder.getLoc (), rewriter.getI64IntegerAttr (1 ));
2143
- Value valAdd = rewriter.create <Torch::AtenAddTensorOp>(
2144
- binder.getLoc (), resultType, cosTimesNegHalf, cosFinalTimesBeta,
2145
- constOne);
2146
- Value finalResult = rewriter.create <Torch::AtenAddScalarOp>(
2147
- binder.getLoc (), resultType, valAdd, alpha, constOne);
2102
+ rewriter.getFloatAttr (rewriter.getF64Type (), 2.0 * pi));
2103
+
2104
+ Value noneVal = rewriter.create <Torch::ConstantNoneOp>(binder.getLoc ());
2105
+ Value cstFalse =
2106
+ rewriter.create <Torch::ConstantBoolOp>(binder.getLoc (), false );
2107
+ Value float32Type = rewriter.create <Torch::ConstantIntOp>(
2108
+ binder.getLoc (), rewriter.getI64IntegerAttr (/* float32Type*/ 6 ));
2109
+
2110
+ Value periodicSizeFloat = rewriter.create <Torch::AtenToDtypeOp>(
2111
+ binder.getLoc (), size.getType (), size, float32Type, cstFalse,
2112
+ cstFalse, noneVal);
2113
+ Value symmetricSizeFloat = rewriter.create <Torch::AtenSubScalarOp>(
2114
+ binder.getLoc (), periodicSizeFloat.getType (), periodicSizeFloat,
2115
+ one, one);
2116
+
2117
+ Value isPeriodic = rewriter.create <Torch::ConstantFloatOp>(
2118
+ binder.getLoc (), rewriter.getF64FloatAttr (isPeriodicFp));
2119
+ Value isSymmetricFloat = rewriter.create <Torch::ConstantFloatOp>(
2120
+ binder.getLoc (), rewriter.getF64FloatAttr (1.0 - isPeriodicFp));
2121
+
2122
+ Value periodicComponent = rewriter.create <Torch::AtenMulScalarOp>(
2123
+ binder.getLoc (), periodicSizeFloat.getType (), periodicSizeFloat,
2124
+ isPeriodic);
2125
+ Value symmetricComponent = rewriter.create <Torch::AtenMulScalarOp>(
2126
+ binder.getLoc (), symmetricSizeFloat.getType (), symmetricSizeFloat,
2127
+ isSymmetricFloat);
2128
+ Value sizeFloat = rewriter.create <Torch::AtenAddTensorOp>(
2129
+ binder.getLoc (), symmetricComponent.getType (), symmetricComponent,
2130
+ periodicComponent, one);
2131
+
2132
+ Value scalarLimit =
2133
+ getItemOp<Torch::IntType>(binder, rewriter, periodicSizeFloat);
2134
+
2135
+ Value rangeArr = rewriter.create <Torch::AtenArangeStartStepOp>(
2136
+ binder.getLoc (), resultType, zero, scalarLimit, one, noneVal,
2137
+ noneVal, noneVal, noneVal);
2138
+
2139
+ Value rangeTimesTau = rewriter.create <Torch::AtenMulScalarOp>(
2140
+ binder.getLoc (), resultType, rangeArr, tau);
2141
+ Value rangeAngular = rewriter.create <Torch::AtenDivTensorOp>(
2142
+ binder.getLoc (), resultType, rangeTimesTau, sizeFloat);
2143
+ Value twoRangeAngular = rewriter.create <Torch::AtenMulScalarOp>(
2144
+ binder.getLoc (), resultType, rangeAngular, two);
2145
+
2146
+ Value cosRangeAngular = rewriter.create <Torch::AtenCosOp>(
2147
+ binder.getLoc (), resultType, rangeAngular);
2148
+ Value cosTwoRangeAngular = rewriter.create <Torch::AtenCosOp>(
2149
+ binder.getLoc (), resultType, twoRangeAngular);
2150
+
2151
+ Value a1Component = rewriter.create <Torch::AtenMulScalarOp>(
2152
+ binder.getLoc (), resultType, cosRangeAngular, a1);
2153
+ Value a2Component = rewriter.create <Torch::AtenMulScalarOp>(
2154
+ binder.getLoc (), resultType, cosTwoRangeAngular, a2);
2155
+
2156
+ Value subA1Component = rewriter.create <Torch::AtenAddScalarOp>(
2157
+ binder.getLoc (), resultType, a1Component, a0, one);
2158
+ Value result = rewriter.create <Torch::AtenAddTensorOp>(
2159
+ binder.getLoc (), resultType, subA1Component, a2Component, one);
2148
2160
2149
2161
int64_t dtypeIntTorch = onnxDtypeIntToTorchDtypeInt (output_datatype);
2150
2162
Value outputDtype = rewriter.create <Torch::ConstantIntOp>(
2151
2163
binder.getLoc (), rewriter.getType <Torch::IntType>(),
2152
2164
rewriter.getIntegerAttr (rewriter.getIntegerType (64 ),
2153
2165
dtypeIntTorch));
2154
- Value noneVal = rewriter.create <Torch::ConstantNoneOp>(binder.getLoc ());
2155
- Value cstFalse =
2156
- rewriter.create <Torch::ConstantBoolOp>(binder.getLoc (), false );
2166
+
2157
2167
rewriter.replaceOpWithNewOp <Torch::AtenToDtypeOp>(
2158
- binder.op , resultType, finalResult , outputDtype,
2168
+ binder.op , resultType, result , outputDtype,
2159
2169
/* non_blocking=*/ cstFalse, /* copy=*/ cstFalse,
2160
2170
/* memory_format=*/ noneVal);
2161
2171
return success ();
0 commit comments