@@ -268,11 +268,11 @@ inline bool Mean(const T* input_data, const int* input_dims,
268
268
return true ;
269
269
}
270
270
271
- template <typename T>
272
271
inline void Mean (const tflite::MeanParams& op_params,
273
272
const RuntimeShape& unextended_input_shape,
274
- const T* input_data,
275
- const RuntimeShape& unextended_output_shape, T* output_data) {
273
+ const float * input_data,
274
+ const RuntimeShape& unextended_output_shape,
275
+ float * output_data) {
276
276
ruy::profiler::ScopeLabel label (" Mean4D" );
277
277
278
278
// Current implementation only supports dimension equals 4 and simultaneous
@@ -312,78 +312,21 @@ inline void Mean(const tflite::MeanParams& op_params,
312
312
}
313
313
}
314
314
315
- inline void Mean (const tflite::MeanParams& op_params,
316
- const RuntimeShape& unextended_input_shape,
317
- const uint8_t * input_data, int32_t input_zero_point,
318
- float input_scale, const RuntimeShape& unextended_output_shape,
319
- uint8_t * output_data, int32_t output_zero_point,
320
- float output_scale) {
321
- ruy::profiler::ScopeLabel label (" Mean4D/Uint8" );
322
-
323
- // Current implementation only supports dimension equals 4 and simultaneous
324
- // reduction over width and height.
325
- TFLITE_CHECK_EQ (unextended_input_shape.DimensionsCount (), 4 );
326
- TFLITE_CHECK_LE (unextended_output_shape.DimensionsCount (), 4 );
327
- const RuntimeShape input_shape =
328
- RuntimeShape::ExtendedShape (4 , unextended_input_shape);
329
- const RuntimeShape output_shape =
330
- RuntimeShape::ExtendedShape (4 , unextended_output_shape);
331
- const int output_batch = output_shape.Dims (0 );
332
- const int output_height = output_shape.Dims (1 );
333
- const int output_width = output_shape.Dims (2 );
334
- const int output_depth = output_shape.Dims (3 );
335
- const int input_height = input_shape.Dims (1 );
336
- const int input_width = input_shape.Dims (2 );
337
- const float num_elements_in_axis = input_width * input_height;
338
-
339
- TFLITE_CHECK_EQ (op_params.axis_count , 2 );
340
- TFLITE_CHECK ((op_params.axis [0 ] == 1 && op_params.axis [1 ] == 2 ) ||
341
- (op_params.axis [0 ] == 2 && op_params.axis [1 ] == 1 ));
342
- TFLITE_CHECK_EQ (output_height, 1 );
343
- TFLITE_CHECK_EQ (output_width, 1 );
344
-
345
- constexpr int32_t kMinValue = std::numeric_limits<uint8_t >::min ();
346
- constexpr int32_t kMaxValue = std::numeric_limits<uint8_t >::max ();
347
-
348
- float temp = input_zero_point * input_scale / output_scale;
349
- temp = temp > 0 ? temp + 0 .5f : temp - 0 .5f ;
350
- int32_t bias = output_zero_point - static_cast <int32_t >(temp);
351
- double real_scale =
352
- static_cast <double >(input_scale / (num_elements_in_axis * output_scale));
353
-
354
- int32_t multiplier;
355
- int shift;
356
- QuantizeMultiplier (real_scale, &multiplier, &shift);
357
- for (int out_b = 0 ; out_b < output_batch; ++out_b) {
358
- for (int out_d = 0 ; out_d < output_depth; ++out_d) {
359
- int32_t acc = 0 ;
360
- for (int in_h = 0 ; in_h < input_height; ++in_h) {
361
- for (int in_w = 0 ; in_w < input_width; ++in_w) {
362
- acc += input_data[Offset (input_shape, out_b, in_h, in_w, out_d)];
363
- }
364
- }
365
- acc = MultiplyByQuantizedMultiplier (acc, multiplier, shift);
366
- acc += bias;
367
- acc = std::min (std::max (acc, kMinValue ), kMaxValue );
368
- output_data[Offset (output_shape, out_b, 0 , 0 , out_d)] =
369
- static_cast <uint8_t >(acc);
370
- }
371
- }
372
- }
373
-
374
315
// Computes the mean of elements across dimensions given in axis.
375
316
// It does so in two stages, first calculates the sum of elements along the axis
376
317
// then divides it by the number of element in axis for quantized values.
377
318
template <typename T, typename U>
378
319
inline bool QuantizedMeanOrSum (const T* input_data, int32_t input_zero_point,
379
- float input_scale , const int * input_dims ,
380
- const int input_num_dims, T* output_data,
381
- int32_t output_zero_point, float output_scale ,
320
+ const int * input_dims , const int input_num_dims ,
321
+ T* output_data, int32_t output_multiplier ,
322
+ int output_shift, int32_t output_zero_point ,
382
323
const int * output_dims,
383
324
const int output_num_dims, const int * axis,
384
325
const int num_axis_dimensions, bool keep_dims,
385
326
int * temp_index, int * resolved_axis, U* temp_sum,
386
327
bool compute_sum) {
328
+ const int32_t kMinValue = std::numeric_limits<T>::min ();
329
+ const int32_t kMaxValue = std::numeric_limits<T>::max ();
387
330
const bool uint8_case = std::is_same<T, uint8_t >::value;
388
331
const bool int16_case = std::is_same<T, int16_t >::value;
389
332
if (uint8_case) {
@@ -430,40 +373,46 @@ inline bool QuantizedMeanOrSum(const T* input_data, int32_t input_zero_point,
430
373
}
431
374
432
375
// Calculate mean by dividing output_data by num of aggregated element.
433
- size_t num_elements_in_axis = 1 ;
376
+ int64_t num_elements_in_axis = 1 ;
434
377
for (int idx = 0 ; idx < num_resolved_axis; ++idx) {
435
378
size_t current = static_cast <size_t >(input_dims[resolved_axis[idx]]);
436
379
// Overflow prevention.
437
- if (current > (std::numeric_limits<size_t >::max () / num_elements_in_axis)) {
380
+ if (current > static_cast <size_t >(std::numeric_limits<int64_t >::max () /
381
+ num_elements_in_axis)) {
438
382
return false ;
439
383
}
440
384
num_elements_in_axis *= current;
441
385
}
442
386
443
- if (num_elements_in_axis > 0 ) {
444
- const float scale = input_scale / output_scale;
445
- if (compute_sum) {
446
- // TODO(b/116341117): Eliminate float and do this completely in 8bit.
447
- const float bias = -input_zero_point * scale * num_elements_in_axis;
448
- for (size_t idx = 0 ; idx < num_outputs; ++idx) {
449
- const U value =
450
- static_cast <U>(TfLiteRound (temp_sum[idx] * scale + bias)) +
451
- output_zero_point;
452
- output_data[idx] = static_cast <T>(value);
453
- }
454
- } else {
455
- const float bias = -input_zero_point * scale;
456
- for (size_t idx = 0 ; idx < num_outputs; ++idx) {
457
- float float_mean = static_cast <float >(temp_sum[idx]) /
458
- static_cast <float >(num_elements_in_axis);
459
- float result = TfLiteMin (
460
- TfLiteRound (float_mean * scale + bias) + output_zero_point,
461
- static_cast <float >(std::numeric_limits<T>::max ()));
462
- result = TfLiteMax (result,
463
- static_cast <float >(std::numeric_limits<T>::min ()));
464
- output_data[idx] = static_cast <T>(result);
465
- }
466
- }
387
+ if (num_elements_in_axis == 0 ) {
388
+ return true ;
389
+ }
390
+
391
+ // Readapt output rescaling when calculating the mean to integrate a
392
+ // 1/num_elements_in_axis multiplier.
393
+ if (!compute_sum) {
394
+ TFLITE_DCHECK_GE (num_elements_in_axis, 0 );
395
+ int shift =
396
+ 63 - CountLeadingZeros (static_cast <uint64_t >(num_elements_in_axis));
397
+ // To avoid any overflow risk 'shift' should be <= 32 and to satisfy
398
+ // 'MultiplyByQuantizedMultiplier' pre-conditions 'output_shift - shift'
399
+ // should be >= -31. Clamp the value at the price of some precision loss.
400
+ shift = std::min (shift, 32 );
401
+ shift = std::min (shift, 31 + output_shift);
402
+ output_multiplier = static_cast <int32_t >(
403
+ (static_cast <int64_t >(output_multiplier) << shift) /
404
+ num_elements_in_axis);
405
+ output_shift = output_shift - shift;
406
+ }
407
+
408
+ for (size_t idx = 0 ; idx < num_outputs; ++idx) {
409
+ const U shifted_sum =
410
+ static_cast <U>(temp_sum[idx] - input_zero_point * num_elements_in_axis);
411
+ int32_t output = MultiplyByQuantizedMultiplier (
412
+ shifted_sum, output_multiplier, output_shift) +
413
+ output_zero_point;
414
+ output = std::min (std::max (output, kMinValue ), kMaxValue );
415
+ output_data[idx] = static_cast <T>(output);
467
416
}
468
417
return true ;
469
418
}
@@ -478,8 +427,8 @@ inline bool QuantizedMeanOrSumExtraArgs(
478
427
bool keep_dims, int * temp_index, int * resolved_axis, U* temp_sum,
479
428
bool compute_sum) {
480
429
return QuantizedMeanOrSum<T, U>(
481
- input_data, input_zero_point, input_scale, input_dims, input_num_dims,
482
- output_data, output_zero_point, output_scale , output_dims,
430
+ input_data, input_zero_point, input_dims, input_num_dims, output_data ,
431
+ output_multiplier, output_shift, output_zero_point , output_dims,
483
432
output_num_dims, axis, num_axis_dimensions, keep_dims, temp_index,
484
433
resolved_axis, temp_sum, compute_sum);
485
434
}
0 commit comments