@@ -420,13 +420,15 @@ struct BilinearParams {
420
420
// that amounts to 'Bilinear' Upsampling/Resizing in the sense that it assumes
421
421
// the scale values for the outermost 2 dimensions are 1.
422
422
// This is the common use-case where the 4-D input (batched multi-channel images)
423
- // is usually of shape [N, C, H, W] and the scales are [1.0, 1.0, height_scale, width_scale]
424
- static BilinearParams SetupUpsampleBilinear (int64_t input_height,
425
- int64_t input_width,
426
- int64_t output_height,
427
- int64_t output_width,
428
- float height_scale,
429
- float width_scale,
423
+ // is usually of shapes:
424
+ // - [N, C, H, W] and the scales are [1.0, 1.0, height_scale, width_scale]
425
+ // - [N, H, W, C] and the scales are [1.0, height_scale, width_scale, 1.0]
426
+ static BilinearParams SetupUpsampleBilinear (const int64_t input_height,
427
+ const int64_t input_width,
428
+ const int64_t output_height,
429
+ const int64_t output_width,
430
+ const float height_scale,
431
+ const float width_scale,
430
432
const std::vector<float >& roi,
431
433
AllocatorPtr& alloc,
432
434
const GetOriginalCoordinateFunc& get_original_coordinate) {
@@ -523,26 +525,25 @@ static BilinearParams SetupUpsampleBilinear(int64_t input_height,
523
525
}
524
526
525
527
template <typename T>
526
- void UpsampleBilinear (int64_t batch_size,
527
- int64_t num_channels,
528
- int64_t input_height,
529
- int64_t input_width,
530
- int64_t output_height,
531
- int64_t output_width,
532
- float height_scale,
533
- float width_scale,
528
+ void UpsampleBilinear (const int64_t batch_size,
529
+ const int64_t num_channels,
530
+ const int64_t input_height,
531
+ const int64_t input_width,
532
+ const int64_t output_height,
533
+ const int64_t output_width,
534
+ const float height_scale,
535
+ const float width_scale,
534
536
const std::vector<float >& roi,
535
- bool use_extrapolation,
536
- float extrapolation_value,
537
- const T* XdataBase,
538
- T* YdataBase,
537
+ const bool use_extrapolation,
538
+ const float extrapolation_value,
539
+ const T* const XdataBase,
540
+ T* const YdataBase,
539
541
AllocatorPtr& alloc,
540
542
const GetOriginalCoordinateFunc& get_original_coordinate,
541
543
concurrency::ThreadPool* tp) {
542
544
BilinearParams p = SetupUpsampleBilinear (input_height, input_width, output_height, output_width,
543
545
height_scale, width_scale, roi,
544
546
alloc, get_original_coordinate);
545
-
546
547
for (int64_t n = 0 ; n < batch_size; ++n) {
547
548
concurrency::ThreadPool::TrySimpleParallelFor (
548
549
tp, num_channels,
@@ -1065,22 +1066,65 @@ Status Upsample<T>::BaseCompute(OpKernelContext* context,
1065
1066
case UpsampleMode::LINEAR: {
1066
1067
// Supports 'bilinear' and 'trilinear' sampling only
1067
1068
1068
- // 'bilinear' == 2-D input or 4-D input with outermost 2 scales as 1
1069
+ // 'bilinear' == 2-D input or 4-D input with outermost 2 scales as 1 or
1070
+ // 4-D input with outermost and innermost scales as 1
1069
1071
if (dims.size () == 2 || dims.size () == 4 ) {
1070
1072
bool is_2D = dims.size () == 2 ;
1071
1073
1072
- const int64_t batch_size = is_2D ? 1 : dims[0 ];
1073
- const int64_t num_channels = is_2D ? 1 : dims[1 ];
1074
- const int64_t input_height = is_2D ? dims[0 ] : dims[2 ];
1075
- const int64_t input_width = is_2D ? dims[1 ] : dims[3 ];
1076
-
1077
- const int64_t output_height = is_2D ? output_dims[0 ] : output_dims[2 ];
1078
- const int64_t output_width = is_2D ? output_dims[1 ] : output_dims[3 ];
1074
+ int64_t batch_size;
1075
+ int64_t num_channels;
1076
+ int64_t input_height;
1077
+ int64_t input_width;
1078
+
1079
+ int64_t output_height;
1080
+ int64_t output_width;
1081
+
1082
+ float height_scale;
1083
+ float width_scale;
1084
+
1085
+ if (is_2D) {
1086
+ batch_size = 1 ;
1087
+ num_channels = 1 ;
1088
+ input_height = dims[0 ];
1089
+ input_width = dims[1 ];
1090
+
1091
+ output_height = output_dims[0 ];
1092
+ output_width = output_dims[1 ];
1093
+
1094
+ height_scale = scales[0 ];
1095
+ width_scale = scales[1 ];
1096
+ } else {
1097
+ if (scales[1 ] == 1 .0f ) {
1098
+ batch_size = dims[0 ];
1099
+ num_channels = dims[1 ];
1100
+ input_height = dims[2 ];
1101
+ input_width = dims[3 ];
1102
+
1103
+ output_height = output_dims[2 ];
1104
+ output_width = output_dims[3 ];
1105
+
1106
+ height_scale = scales[2 ];
1107
+ width_scale = scales[3 ];
1108
+ } else {
1109
+ ORT_ENFORCE (scales[3 ] == 1 .0f , " 4-D input with innermost scale (usually channel of NHWC) as 1." );
1110
+
1111
+ batch_size = dims[0 ];
1112
+ num_channels = dims[3 ];
1113
+ input_height = dims[1 ];
1114
+ input_width = dims[2 ];
1115
+
1116
+ output_height = output_dims[1 ];
1117
+ output_width = output_dims[2 ];
1118
+
1119
+ height_scale = scales[1 ];
1120
+ width_scale = scales[2 ];
1121
+ }
1122
+ }
1079
1123
1080
1124
AllocatorPtr alloc;
1081
1125
ORT_RETURN_IF_ERROR (context->GetTempSpaceAllocator (&alloc));
1082
1126
UpsampleBilinear (batch_size, num_channels, input_height, input_width, output_height, output_width,
1083
- is_2D ? scales[ 0 ] : scales[ 2 ], is_2D ? scales[ 1 ] : scales[ 3 ] , roi,
1127
+ height_scale, width_scale , roi,
1084
1128
use_extrapolation_, extrapolation_value_, X->Data <T>(),
1085
1129
Y->MutableData <T>(), alloc, get_original_coordinate_,
1086
1130
output_height * output_width > 64 ? context->GetOperatorThreadPool () : nullptr );
0 commit comments