@@ -2771,12 +2771,13 @@ static Value NearestInterpolate(OpBuilder &b, Location loc,
2771
2771
return retVal;
2772
2772
}
2773
2773
2774
- static Value BilinearInterpolate (OpBuilder &b,
2775
- Aten__InterpolateSizeListScaleListOp op,
2776
- Location loc, SmallVector<Value> outputSizes,
2777
- Value input, SmallVector<Value> inputSizes,
2778
- SmallVector<Value> scaleValues,
2779
- std::string coordStr) {
2774
+ static SmallVector<Value>
2775
+ CoordinateTransform (OpBuilder &b, Aten__InterpolateSizeListScaleListOp op,
2776
+ Location loc, SmallVector<Value> outputSizes, Value input,
2777
+ SmallVector<Value> inputSizes,
2778
+ SmallVector<Value> scaleValues, std::string coordStr,
2779
+ bool alignCornersBool, SmallVector<Value> indices) {
2780
+
2780
2781
unsigned dimOffset = 2 ;
2781
2782
auto inputType = cast<RankedTensorType>(input.getType ());
2782
2783
auto inputRank = inputType.getRank ();
@@ -2785,15 +2786,7 @@ static Value BilinearInterpolate(OpBuilder &b,
2785
2786
Value cstHalf = b.create <arith::ConstantOp>(loc, b.getF32FloatAttr (0.5 ));
2786
2787
Value zero = b.create <arith::ConstantOp>(loc, b.getF32FloatAttr (0.0 ));
2787
2788
2788
- bool alignCornersBool;
2789
- matchPattern (op.getAlignCorners (), m_TorchConstantBool (&alignCornersBool));
2790
-
2791
- SmallVector<Value> indices;
2792
- for (unsigned i = 0 ; i < inputRank; i++) {
2793
- indices.push_back (b.create <linalg::IndexOp>(loc, i));
2794
- }
2795
-
2796
- SmallVector<Value> proj, projEps, high, low, highFP, lowFP;
2789
+ SmallVector<Value> proj;
2797
2790
for (unsigned i = 0 ; i < inputRank - dimOffset; i++) {
2798
2791
// length_original
2799
2792
Value inputFP =
@@ -2863,6 +2856,38 @@ static Value BilinearInterpolate(OpBuilder &b,
2863
2856
// clip to [0,length_original - 1].
2864
2857
// proj is properly within the input image.
2865
2858
proj.push_back (b.create <arith::MinimumFOp>(loc, max, inputSubOne));
2859
+ }
2860
+ return proj;
2861
+ }
2862
+
2863
+ static Value BilinearInterpolate (OpBuilder &b,
2864
+ Aten__InterpolateSizeListScaleListOp op,
2865
+ Location loc, SmallVector<Value> outputSizes,
2866
+ Value input, SmallVector<Value> inputSizes,
2867
+ SmallVector<Value> scaleValues,
2868
+ std::string coordStr) {
2869
+ unsigned dimOffset = 2 ;
2870
+ auto inputType = cast<RankedTensorType>(input.getType ());
2871
+ auto inputRank = inputType.getRank ();
2872
+
2873
+ Value cstOneFloat = b.create <arith::ConstantOp>(loc, b.getF32FloatAttr (1.0 ));
2874
+
2875
+ bool alignCornersBool;
2876
+ matchPattern (op.getAlignCorners (), m_TorchConstantBool (&alignCornersBool));
2877
+
2878
+ SmallVector<Value> indices;
2879
+ for (unsigned i = 0 ; i < inputRank; i++) {
2880
+ indices.push_back (b.create <linalg::IndexOp>(loc, i));
2881
+ }
2882
+
2883
+ SmallVector<Value> proj, high, low, highFP, lowFP;
2884
+ proj = CoordinateTransform (b, op, loc, outputSizes, input, inputSizes,
2885
+ scaleValues, coordStr, alignCornersBool, indices);
2886
+ for (unsigned i = 0 ; i < inputRank - dimOffset; i++) {
2887
+ // length_original
2888
+ Value inputFP =
2889
+ b.create <arith::SIToFPOp>(loc, b.getF32Type (), inputSizes[i]);
2890
+ Value inputSubOne = b.create <arith::SubFOp>(loc, inputFP, cstOneFloat);
2866
2891
2867
2892
// for bilinear interpolation, we look for the nearest indices below and
2868
2893
// above proj
@@ -2926,6 +2951,158 @@ static Value BilinearInterpolate(OpBuilder &b,
2926
2951
return b.create <arith::AddFOp>(loc, left, right);
2927
2952
}
2928
2953
2954
+ static Value BicubicInterpolate (OpBuilder &b,
2955
+ Aten__InterpolateSizeListScaleListOp op,
2956
+ Location loc, SmallVector<Value> outputSizes,
2957
+ Value input, SmallVector<Value> inputSizes,
2958
+ SmallVector<Value> scaleValues,
2959
+ std::string coordStr) {
2960
+ unsigned dimOffset = 2 ;
2961
+ auto inputType = cast<RankedTensorType>(input.getType ());
2962
+ auto inputRank = inputType.getRank ();
2963
+
2964
+ Value inputFPH =
2965
+ b.create <arith::SIToFPOp>(loc, b.getF32Type (), inputSizes[0 ]);
2966
+ Value inputFPW =
2967
+ b.create <arith::SIToFPOp>(loc, b.getF32Type (), inputSizes[1 ]);
2968
+
2969
+ Value a = b.create <arith::ConstantOp>(loc, b.getF32FloatAttr (-0.75 ));
2970
+ Value zero = b.create <arith::ConstantOp>(loc, b.getF32FloatAttr (0.0 ));
2971
+ Value cstOneFloat = b.create <arith::ConstantOp>(loc, b.getF32FloatAttr (1.0 ));
2972
+ Value cstTwoFloat = b.create <arith::ConstantOp>(loc, b.getF32FloatAttr (2.0 ));
2973
+ Value cstThreeFloat =
2974
+ b.create <arith::ConstantOp>(loc, b.getF32FloatAttr (3.0 ));
2975
+ Value cstFourFloat = b.create <arith::ConstantOp>(loc, b.getF32FloatAttr (4.0 ));
2976
+ Value cstFiveFloat = b.create <arith::ConstantOp>(loc, b.getF32FloatAttr (5.0 ));
2977
+ Value cstEightFloat =
2978
+ b.create <arith::ConstantOp>(loc, b.getF32FloatAttr (8.0 ));
2979
+
2980
+ auto WeightLessThanEqualOne = [&](Value xDistance) -> Value {
2981
+ Value xDistanceSquared = b.create <arith::MulFOp>(loc, xDistance, xDistance);
2982
+ Value xDistanceCubed =
2983
+ b.create <arith::MulFOp>(loc, xDistanceSquared, xDistance);
2984
+
2985
+ Value lessEqualOne = b.create <arith::AddFOp>(loc, a, cstTwoFloat);
2986
+ lessEqualOne = b.create <arith::MulFOp>(loc, xDistanceCubed, lessEqualOne);
2987
+ Value aPlusThree = b.create <arith::AddFOp>(loc, a, cstThreeFloat);
2988
+ aPlusThree = b.create <arith::MulFOp>(loc, xDistanceSquared, aPlusThree);
2989
+ lessEqualOne = b.create <arith::SubFOp>(loc, lessEqualOne, aPlusThree);
2990
+ lessEqualOne = b.create <arith::AddFOp>(loc, lessEqualOne, cstOneFloat);
2991
+
2992
+ return lessEqualOne;
2993
+ };
2994
+
2995
+ auto WeightLessThanTwo = [&](Value xDistance) -> Value {
2996
+ Value xDistanceSquared = b.create <arith::MulFOp>(loc, xDistance, xDistance);
2997
+ Value xDistanceCubed =
2998
+ b.create <arith::MulFOp>(loc, xDistanceSquared, xDistance);
2999
+ Value lessThanTwo = b.create <arith::MulFOp>(loc, xDistanceCubed, a);
3000
+
3001
+ Value fiveA = b.create <arith::MulFOp>(loc, xDistanceSquared, a);
3002
+ fiveA = b.create <arith::MulFOp>(loc, fiveA, cstFiveFloat);
3003
+ lessThanTwo = b.create <arith::SubFOp>(loc, lessThanTwo, fiveA);
3004
+
3005
+ Value eightA = b.create <arith::MulFOp>(loc, a, xDistance);
3006
+ eightA = b.create <arith::MulFOp>(loc, eightA, cstEightFloat);
3007
+ lessThanTwo = b.create <arith::AddFOp>(loc, eightA, lessThanTwo);
3008
+
3009
+ Value fourA = b.create <arith::MulFOp>(loc, a, cstFourFloat);
3010
+ lessThanTwo = b.create <arith::SubFOp>(loc, lessThanTwo, fourA);
3011
+ return lessThanTwo;
3012
+ };
3013
+
3014
+ bool alignCornersBool;
3015
+ matchPattern (op.getAlignCorners (), m_TorchConstantBool (&alignCornersBool));
3016
+
3017
+ SmallVector<Value> indices;
3018
+ for (unsigned i = 0 ; i < inputRank; i++) {
3019
+ indices.push_back (b.create <linalg::IndexOp>(loc, i));
3020
+ }
3021
+
3022
+ SmallVector<Value> proj;
3023
+ proj = CoordinateTransform (b, op, loc, outputSizes, input, inputSizes,
3024
+ scaleValues, coordStr, alignCornersBool, indices);
3025
+
3026
+ Value x1 = b.create <math::CeilOp>(loc, proj[1 ]);
3027
+ Value x_1 = b.create <arith::SubFOp>(loc, x1, cstOneFloat);
3028
+ Value x_2 = b.create <arith::SubFOp>(loc, x_1, cstOneFloat);
3029
+ Value x2 = b.create <arith::AddFOp>(loc, x1, cstOneFloat);
3030
+
3031
+ Value y1 = b.create <math::CeilOp>(loc, proj[0 ]);
3032
+ Value y_1 = b.create <arith::SubFOp>(loc, y1, cstOneFloat);
3033
+ Value y_2 = b.create <arith::SubFOp>(loc, y_1, cstOneFloat);
3034
+ Value y2 = b.create <arith::AddFOp>(loc, y1, cstOneFloat);
3035
+
3036
+ Value y2Distance = b.create <arith::SubFOp>(loc, proj[0 ], y2);
3037
+ y2Distance = b.create <math::AbsFOp>(loc, y2Distance);
3038
+ Value y1Distance = b.create <arith::SubFOp>(loc, proj[0 ], y1);
3039
+ y1Distance = b.create <math::AbsFOp>(loc, y1Distance);
3040
+ Value y_1Distance = b.create <arith::SubFOp>(loc, proj[0 ], y_1);
3041
+ y_1Distance = b.create <math::AbsFOp>(loc, y_1Distance);
3042
+ Value y_2Distance = b.create <arith::SubFOp>(loc, proj[0 ], y_2);
3043
+ y_2Distance = b.create <math::AbsFOp>(loc, y_2Distance);
3044
+
3045
+ Value x2Distance = b.create <arith::SubFOp>(loc, proj[1 ], x2);
3046
+ x2Distance = b.create <math::AbsFOp>(loc, x2Distance);
3047
+ Value x1Distance = b.create <arith::SubFOp>(loc, proj[1 ], x1);
3048
+ x1Distance = b.create <math::AbsFOp>(loc, x1Distance);
3049
+ Value x_1Distance = b.create <arith::SubFOp>(loc, proj[1 ], x_1);
3050
+ x_1Distance = b.create <math::AbsFOp>(loc, x_1Distance);
3051
+ Value x_2Distance = b.create <arith::SubFOp>(loc, proj[1 ], x_2);
3052
+ x_2Distance = b.create <math::AbsFOp>(loc, x_2Distance);
3053
+
3054
+ SmallVector<Value> y{y_2, y_1, y1, y2};
3055
+ SmallVector<Value> x{x_2, x_1, x1, x2};
3056
+ SmallVector<Value> yDistance{y_2Distance, y_1Distance, y1Distance,
3057
+ y2Distance};
3058
+ SmallVector<Value> xDistance{x_2Distance, x_1Distance, x1Distance,
3059
+ x2Distance};
3060
+ SmallVector<Value> wys{
3061
+ WeightLessThanTwo (y_2Distance), WeightLessThanEqualOne (y_1Distance),
3062
+ WeightLessThanEqualOne (y1Distance), WeightLessThanTwo (y2Distance)};
3063
+ SmallVector<Value> wxs{
3064
+ WeightLessThanTwo (x_2Distance), WeightLessThanEqualOne (x_1Distance),
3065
+ WeightLessThanEqualOne (x1Distance), WeightLessThanTwo (x2Distance)};
3066
+ SmallVector<Value> xInterp{zero, zero, zero, zero};
3067
+
3068
+ // f(x_orig, y_orig) = Sum_y Sum_x W(x_original - x)*input[x,y]
3069
+ // * W(y_original - y)
3070
+ Value fxy = zero;
3071
+
3072
+ for (int j = 0 ; j < 4 ; j++) {
3073
+ Value wy = wys[j];
3074
+ Value xInterpy = xInterp[j];
3075
+ for (int i = 0 ; i < 4 ; i++) {
3076
+ Value wx = wxs[i];
3077
+
3078
+ y[j] = b.create <arith::MaximumFOp>(loc, y[j], zero);
3079
+ Value inputHSubOne = b.create <arith::SubFOp>(loc, inputFPH, cstOneFloat);
3080
+ y[j] = b.create <arith::MinimumFOp>(loc, y[j], inputHSubOne);
3081
+
3082
+ Value yInt = b.create <arith::FPToSIOp>(loc, b.getI64Type (), y[j]);
3083
+ Value yIndex = b.create <arith::IndexCastOp>(loc, b.getIndexType (), yInt);
3084
+ indices[dimOffset] = yIndex;
3085
+
3086
+ x[i] = b.create <arith::MaximumFOp>(loc, x[i], zero);
3087
+ Value inputWSubOne = b.create <arith::SubFOp>(loc, inputFPW, cstOneFloat);
3088
+ x[i] = b.create <arith::MinimumFOp>(loc, x[i], inputWSubOne);
3089
+
3090
+ Value xInt = b.create <arith::FPToSIOp>(loc, b.getI64Type (), x[i]);
3091
+ Value xIndex = b.create <arith::IndexCastOp>(loc, b.getIndexType (), xInt);
3092
+ indices[dimOffset + 1 ] = xIndex;
3093
+
3094
+ Value p = b.create <tensor::ExtractOp>(loc, input, indices);
3095
+
3096
+ Value wxp = b.create <arith::MulFOp>(loc, wx, p);
3097
+ xInterpy = b.create <arith::AddFOp>(loc, xInterpy, wxp);
3098
+ }
3099
+ Value wyXInterpy = b.create <arith::MulFOp>(loc, wy, xInterpy);
3100
+ fxy = b.create <arith::AddFOp>(loc, fxy, wyXInterpy);
3101
+ }
3102
+
3103
+ return fxy;
3104
+ }
3105
+
2929
3106
namespace {
2930
3107
class ConvertInterpolateOp
2931
3108
: public OpConversionPattern<Aten__InterpolateSizeListScaleListOp> {
@@ -2941,7 +3118,8 @@ class ConvertInterpolateOp
2941
3118
// coordinate_transformation_mode="asymmetric" will lower to an interpolate
2942
3119
// op with the non-standard mode="bilinear_asymmetric".
2943
3120
matchPattern (op.getMode (), m_TorchConstantStr (mode));
2944
- if (mode.substr (0 , 8 ) != " bilinear" && mode.substr (0 , 7 ) != " nearest" ) {
3121
+ if (mode.substr (0 , 8 ) != " bilinear" && mode.substr (0 , 7 ) != " nearest" &&
3122
+ mode.substr (0 , 5 ) != " cubic" ) {
2945
3123
return failure ();
2946
3124
}
2947
3125
@@ -3030,6 +3208,10 @@ class ConvertInterpolateOp
3030
3208
retVal = BilinearInterpolate (
3031
3209
b, op, loc, outputSizeIntValues, input, inputSizes,
3032
3210
ScaleFactorFloatValues, mode.substr (8 ));
3211
+ } else if (mode.substr (0 , 5 ) == " cubic" ) {
3212
+ retVal = BicubicInterpolate (
3213
+ b, op, loc, outputSizeIntValues, input, inputSizes,
3214
+ ScaleFactorFloatValues, mode.substr (5 ));
3033
3215
}
3034
3216
b.create <linalg::YieldOp>(loc, retVal);
3035
3217
})
0 commit comments