Skip to content

Commit 1ed2b47

Browse files
committed
OnnxToTorch bicubic interpolation
1 parent 7058f45 commit 1ed2b47

File tree

3 files changed

+282
-22
lines changed

3 files changed

+282
-22
lines changed

lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2922,7 +2922,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
29222922
llvm::SmallVector<Value> operands;
29232923
std::string mode, nearest_mode, coordTfMode;
29242924
int64_t antialias, exclude_outside;
2925-
float extrapolation_value;
2925+
float extrapolation_value, cubic_coeff_a;
29262926
Value noneVal = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
29272927

29282928
if (auto attr = binder.op->getAttr("torch.onnx.axes")) {
@@ -2947,7 +2947,8 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
29472947
binder.f32FloatAttr(extrapolation_value, "extrapolation_value",
29482948
0.0) ||
29492949
binder.customOpNameStringAttr(nearest_mode, "nearest_mode",
2950-
"round_prefer_floor"))
2950+
"round_prefer_floor") ||
2951+
binder.f32FloatAttr(cubic_coeff_a, "cubic_coeff_a", -0.75))
29512952
return failure();
29522953
if (antialias != 0) {
29532954
return rewriter.notifyMatchFailure(
@@ -2991,8 +2992,11 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
29912992
Value alignCorners =
29922993
coordTfMode == "align_corners" ? cstTrue : cstFalse;
29932994
if (mode == "cubic") {
2994-
return rewriter.notifyMatchFailure(binder.op,
2995-
"unimplemented: bicubic mode");
2995+
std::string modeStr = "cubic";
2996+
if (coordTfMode != "half_pixel")
2997+
modeStr = modeStr + "_" + coordTfMode;
2998+
modeStrValue =
2999+
rewriter.create<Torch::ConstantStrOp>(binder.getLoc(), modeStr);
29963000
}
29973001
// supported modes:
29983002
// bilinear (half_pixel), bilinear with align_corners,

lib/Conversion/TorchToLinalg/Uncategorized.cpp

Lines changed: 198 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2771,12 +2771,13 @@ static Value NearestInterpolate(OpBuilder &b, Location loc,
27712771
return retVal;
27722772
}
27732773

2774-
static Value BilinearInterpolate(OpBuilder &b,
2775-
Aten__InterpolateSizeListScaleListOp op,
2776-
Location loc, SmallVector<Value> outputSizes,
2777-
Value input, SmallVector<Value> inputSizes,
2778-
SmallVector<Value> scaleValues,
2779-
std::string coordStr) {
2774+
static SmallVector<Value>
2775+
CoordinateTransform(OpBuilder &b, Aten__InterpolateSizeListScaleListOp op,
2776+
Location loc, SmallVector<Value> outputSizes, Value input,
2777+
SmallVector<Value> inputSizes,
2778+
SmallVector<Value> scaleValues, std::string coordStr,
2779+
bool alignCornersBool, SmallVector<Value> indices) {
2780+
27802781
unsigned dimOffset = 2;
27812782
auto inputType = cast<RankedTensorType>(input.getType());
27822783
auto inputRank = inputType.getRank();
@@ -2785,15 +2786,7 @@ static Value BilinearInterpolate(OpBuilder &b,
27852786
Value cstHalf = b.create<arith::ConstantOp>(loc, b.getF32FloatAttr(0.5));
27862787
Value zero = b.create<arith::ConstantOp>(loc, b.getF32FloatAttr(0.0));
27872788

2788-
bool alignCornersBool;
2789-
matchPattern(op.getAlignCorners(), m_TorchConstantBool(&alignCornersBool));
2790-
2791-
SmallVector<Value> indices;
2792-
for (unsigned i = 0; i < inputRank; i++) {
2793-
indices.push_back(b.create<linalg::IndexOp>(loc, i));
2794-
}
2795-
2796-
SmallVector<Value> proj, projEps, high, low, highFP, lowFP;
2789+
SmallVector<Value> proj;
27972790
for (unsigned i = 0; i < inputRank - dimOffset; i++) {
27982791
// length_original
27992792
Value inputFP =
@@ -2863,6 +2856,38 @@ static Value BilinearInterpolate(OpBuilder &b,
28632856
// clip to [0,length_original - 1].
28642857
// proj is properly within the input image.
28652858
proj.push_back(b.create<arith::MinimumFOp>(loc, max, inputSubOne));
2859+
}
2860+
return proj;
2861+
}
2862+
2863+
static Value BilinearInterpolate(OpBuilder &b,
2864+
Aten__InterpolateSizeListScaleListOp op,
2865+
Location loc, SmallVector<Value> outputSizes,
2866+
Value input, SmallVector<Value> inputSizes,
2867+
SmallVector<Value> scaleValues,
2868+
std::string coordStr) {
2869+
unsigned dimOffset = 2;
2870+
auto inputType = cast<RankedTensorType>(input.getType());
2871+
auto inputRank = inputType.getRank();
2872+
2873+
Value cstOneFloat = b.create<arith::ConstantOp>(loc, b.getF32FloatAttr(1.0));
2874+
2875+
bool alignCornersBool;
2876+
matchPattern(op.getAlignCorners(), m_TorchConstantBool(&alignCornersBool));
2877+
2878+
SmallVector<Value> indices;
2879+
for (unsigned i = 0; i < inputRank; i++) {
2880+
indices.push_back(b.create<linalg::IndexOp>(loc, i));
2881+
}
2882+
2883+
SmallVector<Value> proj, high, low, highFP, lowFP;
2884+
proj = CoordinateTransform(b, op, loc, outputSizes, input, inputSizes,
2885+
scaleValues, coordStr, alignCornersBool, indices);
2886+
for (unsigned i = 0; i < inputRank - dimOffset; i++) {
2887+
// length_original
2888+
Value inputFP =
2889+
b.create<arith::SIToFPOp>(loc, b.getF32Type(), inputSizes[i]);
2890+
Value inputSubOne = b.create<arith::SubFOp>(loc, inputFP, cstOneFloat);
28662891

28672892
// for bilinear interpolation, we look for the nearest indices below and
28682893
// above proj
@@ -2926,6 +2951,158 @@ static Value BilinearInterpolate(OpBuilder &b,
29262951
return b.create<arith::AddFOp>(loc, left, right);
29272952
}
29282953

2954+
static Value BicubicInterpolate(OpBuilder &b,
2955+
Aten__InterpolateSizeListScaleListOp op,
2956+
Location loc, SmallVector<Value> outputSizes,
2957+
Value input, SmallVector<Value> inputSizes,
2958+
SmallVector<Value> scaleValues,
2959+
std::string coordStr) {
2960+
unsigned dimOffset = 2;
2961+
auto inputType = cast<RankedTensorType>(input.getType());
2962+
auto inputRank = inputType.getRank();
2963+
2964+
Value inputFPH =
2965+
b.create<arith::SIToFPOp>(loc, b.getF32Type(), inputSizes[0]);
2966+
Value inputFPW =
2967+
b.create<arith::SIToFPOp>(loc, b.getF32Type(), inputSizes[1]);
2968+
2969+
Value a = b.create<arith::ConstantOp>(loc, b.getF32FloatAttr(-0.75));
2970+
Value zero = b.create<arith::ConstantOp>(loc, b.getF32FloatAttr(0.0));
2971+
Value cstOneFloat = b.create<arith::ConstantOp>(loc, b.getF32FloatAttr(1.0));
2972+
Value cstTwoFloat = b.create<arith::ConstantOp>(loc, b.getF32FloatAttr(2.0));
2973+
Value cstThreeFloat =
2974+
b.create<arith::ConstantOp>(loc, b.getF32FloatAttr(3.0));
2975+
Value cstFourFloat = b.create<arith::ConstantOp>(loc, b.getF32FloatAttr(4.0));
2976+
Value cstFiveFloat = b.create<arith::ConstantOp>(loc, b.getF32FloatAttr(5.0));
2977+
Value cstEightFloat =
2978+
b.create<arith::ConstantOp>(loc, b.getF32FloatAttr(8.0));
2979+
2980+
auto WeightLessThanEqualOne = [&](Value xDistance) -> Value {
2981+
Value xDistanceSquared = b.create<arith::MulFOp>(loc, xDistance, xDistance);
2982+
Value xDistanceCubed =
2983+
b.create<arith::MulFOp>(loc, xDistanceSquared, xDistance);
2984+
2985+
Value lessEqualOne = b.create<arith::AddFOp>(loc, a, cstTwoFloat);
2986+
lessEqualOne = b.create<arith::MulFOp>(loc, xDistanceCubed, lessEqualOne);
2987+
Value aPlusThree = b.create<arith::AddFOp>(loc, a, cstThreeFloat);
2988+
aPlusThree = b.create<arith::MulFOp>(loc, xDistanceSquared, aPlusThree);
2989+
lessEqualOne = b.create<arith::SubFOp>(loc, lessEqualOne, aPlusThree);
2990+
lessEqualOne = b.create<arith::AddFOp>(loc, lessEqualOne, cstOneFloat);
2991+
2992+
return lessEqualOne;
2993+
};
2994+
2995+
auto WeightLessThanTwo = [&](Value xDistance) -> Value {
2996+
Value xDistanceSquared = b.create<arith::MulFOp>(loc, xDistance, xDistance);
2997+
Value xDistanceCubed =
2998+
b.create<arith::MulFOp>(loc, xDistanceSquared, xDistance);
2999+
Value lessThanTwo = b.create<arith::MulFOp>(loc, xDistanceCubed, a);
3000+
3001+
Value fiveA = b.create<arith::MulFOp>(loc, xDistanceSquared, a);
3002+
fiveA = b.create<arith::MulFOp>(loc, fiveA, cstFiveFloat);
3003+
lessThanTwo = b.create<arith::SubFOp>(loc, lessThanTwo, fiveA);
3004+
3005+
Value eightA = b.create<arith::MulFOp>(loc, a, xDistance);
3006+
eightA = b.create<arith::MulFOp>(loc, eightA, cstEightFloat);
3007+
lessThanTwo = b.create<arith::AddFOp>(loc, eightA, lessThanTwo);
3008+
3009+
Value fourA = b.create<arith::MulFOp>(loc, a, cstFourFloat);
3010+
lessThanTwo = b.create<arith::SubFOp>(loc, lessThanTwo, fourA);
3011+
return lessThanTwo;
3012+
};
3013+
3014+
bool alignCornersBool;
3015+
matchPattern(op.getAlignCorners(), m_TorchConstantBool(&alignCornersBool));
3016+
3017+
SmallVector<Value> indices;
3018+
for (unsigned i = 0; i < inputRank; i++) {
3019+
indices.push_back(b.create<linalg::IndexOp>(loc, i));
3020+
}
3021+
3022+
SmallVector<Value> proj;
3023+
proj = CoordinateTransform(b, op, loc, outputSizes, input, inputSizes,
3024+
scaleValues, coordStr, alignCornersBool, indices);
3025+
3026+
Value x1 = b.create<math::CeilOp>(loc, proj[1]);
3027+
Value x_1 = b.create<arith::SubFOp>(loc, x1, cstOneFloat);
3028+
Value x_2 = b.create<arith::SubFOp>(loc, x_1, cstOneFloat);
3029+
Value x2 = b.create<arith::AddFOp>(loc, x1, cstOneFloat);
3030+
3031+
Value y1 = b.create<math::CeilOp>(loc, proj[0]);
3032+
Value y_1 = b.create<arith::SubFOp>(loc, y1, cstOneFloat);
3033+
Value y_2 = b.create<arith::SubFOp>(loc, y_1, cstOneFloat);
3034+
Value y2 = b.create<arith::AddFOp>(loc, y1, cstOneFloat);
3035+
3036+
Value y2Distance = b.create<arith::SubFOp>(loc, proj[0], y2);
3037+
y2Distance = b.create<math::AbsFOp>(loc, y2Distance);
3038+
Value y1Distance = b.create<arith::SubFOp>(loc, proj[0], y1);
3039+
y1Distance = b.create<math::AbsFOp>(loc, y1Distance);
3040+
Value y_1Distance = b.create<arith::SubFOp>(loc, proj[0], y_1);
3041+
y_1Distance = b.create<math::AbsFOp>(loc, y_1Distance);
3042+
Value y_2Distance = b.create<arith::SubFOp>(loc, proj[0], y_2);
3043+
y_2Distance = b.create<math::AbsFOp>(loc, y_2Distance);
3044+
3045+
Value x2Distance = b.create<arith::SubFOp>(loc, proj[1], x2);
3046+
x2Distance = b.create<math::AbsFOp>(loc, x2Distance);
3047+
Value x1Distance = b.create<arith::SubFOp>(loc, proj[1], x1);
3048+
x1Distance = b.create<math::AbsFOp>(loc, x1Distance);
3049+
Value x_1Distance = b.create<arith::SubFOp>(loc, proj[1], x_1);
3050+
x_1Distance = b.create<math::AbsFOp>(loc, x_1Distance);
3051+
Value x_2Distance = b.create<arith::SubFOp>(loc, proj[1], x_2);
3052+
x_2Distance = b.create<math::AbsFOp>(loc, x_2Distance);
3053+
3054+
SmallVector<Value> y{y_2, y_1, y1, y2};
3055+
SmallVector<Value> x{x_2, x_1, x1, x2};
3056+
SmallVector<Value> yDistance{y_2Distance, y_1Distance, y1Distance,
3057+
y2Distance};
3058+
SmallVector<Value> xDistance{x_2Distance, x_1Distance, x1Distance,
3059+
x2Distance};
3060+
SmallVector<Value> wys{
3061+
WeightLessThanTwo(y_2Distance), WeightLessThanEqualOne(y_1Distance),
3062+
WeightLessThanEqualOne(y1Distance), WeightLessThanTwo(y2Distance)};
3063+
SmallVector<Value> wxs{
3064+
WeightLessThanTwo(x_2Distance), WeightLessThanEqualOne(x_1Distance),
3065+
WeightLessThanEqualOne(x1Distance), WeightLessThanTwo(x2Distance)};
3066+
SmallVector<Value> xInterp{zero, zero, zero, zero};
3067+
3068+
// f(x_orig, y_orig) = Sum_y Sum_x W(x_original - x)*input[x,y]
3069+
// * W(y_original - y)
3070+
Value fxy = zero;
3071+
3072+
for (int j = 0; j < 4; j++) {
3073+
Value wy = wys[j];
3074+
Value xInterpy = xInterp[j];
3075+
for (int i = 0; i < 4; i++) {
3076+
Value wx = wxs[i];
3077+
3078+
y[j] = b.create<arith::MaximumFOp>(loc, y[j], zero);
3079+
Value inputHSubOne = b.create<arith::SubFOp>(loc, inputFPH, cstOneFloat);
3080+
y[j] = b.create<arith::MinimumFOp>(loc, y[j], inputHSubOne);
3081+
3082+
Value yInt = b.create<arith::FPToSIOp>(loc, b.getI64Type(), y[j]);
3083+
Value yIndex = b.create<arith::IndexCastOp>(loc, b.getIndexType(), yInt);
3084+
indices[dimOffset] = yIndex;
3085+
3086+
x[i] = b.create<arith::MaximumFOp>(loc, x[i], zero);
3087+
Value inputWSubOne = b.create<arith::SubFOp>(loc, inputFPW, cstOneFloat);
3088+
x[i] = b.create<arith::MinimumFOp>(loc, x[i], inputWSubOne);
3089+
3090+
Value xInt = b.create<arith::FPToSIOp>(loc, b.getI64Type(), x[i]);
3091+
Value xIndex = b.create<arith::IndexCastOp>(loc, b.getIndexType(), xInt);
3092+
indices[dimOffset + 1] = xIndex;
3093+
3094+
Value p = b.create<tensor::ExtractOp>(loc, input, indices);
3095+
3096+
Value wxp = b.create<arith::MulFOp>(loc, wx, p);
3097+
xInterpy = b.create<arith::AddFOp>(loc, xInterpy, wxp);
3098+
}
3099+
Value wyXInterpy = b.create<arith::MulFOp>(loc, wy, xInterpy);
3100+
fxy = b.create<arith::AddFOp>(loc, fxy, wyXInterpy);
3101+
}
3102+
3103+
return fxy;
3104+
}
3105+
29293106
namespace {
29303107
class ConvertInterpolateOp
29313108
: public OpConversionPattern<Aten__InterpolateSizeListScaleListOp> {
@@ -2941,7 +3118,8 @@ class ConvertInterpolateOp
29413118
// coordinate_transformation_mode="asymmetric" will lower to an interpolate
29423119
// op with the non-standard mode="bilinear_asymmetric".
29433120
matchPattern(op.getMode(), m_TorchConstantStr(mode));
2944-
if (mode.substr(0, 8) != "bilinear" && mode.substr(0, 7) != "nearest") {
3121+
if (mode.substr(0, 8) != "bilinear" && mode.substr(0, 7) != "nearest" &&
3122+
mode.substr(0, 5) != "cubic") {
29453123
return failure();
29463124
}
29473125

@@ -3030,6 +3208,10 @@ class ConvertInterpolateOp
30303208
retVal = BilinearInterpolate(
30313209
b, op, loc, outputSizeIntValues, input, inputSizes,
30323210
ScaleFactorFloatValues, mode.substr(8));
3211+
} else if (mode.substr(0, 5) == "cubic") {
3212+
retVal = BicubicInterpolate(
3213+
b, op, loc, outputSizeIntValues, input, inputSizes,
3214+
ScaleFactorFloatValues, mode.substr(5));
30333215
}
30343216
b.create<linalg::YieldOp>(loc, retVal);
30353217
})

test/Conversion/TorchToLinalg/resize.mlir

Lines changed: 76 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,14 @@ func.func @test_resize_sizes_linear(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1:
2121
// CHECK-DAG: %[[x25:.*]] = arith.divf %[[x24]], %[[x21]] : f32
2222
// CHECK-DAG: %[[x26:.*]] = arith.subf %[[x25]], %[[cst_4]] : f32
2323
// CHECK-DAG: %[[x27:.*]] = arith.maximumf %[[x26]], %[[cst_5]] : f32
24-
// CHECK-DAG: %[[x28:.*]] = arith.subf %[[x19]], %[[cst]] : f32
24+
// CHECK-DAG: %[[x28:.*]] = arith.subf %[[x19]], %cst_4 : f32
2525
// CHECK-DAG: %[[x29:.*]] = arith.minimumf %[[x27]], %[[x28]] : f32
2626
// CHECK-DAG: %[[x30:.*]] = math.floor %[[x29]] : f32
2727
// CHECK-DAG: %[[x31:.*]] = arith.addf %[[cst]], %[[x29]] : f32
2828
// CHECK-DAG: %[[x32:.*]] = math.floor %[[x31]] : f32
2929
// CHECK-DAG: %[[x33:.*]] = arith.fptosi %[[x30]] : f32 to i64
3030
// CHECK-DAG: %[[x34:.*]] = arith.index_cast %[[x33]] : i64 to index
31-
// CHECK-DAG: %[[x35:.*]] = arith.minimumf %[[x31]], %[[x28]] : f32
31+
// CHECK-DAG: %[[x35:.*]] = arith.minimumf %44, %42 : f32
3232
// CHECK-DAG: %[[x36:.*]] = arith.fptosi %[[x35]] : f32 to i64
3333
// CHECK-DAG: %[[x37:.*]] = arith.index_cast %[[x36]] : i64 to index
3434
// CHECK: %[[extracted:.*]] = tensor.extract %[[x0]][%[[x15]], %[[x16]], %[[x34]], %[[low:.*]]] : tensor<1x1x2x4xf32>
@@ -304,4 +304,78 @@ func.func @test_resize_nearest_half_pixel_round_prefer_floor(%arg0: !torch.vtens
304304
return %5 : !torch.vtensor<[?,?,?],f32>
305305
}
306306

307+
// CHECK-LABEL: func.func @test_resize_sizes_cubic
308+
func.func @test_resize_sizes_cubic(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4]
309+
,si64>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19
310+
: si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
311+
// CHECK-DAG: %[[x1:.*]] = math.ceil %42 : f32
312+
// CHECK-DAG: %[[x_1:.*]] = arith.subf %[[x1]], %cst : f32
313+
// CHECK-DAG: %[[x_2:.*]] = arith.subf %[[x_1]], %cst : f32
314+
// CHECK-DAG: %[[x2:.*]] = arith.addf %[[x1]], %cst : f32
315+
// CHECK-DAG: %[[y1:.*]] = math.ceil %31 : f32
316+
// CHECK-DAG: %[[y_1:.*]] = arith.subf %[[y1]], %cst : f32
317+
// CHECK-DAG: %[[y_2:.*]] = arith.subf %[[y_1]], %cst : f32
318+
// CHECK-DAG: %[[y2:.*]] = arith.addf %[[y1]], %cst : f32
319+
// CHECK-DAG: %[[y2D:.*]] = arith.subf %31, %[[y2]] : f32
320+
// CHECK-DAG: %[[y2Dist:.*]] = math.absf %[[y2D]] : f32
321+
// CHECK-DAG: %[[y1D:.*]] = arith.subf %31, %[[y1]] : f32
322+
// CHECK-DAG: %[[y1Dist:.*]] = math.absf %[[y1D]] : f32
323+
// CHECK-DAG: %[[y_1D:.*]] = arith.subf %31, %[[y_1]] : f32
324+
// CHECK-DAG: %[[y_1Dist:.*]] = math.absf %[[y_1D]] : f32
325+
// CHECK-DAG: %[[y_2D:.*]] = arith.subf %31, %[[y_2]] : f32
326+
// CHECK-DAG: %[[y_2Dist:.*]] = math.absf %[[y_2D]] : f32
327+
// CHECK-DAG: %[[x2D:.*]] = arith.subf %42, %[[x2]] : f32
328+
// CHECK-DAG: %[[x2Dist:.*]] = math.absf %[[x2D]] : f32
329+
// CHECK-DAG: %[[x1D:.*]] = arith.subf %42, %[[x1]] : f32
330+
// CHECK-DAG: %[[x1Dist:.*]] = math.absf %[[x1D]] : f32
331+
// CHECK-DAG: %[[x_1D:.*]] = arith.subf %42, %[[x_1]] : f32
332+
// CHECK-DAG: %[[x_1Dist:.*]] = math.absf %[[x_1D]] : f32
333+
// CHECK-DAG: %[[x_2D:.*]] = arith.subf %42, %[[x_2]] : f32
334+
// CHECK-DAG: %[[x_2Dist:.*]] = math.absf %[[x_2D]] : f32
335+
// CHECK-DAG: %[[cst_8:.*]] = arith.constant -7.500000e-01 : f32
336+
// CHECK-DAG: %[[cst_9:.*]] = arith.constant 0.000000e+00 : f32
337+
// CHECK-DAG: %[[cst_10:.*]] = arith.constant 1.000000e+00 : f32
338+
// CHECK-DAG: %[[cst_11:.*]] = arith.constant 2.000000e+00 : f32
339+
// CHECK-DAG: %[[cst_12:.*]] = arith.constant 3.000000e+00 : f32
340+
// CHECK-DAG: %[[cst_13:.*]] = arith.constant 4.000000e+00 : f32
341+
// CHECK-DAG: %[[cst_14:.*]] = arith.constant 5.000000e+00 : f32
342+
// CHECK-DAG: %[[cst_15:.*]] = arith.constant 8.000000e+00 : f32
343+
// CHECK-DAG: %[[distSQ:.*]] = arith.mulf %[[y_2Dist]], %[[y_2Dist]] : f32
344+
// CHECK-DAG: %[[distCubed:.*]] = arith.mulf %[[distSQ]], %[[y_2Dist]] : f32
345+
// CHECK-DAG: %[[x69:.*]] = arith.mulf %[[distCubed]], %[[cst_8]] : f32
346+
// CHECK-DAG: %[[x70:.*]] = arith.mulf %[[distSQ]], %[[cst_8]] : f32
347+
// CHECK-DAG: %[[x71:.*]] = arith.mulf %[[x70]], %[[cst_14]] : f32
348+
// CHECK-DAG: %[[x72:.*]] = arith.subf %[[x69]], %[[x71]] : f32
349+
// CHECK-DAG: %[[x73:.*]] = arith.mulf %[[cst_8]], %[[y_2Dist]] : f32
350+
// CHECK-DAG: %[[x74:.*]] = arith.mulf %[[x73]], %[[cst_15]] : f32
351+
// CHECK-DAG: %[[x75:.*]] = arith.addf %[[x74]], %[[x72]] : f32
352+
// CHECK-DAG: %[[x76:.*]] = arith.mulf %[[cst_8]], %[[cst_13]] : f32
353+
// CHECK-DAG: %[[x77:.*]] = arith.subf %[[x75]], %[[x76]] : f32
354+
// CHECK-DAG: %[[x78:.*]] = arith.addf %[[cst_8]], %[[cst_11]] : f32
355+
// CHECK-DAG: %[[x79:.*]] = arith.mulf %[[distCubed]], %[[x78]] : f32
356+
// CHECK-DAG: %[[x80:.*]] = arith.addf %[[cst_8]], %[[cst_12]] : f32
357+
// CHECK-DAG: %[[x81:.*]] = arith.mulf %[[distSQ]], %[[x80]] : f32
358+
// CHECK-DAG: %[[x82:.*]] = arith.subf %[[x79]], %[[x81]] : f32
359+
// CHECK-DAG: %[[x83:.*]] = arith.addf %82, %cst_10 : f32
360+
// CHECK-DAG: %[[x84:.*]] = arith.cmpf uge, %[[y_2Dist]], %[[cst_11]] : f32
361+
// CHECK-DAG: %[[x85:.*]] = arith.select %84, %cst_9, %77 : f32
362+
// CHECK-DAG: %[[x86:.*]] = arith.cmpf ule, %[[y_2Dist]], %cst_10 : f32
363+
// CHECK-DAG: %[[x87:.*]] = arith.select %86, %83, %85 : f32
364+
%none = torch.constant.none
365+
%none_0 = torch.constant.none
366+
%int0 = torch.constant.int 0
367+
%false = torch.constant.bool false
368+
%true = torch.constant.bool true
369+
%str = torch.constant.str "cubic"
370+
%int2 = torch.constant.int 2
371+
%0 = torch.aten.select.int %arg1, %int0, %int2 : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
372+
%1 = torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int
373+
%int3 = torch.constant.int 3
374+
%2 = torch.aten.select.int %arg1, %int0, %int3 : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
375+
%3 = torch.aten.item %2 : !torch.vtensor<[1],si64> -> !torch.int
376+
%4 = torch.prim.ListConstruct %1, %3 : (!torch.int, !torch.int) -> !torch.list<int>
377+
%5 = torch.aten.__interpolate.size_list_scale_list %arg0, %4, %none_0, %str, %false, %none_0, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list<int>, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?],f32>
378+
return %5 : !torch.vtensor<[?,?,?,?],f32>
379+
}
380+
307381
// -----

0 commit comments

Comments
 (0)