Skip to content

Commit f718e52

Browse files
committed
OnnxToTorch bicubic interpolation
1 parent 0a86deb commit f718e52

File tree

3 files changed

+292
-22
lines changed

3 files changed

+292
-22
lines changed

lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2914,7 +2914,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
29142914
llvm::SmallVector<Value> operands;
29152915
std::string mode, nearest_mode, coordTfMode;
29162916
int64_t antialias, exclude_outside;
2917-
float extrapolation_value;
2917+
float extrapolation_value, cubic_coeff_a;
29182918
Value noneVal = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
29192919

29202920
if (auto attr = binder.op->getAttr("torch.onnx.axes")) {
@@ -2939,7 +2939,8 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
29392939
binder.f32FloatAttr(extrapolation_value, "extrapolation_value",
29402940
0.0) ||
29412941
binder.customOpNameStringAttr(nearest_mode, "nearest_mode",
2942-
"round_prefer_floor"))
2942+
"round_prefer_floor") ||
2943+
binder.f32FloatAttr(cubic_coeff_a, "cubic_coeff_a", -0.75))
29432944
return failure();
29442945
if (antialias != 0) {
29452946
return rewriter.notifyMatchFailure(
@@ -2983,8 +2984,11 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
29832984
Value alignCorners =
29842985
coordTfMode == "align_corners" ? cstTrue : cstFalse;
29852986
if (mode == "cubic") {
2986-
return rewriter.notifyMatchFailure(binder.op,
2987-
"unimplemented: bicubic mode");
2987+
std::string modeStr = "cubic";
2988+
if (coordTfMode != "half_pixel")
2989+
modeStr = modeStr + "_" + coordTfMode;
2990+
modeStrValue =
2991+
rewriter.create<Torch::ConstantStrOp>(binder.getLoc(), modeStr);
29882992
}
29892993
// supported modes:
29902994
// bilinear (half_pixel), bilinear with align_corners,

lib/Conversion/TorchToLinalg/Uncategorized.cpp

Lines changed: 208 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2740,12 +2740,13 @@ static Value NearestInterpolate(OpBuilder &b, Location loc,
27402740
return retVal;
27412741
}
27422742

2743-
static Value BilinearInterpolate(OpBuilder &b,
2744-
Aten__InterpolateSizeListScaleListOp op,
2745-
Location loc, SmallVector<Value> outputSizes,
2746-
Value input, SmallVector<Value> inputSizes,
2747-
SmallVector<Value> scaleValues,
2748-
std::string coordStr) {
2743+
static SmallVector<Value>
2744+
CoordinateTransform(OpBuilder &b, Aten__InterpolateSizeListScaleListOp op,
2745+
Location loc, SmallVector<Value> outputSizes, Value input,
2746+
SmallVector<Value> inputSizes,
2747+
SmallVector<Value> scaleValues, std::string coordStr,
2748+
bool alignCornersBool, SmallVector<Value> indices) {
2749+
27492750
unsigned dimOffset = 2;
27502751
auto inputType = cast<RankedTensorType>(input.getType());
27512752
auto inputRank = inputType.getRank();
@@ -2754,15 +2755,7 @@ static Value BilinearInterpolate(OpBuilder &b,
27542755
Value cstHalf = b.create<arith::ConstantOp>(loc, b.getF32FloatAttr(0.5));
27552756
Value zero = b.create<arith::ConstantOp>(loc, b.getF32FloatAttr(0.0));
27562757

2757-
bool alignCornersBool;
2758-
matchPattern(op.getAlignCorners(), m_TorchConstantBool(&alignCornersBool));
2759-
2760-
SmallVector<Value> indices;
2761-
for (unsigned i = 0; i < inputRank; i++) {
2762-
indices.push_back(b.create<linalg::IndexOp>(loc, i));
2763-
}
2764-
2765-
SmallVector<Value> proj, projEps, high, low, highFP, lowFP;
2758+
SmallVector<Value> proj;
27662759
for (unsigned i = 0; i < inputRank - dimOffset; i++) {
27672760
// length_original
27682761
Value inputFP =
@@ -2832,6 +2825,38 @@ static Value BilinearInterpolate(OpBuilder &b,
28322825
// clip to [0,length_original - 1].
28332826
// proj is properly within the input image.
28342827
proj.push_back(b.create<arith::MinimumFOp>(loc, max, inputSubOne));
2828+
}
2829+
return proj;
2830+
}
2831+
2832+
static Value BilinearInterpolate(OpBuilder &b,
2833+
Aten__InterpolateSizeListScaleListOp op,
2834+
Location loc, SmallVector<Value> outputSizes,
2835+
Value input, SmallVector<Value> inputSizes,
2836+
SmallVector<Value> scaleValues,
2837+
std::string coordStr) {
2838+
unsigned dimOffset = 2;
2839+
auto inputType = cast<RankedTensorType>(input.getType());
2840+
auto inputRank = inputType.getRank();
2841+
2842+
Value cstOneFloat = b.create<arith::ConstantOp>(loc, b.getF32FloatAttr(1.0));
2843+
2844+
bool alignCornersBool;
2845+
matchPattern(op.getAlignCorners(), m_TorchConstantBool(&alignCornersBool));
2846+
2847+
SmallVector<Value> indices;
2848+
for (unsigned i = 0; i < inputRank; i++) {
2849+
indices.push_back(b.create<linalg::IndexOp>(loc, i));
2850+
}
2851+
2852+
SmallVector<Value> proj, high, low, highFP, lowFP;
2853+
proj = CoordinateTransform(b, op, loc, outputSizes, input, inputSizes,
2854+
scaleValues, coordStr, alignCornersBool, indices);
2855+
for (unsigned i = 0; i < inputRank - dimOffset; i++) {
2856+
// length_original
2857+
Value inputFP =
2858+
b.create<arith::SIToFPOp>(loc, b.getF32Type(), inputSizes[i]);
2859+
Value inputSubOne = b.create<arith::SubFOp>(loc, inputFP, cstOneFloat);
28352860

28362861
// for bilinear interpolation, we look for the nearest indices below and
28372862
// above proj
@@ -2895,6 +2920,168 @@ static Value BilinearInterpolate(OpBuilder &b,
28952920
return b.create<arith::AddFOp>(loc, left, right);
28962921
}
28972922

2923+
static Value WeightFunction(OpBuilder &b, Location loc, Value xDistance) {
2924+
Value a = b.create<arith::ConstantOp>(loc, b.getF32FloatAttr(-0.75));
2925+
Value zero = b.create<arith::ConstantOp>(loc, b.getF32FloatAttr(0.0));
2926+
Value cstOneFloat = b.create<arith::ConstantOp>(loc, b.getF32FloatAttr(1.0));
2927+
Value cstTwoFloat = b.create<arith::ConstantOp>(loc, b.getF32FloatAttr(2.0));
2928+
Value cstThreeFloat =
2929+
b.create<arith::ConstantOp>(loc, b.getF32FloatAttr(3.0));
2930+
Value cstFourFloat = b.create<arith::ConstantOp>(loc, b.getF32FloatAttr(4.0));
2931+
Value cstFiveFloat = b.create<arith::ConstantOp>(loc, b.getF32FloatAttr(5.0));
2932+
Value cstEightFloat =
2933+
b.create<arith::ConstantOp>(loc, b.getF32FloatAttr(8.0));
2934+
2935+
Value xDistanceSquared = b.create<arith::MulFOp>(loc, xDistance, xDistance);
2936+
Value xDistanceCubed =
2937+
b.create<arith::MulFOp>(loc, xDistanceSquared, xDistance);
2938+
Value lessThanTwo = b.create<arith::MulFOp>(loc, xDistanceCubed, a);
2939+
2940+
Value fiveA = b.create<arith::MulFOp>(loc, xDistanceSquared, a);
2941+
fiveA = b.create<arith::MulFOp>(loc, fiveA, cstFiveFloat);
2942+
lessThanTwo = b.create<arith::SubFOp>(loc, lessThanTwo, fiveA);
2943+
2944+
Value eightA = b.create<arith::MulFOp>(loc, a, xDistance);
2945+
eightA = b.create<arith::MulFOp>(loc, eightA, cstEightFloat);
2946+
lessThanTwo = b.create<arith::AddFOp>(loc, eightA, lessThanTwo);
2947+
2948+
Value fourA = b.create<arith::MulFOp>(loc, a, cstFourFloat);
2949+
lessThanTwo = b.create<arith::SubFOp>(loc, lessThanTwo, fourA);
2950+
2951+
Value greaterthanOrEqualToTwo = zero;
2952+
2953+
Value lessEqualOne = b.create<arith::AddFOp>(loc, a, cstTwoFloat);
2954+
lessEqualOne = b.create<arith::MulFOp>(loc, xDistanceCubed, lessEqualOne);
2955+
Value aPlusThree = b.create<arith::AddFOp>(loc, a, cstThreeFloat);
2956+
aPlusThree = b.create<arith::MulFOp>(loc, xDistanceSquared, aPlusThree);
2957+
lessEqualOne = b.create<arith::SubFOp>(loc, lessEqualOne, aPlusThree);
2958+
lessEqualOne = b.create<arith::AddFOp>(loc, lessEqualOne, cstOneFloat);
2959+
2960+
Value cmp = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UGE, xDistance,
2961+
cstTwoFloat);
2962+
Value greaterThanOne =
2963+
b.create<arith::SelectOp>(loc, cmp, greaterthanOrEqualToTwo, lessThanTwo);
2964+
cmp = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ULE, xDistance,
2965+
cstOneFloat);
2966+
Value middle =
2967+
b.create<arith::SelectOp>(loc, cmp, lessEqualOne, greaterThanOne);
2968+
2969+
return middle;
2970+
}
2971+
2972+
static Value BicubicInterpolate(OpBuilder &b,
2973+
Aten__InterpolateSizeListScaleListOp op,
2974+
Location loc, SmallVector<Value> outputSizes,
2975+
Value input, SmallVector<Value> inputSizes,
2976+
SmallVector<Value> scaleValues,
2977+
std::string coordStr) {
2978+
unsigned dimOffset = 2;
2979+
auto inputType = cast<RankedTensorType>(input.getType());
2980+
auto inputRank = inputType.getRank();
2981+
2982+
Value inputFPH =
2983+
b.create<arith::SIToFPOp>(loc, b.getF32Type(), inputSizes[0]);
2984+
Value inputFPW =
2985+
b.create<arith::SIToFPOp>(loc, b.getF32Type(), inputSizes[1]);
2986+
2987+
Value cstOneFloat = b.create<arith::ConstantOp>(loc, b.getF32FloatAttr(1.0));
2988+
Value zero = b.create<arith::ConstantOp>(loc, b.getF32FloatAttr(0.0));
2989+
2990+
bool alignCornersBool;
2991+
matchPattern(op.getAlignCorners(), m_TorchConstantBool(&alignCornersBool));
2992+
2993+
SmallVector<Value> indices;
2994+
for (unsigned i = 0; i < inputRank; i++) {
2995+
indices.push_back(b.create<linalg::IndexOp>(loc, i));
2996+
}
2997+
2998+
SmallVector<Value> proj;
2999+
proj = CoordinateTransform(b, op, loc, outputSizes, input, inputSizes,
3000+
scaleValues, coordStr, alignCornersBool, indices);
3001+
3002+
Value x1 = b.create<math::CeilOp>(loc, proj[1]);
3003+
Value x_1 = b.create<arith::SubFOp>(loc, x1, cstOneFloat);
3004+
Value x_2 = b.create<arith::SubFOp>(loc, x_1, cstOneFloat);
3005+
Value x2 = b.create<arith::AddFOp>(loc, x1, cstOneFloat);
3006+
3007+
Value y1 = b.create<math::CeilOp>(loc, proj[0]);
3008+
Value y_1 = b.create<arith::SubFOp>(loc, y1, cstOneFloat);
3009+
Value y_2 = b.create<arith::SubFOp>(loc, y_1, cstOneFloat);
3010+
Value y2 = b.create<arith::AddFOp>(loc, y1, cstOneFloat);
3011+
3012+
Value y2Distance = b.create<arith::SubFOp>(loc, proj[0], y2);
3013+
y2Distance = b.create<math::AbsFOp>(loc, y2Distance);
3014+
Value y1Distance = b.create<arith::SubFOp>(loc, proj[0], y1);
3015+
y1Distance = b.create<math::AbsFOp>(loc, y1Distance);
3016+
Value y_1Distance = b.create<arith::SubFOp>(loc, proj[0], y_1);
3017+
y_1Distance = b.create<math::AbsFOp>(loc, y_1Distance);
3018+
Value y_2Distance = b.create<arith::SubFOp>(loc, proj[0], y_2);
3019+
y_2Distance = b.create<math::AbsFOp>(loc, y_2Distance);
3020+
3021+
Value x2Distance = b.create<arith::SubFOp>(loc, proj[1], x2);
3022+
x2Distance = b.create<math::AbsFOp>(loc, x2Distance);
3023+
Value x1Distance = b.create<arith::SubFOp>(loc, proj[1], x1);
3024+
x1Distance = b.create<math::AbsFOp>(loc, x1Distance);
3025+
Value x_1Distance = b.create<arith::SubFOp>(loc, proj[1], x_1);
3026+
x_1Distance = b.create<math::AbsFOp>(loc, x_1Distance);
3027+
Value x_2Distance = b.create<arith::SubFOp>(loc, proj[1], x_2);
3028+
x_2Distance = b.create<math::AbsFOp>(loc, x_2Distance);
3029+
3030+
SmallVector<Value> y{y_2, y_1, y1, y2};
3031+
SmallVector<Value> x{x_2, x_1, x1, x2};
3032+
SmallVector<Value> yDistance{y_2Distance, y_1Distance, y1Distance,
3033+
y2Distance};
3034+
SmallVector<Value> xDistance{x_2Distance, x_1Distance, x1Distance,
3035+
x2Distance};
3036+
SmallVector<Value> xInterp{zero, zero, zero, zero};
3037+
3038+
// f(x_orig, y_orig) = Sum_y Sum_x W(x_original - x)*input[x,y]
3039+
// * W(y_original - y)
3040+
Value fxy = zero;
3041+
3042+
for (int j = 0; j < 4; j++) {
3043+
Value wy = WeightFunction(b, loc, yDistance[j]);
3044+
Value xInterpy = xInterp[j];
3045+
for (int i = 0; i < 4; i++) {
3046+
Value wx = WeightFunction(b, loc, xDistance[i]);
3047+
3048+
Value cmp =
3049+
b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UGE, y[j], zero);
3050+
y[j] = b.create<arith::SelectOp>(loc, cmp, y[j], zero);
3051+
3052+
Value inputWSubOne = b.create<arith::SubFOp>(loc, inputFPW, cstOneFloat);
3053+
Value inputHSubOne = b.create<arith::SubFOp>(loc, inputFPH, cstOneFloat);
3054+
cmp = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UGT, y[j],
3055+
inputHSubOne);
3056+
y[j] = b.create<arith::SelectOp>(loc, cmp, inputHSubOne, y[j]);
3057+
3058+
Value yInt = b.create<arith::FPToSIOp>(loc, b.getI64Type(), y[j]);
3059+
Value yIndex = b.create<arith::IndexCastOp>(loc, b.getIndexType(), yInt);
3060+
indices[dimOffset] = yIndex;
3061+
3062+
cmp = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UGE, x[i], zero);
3063+
x[i] = b.create<arith::SelectOp>(loc, cmp, x[i], zero);
3064+
3065+
cmp = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UGT, x[i],
3066+
inputWSubOne);
3067+
x[i] = b.create<arith::SelectOp>(loc, cmp, inputWSubOne, x[i]);
3068+
3069+
Value xInt = b.create<arith::FPToSIOp>(loc, b.getI64Type(), x[i]);
3070+
Value xIndex = b.create<arith::IndexCastOp>(loc, b.getIndexType(), xInt);
3071+
indices[dimOffset + 1] = xIndex;
3072+
3073+
Value p = b.create<tensor::ExtractOp>(loc, input, indices);
3074+
3075+
Value wxp = b.create<arith::MulFOp>(loc, wx, p);
3076+
xInterpy = b.create<arith::AddFOp>(loc, xInterpy, wxp);
3077+
}
3078+
Value wyXInterpy = b.create<arith::MulFOp>(loc, wy, xInterpy);
3079+
fxy = b.create<arith::AddFOp>(loc, fxy, wyXInterpy);
3080+
}
3081+
3082+
return fxy;
3083+
}
3084+
28983085
namespace {
28993086
class ConvertInterpolateOp
29003087
: public OpConversionPattern<Aten__InterpolateSizeListScaleListOp> {
@@ -2910,7 +3097,8 @@ class ConvertInterpolateOp
29103097
// coordinate_transformation_mode="asymmetric" will lower to an interpolate
29113098
// op with the non-standard mode="bilinear_asymmetric".
29123099
matchPattern(op.getMode(), m_TorchConstantStr(mode));
2913-
if (mode.substr(0, 8) != "bilinear" && mode.substr(0, 7) != "nearest") {
3100+
if (mode.substr(0, 8) != "bilinear" && mode.substr(0, 7) != "nearest" &&
3101+
mode.substr(0, 5) != "cubic") {
29143102
return failure();
29153103
}
29163104

@@ -2999,6 +3187,10 @@ class ConvertInterpolateOp
29993187
retVal = BilinearInterpolate(
30003188
b, op, loc, outputSizeIntValues, input, inputSizes,
30013189
ScaleFactorFloatValues, mode.substr(8));
3190+
} else if (mode.substr(0, 5) == "cubic") {
3191+
retVal = BicubicInterpolate(
3192+
b, op, loc, outputSizeIntValues, input, inputSizes,
3193+
ScaleFactorFloatValues, mode.substr(5));
30023194
}
30033195
b.create<linalg::YieldOp>(loc, retVal);
30043196
})

0 commit comments

Comments
 (0)