Skip to content

Commit bd08f11

Browse files
authored
Upsample support NHWC (microsoft#10554)
Implement bilinear interpolation for Upsample (Resize) 4-D input with the outermost and innermost scale (usually channel of NHWC) as 1. Besides, I revert the HandleResize back to the original implementation for TransposeOptimizerTests.TestResize* tests.
1 parent e0d1d69 commit bd08f11

File tree

4 files changed

+342
-275
lines changed

4 files changed

+342
-275
lines changed

onnxruntime/core/optimizer/transpose_optimizer/transpose_optimizer.cc

Lines changed: 29 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -967,41 +967,35 @@ static void PermuteInput(api::GraphRef& graph, api::NodeRef& node, size_t i, con
967967
node.SetInput(i, gather_output);
968968
}
969969

970-
// static bool HandleResize(HandlerArgs& args) {
971-
// auto inputs = args.node.Inputs();
972-
// int64_t rank_int = gsl::narrow_cast<int64_t>(args.perm.size());
973-
//
974-
// auto p = ChannelFirstToLastPerm(rank_int);
975-
// auto& perm = p == args.perm ? args.perm : args.perm_inv;
976-
// auto& perm_inv = p == args.perm ? args.perm_inv : args.perm;
977-
//
978-
// if (args.ctx.opset < 11) {
979-
// PermuteInput(args.ctx.graph, args.node, 1, perm);
980-
// } else {
981-
// if (inputs[1] != "") {
982-
// std::vector<int64_t> double_perm_inv = perm;
983-
// double_perm_inv.reserve(2 * args.perm.size());
984-
// for (int64_t p1 : perm) {
985-
// double_perm_inv.push_back(p1 + rank_int);
986-
// }
987-
// PermuteInput(args.ctx.graph, args.node, 1, double_perm_inv);
988-
// }
989-
// for (size_t i = 2; i < inputs.size(); ++i) {
990-
// if (inputs[i] != "") {
991-
// PermuteInput(args.ctx.graph, args.node, i, perm);
992-
// }
993-
// }
994-
// }
995-
//
996-
// TransposeFirstInput(args.ctx, args.node, perm);
997-
// TransposeOutputs(args.ctx, args.node, perm_inv);
998-
//
999-
// SwapNodeOpTypeAndDomain(args.ctx.graph, args.node, args.node.OpType(), "com.microsoft.nhwc");
1000-
//
1001-
// return true;
1002-
// }
970+
static bool HandleResize(HandlerArgs& args) {
971+
auto inputs = args.node.Inputs();
972+
int64_t rank_int = gsl::narrow_cast<int64_t>(args.perm.size());
973+
974+
if (args.ctx.opset < 11) {
975+
PermuteInput(args.ctx.graph, args.node, 1, args.perm_inv);
976+
} else {
977+
if (inputs[1] != "") {
978+
std::vector<int64_t> double_perm_inv = args.perm_inv;
979+
double_perm_inv.reserve(2 * args.perm_inv.size());
980+
for (int64_t p : args.perm_inv) {
981+
double_perm_inv.push_back(p + rank_int);
982+
}
983+
PermuteInput(args.ctx.graph, args.node, 1, double_perm_inv);
984+
}
985+
for (size_t i = 2; i < inputs.size(); ++i) {
986+
if (inputs[i] != "") {
987+
PermuteInput(args.ctx.graph, args.node, i, args.perm_inv);
988+
}
989+
}
990+
}
991+
992+
TransposeFirstInput(args.ctx, args.node, args.perm_inv);
993+
TransposeOutputs(args.ctx, args.node, args.perm);
994+
995+
return true;
996+
}
1003997

1004-
// constexpr HandlerInfo resize_handler = {&FirstInput, &HandleResize};
998+
constexpr HandlerInfo resize_handler = {&FirstInput, &HandleResize};
1005999

10061000
static bool HandlePad(HandlerArgs& args) {
10071001
size_t rank = args.perm.size();
@@ -1640,9 +1634,7 @@ static const std::unordered_map<std::string_view, const HandlerInfo&> handler_ma
16401634
{"Split", split_handler},
16411635
{"Shape", shape_handler},
16421636
{"Pad", pad_handler},
1643-
// Todo: renable resize handler after adding NHWC support in upsample op on cpu
1644-
// https://github.com/microsoft/onnxruntime/issues/9857
1645-
// {"Resize", resize_handler},
1637+
{"Resize", resize_handler},
16461638
{"ReduceSum", reduce_sum_handler},
16471639

16481640
{"ReduceLogSum", reduce_op_handler},

onnxruntime/core/providers/cpu/tensor/upsample.cc

Lines changed: 73 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -420,13 +420,15 @@ struct BilinearParams {
420420
// that amounts to 'Bilinear' Upsampling/Resizing in the sense that it assumes
421421
// the scale values for the outermost 2 dimensions are 1.
422422
// This is the common use-case where the 4-D input (batched multi-channel images)
423-
// is usually of shape [N, C, H, W] and the scales are [1.0, 1.0, height_scale, width_scale]
424-
static BilinearParams SetupUpsampleBilinear(int64_t input_height,
425-
int64_t input_width,
426-
int64_t output_height,
427-
int64_t output_width,
428-
float height_scale,
429-
float width_scale,
423+
// is usually of shapes:
424+
// - [N, C, H, W] and the scales are [1.0, 1.0, height_scale, width_scale]
425+
// - [N, H, W, C] and the scales are [1.0, height_scale, width_scale, 1.0]
426+
static BilinearParams SetupUpsampleBilinear(const int64_t input_height,
427+
const int64_t input_width,
428+
const int64_t output_height,
429+
const int64_t output_width,
430+
const float height_scale,
431+
const float width_scale,
430432
const std::vector<float>& roi,
431433
AllocatorPtr& alloc,
432434
const GetOriginalCoordinateFunc& get_original_coordinate) {
@@ -523,26 +525,25 @@ static BilinearParams SetupUpsampleBilinear(int64_t input_height,
523525
}
524526

525527
template <typename T>
526-
void UpsampleBilinear(int64_t batch_size,
527-
int64_t num_channels,
528-
int64_t input_height,
529-
int64_t input_width,
530-
int64_t output_height,
531-
int64_t output_width,
532-
float height_scale,
533-
float width_scale,
528+
void UpsampleBilinear(const int64_t batch_size,
529+
const int64_t num_channels,
530+
const int64_t input_height,
531+
const int64_t input_width,
532+
const int64_t output_height,
533+
const int64_t output_width,
534+
const float height_scale,
535+
const float width_scale,
534536
const std::vector<float>& roi,
535-
bool use_extrapolation,
536-
float extrapolation_value,
537-
const T* XdataBase,
538-
T* YdataBase,
537+
const bool use_extrapolation,
538+
const float extrapolation_value,
539+
const T* const XdataBase,
540+
T* const YdataBase,
539541
AllocatorPtr& alloc,
540542
const GetOriginalCoordinateFunc& get_original_coordinate,
541543
concurrency::ThreadPool* tp) {
542544
BilinearParams p = SetupUpsampleBilinear(input_height, input_width, output_height, output_width,
543545
height_scale, width_scale, roi,
544546
alloc, get_original_coordinate);
545-
546547
for (int64_t n = 0; n < batch_size; ++n) {
547548
concurrency::ThreadPool::TrySimpleParallelFor(
548549
tp, num_channels,
@@ -1065,22 +1066,65 @@ Status Upsample<T>::BaseCompute(OpKernelContext* context,
10651066
case UpsampleMode::LINEAR: {
10661067
// Supports 'bilinear' and 'trilinear' sampling only
10671068

1068-
//'bilinear' == 2-D input or 4-D input with outermost 2 scales as 1
1069+
//'bilinear' == 2-D input or 4-D input with outermost 2 scales as 1 or
1070+
// 4-D input with outermost and innermost scales as 1
10691071
if (dims.size() == 2 || dims.size() == 4) {
10701072
bool is_2D = dims.size() == 2;
10711073

1072-
const int64_t batch_size = is_2D ? 1 : dims[0];
1073-
const int64_t num_channels = is_2D ? 1 : dims[1];
1074-
const int64_t input_height = is_2D ? dims[0] : dims[2];
1075-
const int64_t input_width = is_2D ? dims[1] : dims[3];
1076-
1077-
const int64_t output_height = is_2D ? output_dims[0] : output_dims[2];
1078-
const int64_t output_width = is_2D ? output_dims[1] : output_dims[3];
1074+
int64_t batch_size;
1075+
int64_t num_channels;
1076+
int64_t input_height;
1077+
int64_t input_width;
1078+
1079+
int64_t output_height;
1080+
int64_t output_width;
1081+
1082+
float height_scale;
1083+
float width_scale;
1084+
1085+
if (is_2D) {
1086+
batch_size = 1;
1087+
num_channels = 1;
1088+
input_height = dims[0];
1089+
input_width = dims[1];
1090+
1091+
output_height = output_dims[0];
1092+
output_width = output_dims[1];
1093+
1094+
height_scale = scales[0];
1095+
width_scale = scales[1];
1096+
} else {
1097+
if (scales[1] == 1.0f) {
1098+
batch_size = dims[0];
1099+
num_channels = dims[1];
1100+
input_height = dims[2];
1101+
input_width = dims[3];
1102+
1103+
output_height = output_dims[2];
1104+
output_width = output_dims[3];
1105+
1106+
height_scale = scales[2];
1107+
width_scale = scales[3];
1108+
} else {
1109+
ORT_ENFORCE(scales[3] == 1.0f, "4-D input with innermost scale (usually channel of NHWC) as 1.");
1110+
1111+
batch_size = dims[0];
1112+
num_channels = dims[3];
1113+
input_height = dims[1];
1114+
input_width = dims[2];
1115+
1116+
output_height = output_dims[1];
1117+
output_width = output_dims[2];
1118+
1119+
height_scale = scales[1];
1120+
width_scale = scales[2];
1121+
}
1122+
}
10791123

10801124
AllocatorPtr alloc;
10811125
ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&alloc));
10821126
UpsampleBilinear(batch_size, num_channels, input_height, input_width, output_height, output_width,
1083-
is_2D ? scales[0] : scales[2], is_2D ? scales[1] : scales[3], roi,
1127+
height_scale, width_scale, roi,
10841128
use_extrapolation_, extrapolation_value_, X->Data<T>(),
10851129
Y->MutableData<T>(), alloc, get_original_coordinate_,
10861130
output_height * output_width > 64 ? context->GetOperatorThreadPool() : nullptr);

0 commit comments

Comments
 (0)