@@ -2740,12 +2740,13 @@ static Value NearestInterpolate(OpBuilder &b, Location loc,
2740
2740
return retVal;
2741
2741
}
2742
2742
2743
- static Value BilinearInterpolate (OpBuilder &b,
2744
- Aten__InterpolateSizeListScaleListOp op,
2745
- Location loc, SmallVector<Value> outputSizes,
2746
- Value input, SmallVector<Value> inputSizes,
2747
- SmallVector<Value> scaleValues,
2748
- std::string coordStr) {
2743
+ static SmallVector<Value>
2744
+ CoordinateTransform (OpBuilder &b, Aten__InterpolateSizeListScaleListOp op,
2745
+ Location loc, SmallVector<Value> outputSizes, Value input,
2746
+ SmallVector<Value> inputSizes,
2747
+ SmallVector<Value> scaleValues, std::string coordStr,
2748
+ bool alignCornersBool, SmallVector<Value> indices) {
2749
+
2749
2750
unsigned dimOffset = 2 ;
2750
2751
auto inputType = cast<RankedTensorType>(input.getType ());
2751
2752
auto inputRank = inputType.getRank ();
@@ -2754,15 +2755,7 @@ static Value BilinearInterpolate(OpBuilder &b,
2754
2755
Value cstHalf = b.create <arith::ConstantOp>(loc, b.getF32FloatAttr (0.5 ));
2755
2756
Value zero = b.create <arith::ConstantOp>(loc, b.getF32FloatAttr (0.0 ));
2756
2757
2757
- bool alignCornersBool;
2758
- matchPattern (op.getAlignCorners (), m_TorchConstantBool (&alignCornersBool));
2759
-
2760
- SmallVector<Value> indices;
2761
- for (unsigned i = 0 ; i < inputRank; i++) {
2762
- indices.push_back (b.create <linalg::IndexOp>(loc, i));
2763
- }
2764
-
2765
- SmallVector<Value> proj, projEps, high, low, highFP, lowFP;
2758
+ SmallVector<Value> proj;
2766
2759
for (unsigned i = 0 ; i < inputRank - dimOffset; i++) {
2767
2760
// length_original
2768
2761
Value inputFP =
@@ -2832,6 +2825,38 @@ static Value BilinearInterpolate(OpBuilder &b,
2832
2825
// clip to [0,length_original - 1].
2833
2826
// proj is properly within the input image.
2834
2827
proj.push_back (b.create <arith::MinimumFOp>(loc, max, inputSubOne));
2828
+ }
2829
+ return proj;
2830
+ }
2831
+
2832
+ static Value BilinearInterpolate (OpBuilder &b,
2833
+ Aten__InterpolateSizeListScaleListOp op,
2834
+ Location loc, SmallVector<Value> outputSizes,
2835
+ Value input, SmallVector<Value> inputSizes,
2836
+ SmallVector<Value> scaleValues,
2837
+ std::string coordStr) {
2838
+ unsigned dimOffset = 2 ;
2839
+ auto inputType = cast<RankedTensorType>(input.getType ());
2840
+ auto inputRank = inputType.getRank ();
2841
+
2842
+ Value cstOneFloat = b.create <arith::ConstantOp>(loc, b.getF32FloatAttr (1.0 ));
2843
+
2844
+ bool alignCornersBool;
2845
+ matchPattern (op.getAlignCorners (), m_TorchConstantBool (&alignCornersBool));
2846
+
2847
+ SmallVector<Value> indices;
2848
+ for (unsigned i = 0 ; i < inputRank; i++) {
2849
+ indices.push_back (b.create <linalg::IndexOp>(loc, i));
2850
+ }
2851
+
2852
+ SmallVector<Value> proj, high, low, highFP, lowFP;
2853
+ proj = CoordinateTransform (b, op, loc, outputSizes, input, inputSizes,
2854
+ scaleValues, coordStr, alignCornersBool, indices);
2855
+ for (unsigned i = 0 ; i < inputRank - dimOffset; i++) {
2856
+ // length_original
2857
+ Value inputFP =
2858
+ b.create <arith::SIToFPOp>(loc, b.getF32Type (), inputSizes[i]);
2859
+ Value inputSubOne = b.create <arith::SubFOp>(loc, inputFP, cstOneFloat);
2835
2860
2836
2861
// for bilinear interpolation, we look for the nearest indices below and
2837
2862
// above proj
@@ -2895,6 +2920,168 @@ static Value BilinearInterpolate(OpBuilder &b,
2895
2920
return b.create <arith::AddFOp>(loc, left, right);
2896
2921
}
2897
2922
2923
+ static Value WeightFunction (OpBuilder &b, Location loc, Value xDistance) {
2924
+ Value a = b.create <arith::ConstantOp>(loc, b.getF32FloatAttr (-0.75 ));
2925
+ Value zero = b.create <arith::ConstantOp>(loc, b.getF32FloatAttr (0.0 ));
2926
+ Value cstOneFloat = b.create <arith::ConstantOp>(loc, b.getF32FloatAttr (1.0 ));
2927
+ Value cstTwoFloat = b.create <arith::ConstantOp>(loc, b.getF32FloatAttr (2.0 ));
2928
+ Value cstThreeFloat =
2929
+ b.create <arith::ConstantOp>(loc, b.getF32FloatAttr (3.0 ));
2930
+ Value cstFourFloat = b.create <arith::ConstantOp>(loc, b.getF32FloatAttr (4.0 ));
2931
+ Value cstFiveFloat = b.create <arith::ConstantOp>(loc, b.getF32FloatAttr (5.0 ));
2932
+ Value cstEightFloat =
2933
+ b.create <arith::ConstantOp>(loc, b.getF32FloatAttr (8.0 ));
2934
+
2935
+ Value xDistanceSquared = b.create <arith::MulFOp>(loc, xDistance, xDistance);
2936
+ Value xDistanceCubed =
2937
+ b.create <arith::MulFOp>(loc, xDistanceSquared, xDistance);
2938
+ Value lessThanTwo = b.create <arith::MulFOp>(loc, xDistanceCubed, a);
2939
+
2940
+ Value fiveA = b.create <arith::MulFOp>(loc, xDistanceSquared, a);
2941
+ fiveA = b.create <arith::MulFOp>(loc, fiveA, cstFiveFloat);
2942
+ lessThanTwo = b.create <arith::SubFOp>(loc, lessThanTwo, fiveA);
2943
+
2944
+ Value eightA = b.create <arith::MulFOp>(loc, a, xDistance);
2945
+ eightA = b.create <arith::MulFOp>(loc, eightA, cstEightFloat);
2946
+ lessThanTwo = b.create <arith::AddFOp>(loc, eightA, lessThanTwo);
2947
+
2948
+ Value fourA = b.create <arith::MulFOp>(loc, a, cstFourFloat);
2949
+ lessThanTwo = b.create <arith::SubFOp>(loc, lessThanTwo, fourA);
2950
+
2951
+ Value greaterthanOrEqualToTwo = zero;
2952
+
2953
+ Value lessEqualOne = b.create <arith::AddFOp>(loc, a, cstTwoFloat);
2954
+ lessEqualOne = b.create <arith::MulFOp>(loc, xDistanceCubed, lessEqualOne);
2955
+ Value aPlusThree = b.create <arith::AddFOp>(loc, a, cstThreeFloat);
2956
+ aPlusThree = b.create <arith::MulFOp>(loc, xDistanceSquared, aPlusThree);
2957
+ lessEqualOne = b.create <arith::SubFOp>(loc, lessEqualOne, aPlusThree);
2958
+ lessEqualOne = b.create <arith::AddFOp>(loc, lessEqualOne, cstOneFloat);
2959
+
2960
+ Value cmp = b.create <arith::CmpFOp>(loc, arith::CmpFPredicate::UGE, xDistance,
2961
+ cstTwoFloat);
2962
+ Value greaterThanOne =
2963
+ b.create <arith::SelectOp>(loc, cmp, greaterthanOrEqualToTwo, lessThanTwo);
2964
+ cmp = b.create <arith::CmpFOp>(loc, arith::CmpFPredicate::ULE, xDistance,
2965
+ cstOneFloat);
2966
+ Value middle =
2967
+ b.create <arith::SelectOp>(loc, cmp, lessEqualOne, greaterThanOne);
2968
+
2969
+ return middle;
2970
+ }
2971
+
2972
+ static Value BicubicInterpolate (OpBuilder &b,
2973
+ Aten__InterpolateSizeListScaleListOp op,
2974
+ Location loc, SmallVector<Value> outputSizes,
2975
+ Value input, SmallVector<Value> inputSizes,
2976
+ SmallVector<Value> scaleValues,
2977
+ std::string coordStr) {
2978
+ unsigned dimOffset = 2 ;
2979
+ auto inputType = cast<RankedTensorType>(input.getType ());
2980
+ auto inputRank = inputType.getRank ();
2981
+
2982
+ Value inputFPH =
2983
+ b.create <arith::SIToFPOp>(loc, b.getF32Type (), inputSizes[0 ]);
2984
+ Value inputFPW =
2985
+ b.create <arith::SIToFPOp>(loc, b.getF32Type (), inputSizes[1 ]);
2986
+
2987
+ Value cstOneFloat = b.create <arith::ConstantOp>(loc, b.getF32FloatAttr (1.0 ));
2988
+ Value zero = b.create <arith::ConstantOp>(loc, b.getF32FloatAttr (0.0 ));
2989
+
2990
+ bool alignCornersBool;
2991
+ matchPattern (op.getAlignCorners (), m_TorchConstantBool (&alignCornersBool));
2992
+
2993
+ SmallVector<Value> indices;
2994
+ for (unsigned i = 0 ; i < inputRank; i++) {
2995
+ indices.push_back (b.create <linalg::IndexOp>(loc, i));
2996
+ }
2997
+
2998
+ SmallVector<Value> proj;
2999
+ proj = CoordinateTransform (b, op, loc, outputSizes, input, inputSizes,
3000
+ scaleValues, coordStr, alignCornersBool, indices);
3001
+
3002
+ Value x1 = b.create <math::CeilOp>(loc, proj[1 ]);
3003
+ Value x_1 = b.create <arith::SubFOp>(loc, x1, cstOneFloat);
3004
+ Value x_2 = b.create <arith::SubFOp>(loc, x_1, cstOneFloat);
3005
+ Value x2 = b.create <arith::AddFOp>(loc, x1, cstOneFloat);
3006
+
3007
+ Value y1 = b.create <math::CeilOp>(loc, proj[0 ]);
3008
+ Value y_1 = b.create <arith::SubFOp>(loc, y1, cstOneFloat);
3009
+ Value y_2 = b.create <arith::SubFOp>(loc, y_1, cstOneFloat);
3010
+ Value y2 = b.create <arith::AddFOp>(loc, y1, cstOneFloat);
3011
+
3012
+ Value y2Distance = b.create <arith::SubFOp>(loc, proj[0 ], y2);
3013
+ y2Distance = b.create <math::AbsFOp>(loc, y2Distance);
3014
+ Value y1Distance = b.create <arith::SubFOp>(loc, proj[0 ], y1);
3015
+ y1Distance = b.create <math::AbsFOp>(loc, y1Distance);
3016
+ Value y_1Distance = b.create <arith::SubFOp>(loc, proj[0 ], y_1);
3017
+ y_1Distance = b.create <math::AbsFOp>(loc, y_1Distance);
3018
+ Value y_2Distance = b.create <arith::SubFOp>(loc, proj[0 ], y_2);
3019
+ y_2Distance = b.create <math::AbsFOp>(loc, y_2Distance);
3020
+
3021
+ Value x2Distance = b.create <arith::SubFOp>(loc, proj[1 ], x2);
3022
+ x2Distance = b.create <math::AbsFOp>(loc, x2Distance);
3023
+ Value x1Distance = b.create <arith::SubFOp>(loc, proj[1 ], x1);
3024
+ x1Distance = b.create <math::AbsFOp>(loc, x1Distance);
3025
+ Value x_1Distance = b.create <arith::SubFOp>(loc, proj[1 ], x_1);
3026
+ x_1Distance = b.create <math::AbsFOp>(loc, x_1Distance);
3027
+ Value x_2Distance = b.create <arith::SubFOp>(loc, proj[1 ], x_2);
3028
+ x_2Distance = b.create <math::AbsFOp>(loc, x_2Distance);
3029
+
3030
+ SmallVector<Value> y{y_2, y_1, y1, y2};
3031
+ SmallVector<Value> x{x_2, x_1, x1, x2};
3032
+ SmallVector<Value> yDistance{y_2Distance, y_1Distance, y1Distance,
3033
+ y2Distance};
3034
+ SmallVector<Value> xDistance{x_2Distance, x_1Distance, x1Distance,
3035
+ x2Distance};
3036
+ SmallVector<Value> xInterp{zero, zero, zero, zero};
3037
+
3038
+ // f(x_orig, y_orig) = Sum_y Sum_x W(x_original - x)*input[x,y]
3039
+ // * W(y_original - y)
3040
+ Value fxy = zero;
3041
+
3042
+ for (int j = 0 ; j < 4 ; j++) {
3043
+ Value wy = WeightFunction (b, loc, yDistance[j]);
3044
+ Value xInterpy = xInterp[j];
3045
+ for (int i = 0 ; i < 4 ; i++) {
3046
+ Value wx = WeightFunction (b, loc, xDistance[i]);
3047
+
3048
+ Value cmp =
3049
+ b.create <arith::CmpFOp>(loc, arith::CmpFPredicate::UGE, y[j], zero);
3050
+ y[j] = b.create <arith::SelectOp>(loc, cmp, y[j], zero);
3051
+
3052
+ Value inputWSubOne = b.create <arith::SubFOp>(loc, inputFPW, cstOneFloat);
3053
+ Value inputHSubOne = b.create <arith::SubFOp>(loc, inputFPH, cstOneFloat);
3054
+ cmp = b.create <arith::CmpFOp>(loc, arith::CmpFPredicate::UGT, y[j],
3055
+ inputHSubOne);
3056
+ y[j] = b.create <arith::SelectOp>(loc, cmp, inputHSubOne, y[j]);
3057
+
3058
+ Value yInt = b.create <arith::FPToSIOp>(loc, b.getI64Type (), y[j]);
3059
+ Value yIndex = b.create <arith::IndexCastOp>(loc, b.getIndexType (), yInt);
3060
+ indices[dimOffset] = yIndex;
3061
+
3062
+ cmp = b.create <arith::CmpFOp>(loc, arith::CmpFPredicate::UGE, x[i], zero);
3063
+ x[i] = b.create <arith::SelectOp>(loc, cmp, x[i], zero);
3064
+
3065
+ cmp = b.create <arith::CmpFOp>(loc, arith::CmpFPredicate::UGT, x[i],
3066
+ inputWSubOne);
3067
+ x[i] = b.create <arith::SelectOp>(loc, cmp, inputWSubOne, x[i]);
3068
+
3069
+ Value xInt = b.create <arith::FPToSIOp>(loc, b.getI64Type (), x[i]);
3070
+ Value xIndex = b.create <arith::IndexCastOp>(loc, b.getIndexType (), xInt);
3071
+ indices[dimOffset + 1 ] = xIndex;
3072
+
3073
+ Value p = b.create <tensor::ExtractOp>(loc, input, indices);
3074
+
3075
+ Value wxp = b.create <arith::MulFOp>(loc, wx, p);
3076
+ xInterpy = b.create <arith::AddFOp>(loc, xInterpy, wxp);
3077
+ }
3078
+ Value wyXInterpy = b.create <arith::MulFOp>(loc, wy, xInterpy);
3079
+ fxy = b.create <arith::AddFOp>(loc, fxy, wyXInterpy);
3080
+ }
3081
+
3082
+ return fxy;
3083
+ }
3084
+
2898
3085
namespace {
2899
3086
class ConvertInterpolateOp
2900
3087
: public OpConversionPattern<Aten__InterpolateSizeListScaleListOp> {
@@ -2910,7 +3097,8 @@ class ConvertInterpolateOp
2910
3097
// coordinate_transformation_mode="asymmetric" will lower to an interpolate
2911
3098
// op with the non-standard mode="bilinear_asymmetric".
2912
3099
matchPattern (op.getMode (), m_TorchConstantStr (mode));
2913
- if (mode.substr (0 , 8 ) != " bilinear" && mode.substr (0 , 7 ) != " nearest" ) {
3100
+ if (mode.substr (0 , 8 ) != " bilinear" && mode.substr (0 , 7 ) != " nearest" &&
3101
+ mode.substr (0 , 5 ) != " cubic" ) {
2914
3102
return failure ();
2915
3103
}
2916
3104
@@ -2999,6 +3187,10 @@ class ConvertInterpolateOp
2999
3187
retVal = BilinearInterpolate (
3000
3188
b, op, loc, outputSizeIntValues, input, inputSizes,
3001
3189
ScaleFactorFloatValues, mode.substr (8 ));
3190
+ } else if (mode.substr (0 , 5 ) == " cubic" ) {
3191
+ retVal = BicubicInterpolate (
3192
+ b, op, loc, outputSizeIntValues, input, inputSizes,
3193
+ ScaleFactorFloatValues, mode.substr (5 ));
3002
3194
}
3003
3195
b.create <linalg::YieldOp>(loc, retVal);
3004
3196
})
0 commit comments