Skip to content

Commit

Permalink
Optimize GatherElements and Scaler. (microsoft#5543)
Browse files Browse the repository at this point in the history
* Optimize GatherElements and Scaler.

* Address PR comments

* Fix build
  • Loading branch information
pranavsharma authored Oct 20, 2020
1 parent 2f4fc83 commit 1038f9c
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 61 deletions.
27 changes: 21 additions & 6 deletions onnxruntime/core/providers/cpu/ml/scaler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,15 +83,30 @@ common::Status ScalerOp<T>::Compute(OpKernelContext* context) const {

size_t x_size = x_shape.Size();
int64_t stride = x_dims.size() == 1 ? x_dims[0] : x_dims[1];
auto* ttp = context->GetOperatorThreadPool();
auto num_threads = std::min<int>(concurrency::ThreadPool::DegreeOfParallelism(ttp), static_cast<int>(x_size));

if (static_cast<int64_t>(offset_.size()) == stride &&
static_cast<int64_t>(scale_.size()) == stride) {
for (size_t i = 0; i < x_size; i++) {
y_data[i] = static_cast<float>((x_data[i] - offset_[i % stride]) * scale_[i % stride]);
}
concurrency::ThreadPool::TrySimpleParallelFor(
ttp,
num_threads,
[this, num_threads, y_data, x_data, stride, x_size](ptrdiff_t batch_num) {
auto work = concurrency::ThreadPool::PartitionWork(batch_num, num_threads, x_size);
for (auto i = work.start; i < work.end; ++i) {
y_data[i] = static_cast<float>((x_data[i] - offset_[i % stride]) * scale_[i % stride]);
}
});
} else if (offset_.size() == 1 && scale_.size() == 1) {
for (size_t i = 0; i < x_size; i++) {
y_data[i] = static_cast<float>((x_data[i] - offset_[0]) * scale_[0]);
}
concurrency::ThreadPool::TrySimpleParallelFor(
ttp,
num_threads,
[this, num_threads, y_data, x_data, x_size](ptrdiff_t batch_num) {
auto work = concurrency::ThreadPool::PartitionWork(batch_num, num_threads, x_size);
for (auto i = work.start; i < work.end; ++i) {
y_data[i] = static_cast<float>((x_data[i] - offset_[0]) * scale_[0]);
}
});
} else {
std::ostringstream err_msg;
err_msg << "Either both scale and offset can be of feature size (" << stride << ") or 1";
Expand Down
119 changes: 64 additions & 55 deletions onnxruntime/core/providers/cpu/tensor/gather_elements.cc
Original file line number Diff line number Diff line change
Expand Up @@ -86,42 +86,15 @@ static inline void increment_over_inner_dim(std::vector<int64_t>& current_dims,
}
}

// parse indices_tensor and along the way validate its shape and contents
static std::vector<int64_t> parse_and_validate_indices_tensor(const Tensor* indices_tensor,
int64_t axis, const TensorShape& input_shape) {
// first parse 'indices' data
auto num_elements = indices_tensor->Shape().Size();
std::vector<int64_t> indices_data;
// reserving memory ahead as we know the size of the container
indices_data.reserve(num_elements);
if (indices_tensor->IsDataType<int32_t>()) {
const auto* data = indices_tensor->Data<int32_t>();
for (int64_t i = 0; i < num_elements; ++i)
indices_data.push_back(data[i]);
} else if (indices_tensor->IsDataType<int64_t>()) {
const auto* data = indices_tensor->Data<int64_t>();
for (int64_t i = 0; i < num_elements; ++i)
indices_data.push_back(data[i]);
template <typename Tin>
static inline int64_t GetNegativeIndexAdjustedValue(const Tin* indices_data, Tin index, int64_t axis, const TensorShape& input_shape) {
int64_t retval = -1;
if (indices_data[index] < 0) {
retval = static_cast<int64_t>(indices_data[index] + input_shape[axis]);
} else {
ORT_THROW("GatherElements op: Data type for 'indices' tensor must be 'int32_t' and 'int64_t'");
retval = static_cast<int64_t>(indices_data[index]);
}

// validate 'indices' data
// along the way 'fix' negative index values if within bounds
int64_t lower_index_limit = -input_shape[axis];
int64_t upper_index_limit = input_shape[axis] - 1;

for (int64_t i = 0; i < num_elements; ++i) {
auto indices_val = indices_data[i];
if (indices_val < lower_index_limit || indices_val > upper_index_limit)
ORT_THROW("GatherElements op: Value in indices must be within bounds [",
lower_index_limit, " , ", upper_index_limit, "]. Actual value is ", indices_val);

if (indices_val < 0)
indices_data[i] += input_shape[axis];
}

return indices_data;
return retval;
}

#ifdef __GNUC__
Expand All @@ -130,7 +103,7 @@ static std::vector<int64_t> parse_and_validate_indices_tensor(const Tensor* indi
#pragma GCC diagnostic ignored "-Wclass-memaccess"
#endif
#endif
template <bool is_string, typename T>
template <bool is_string, typename T, typename Tin>
static void core_impl(const Tensor* input_tensor, const Tensor* indices_tensor,
Tensor* output_tensor, int64_t axis) {
// get pointer to input data
Expand All @@ -154,15 +127,27 @@ static void core_impl(const Tensor* input_tensor, const Tensor* indices_tensor,
const int64_t input_rank = static_cast<int64_t>(input_tensor->Shape().NumDimensions());
const TensorPitches input_shape_pitches(*input_tensor);

const std::vector<int64_t>& indices_data = parse_and_validate_indices_tensor(indices_tensor, axis, input_tensor->Shape());
const auto& input_shape = input_tensor->Shape();
const TensorShape& indices_shape = indices_tensor->Shape();
const Tin* indices_data = indices_tensor->Data<Tin>();

// validate indices
auto num_elements = indices_tensor->Shape().Size();
int64_t lower_index_limit = -input_shape[axis];
int64_t upper_index_limit = input_shape[axis] - 1;
for (int64_t i = 0; i < num_elements; ++i) {
auto indices_val = indices_data[i];
if (indices_val < lower_index_limit || indices_val > upper_index_limit)
ORT_THROW("GatherElements op: Value in indices must be within bounds [",
lower_index_limit, " , ", upper_index_limit, "]. Actual value is ", indices_val);
}

int64_t num_inner_dim = calculate_num_inner_dim(indices_shape);
int64_t inner_dim_size = indices_shape[input_rank - 1];
bool processing_inner_dim = (axis == input_rank - 1) ? true : false;

int64_t base_offset = 0;
int64_t indices_counter = -1;
Tin indices_counter = -1;
int64_t output_counter = -1;
size_t element_size = input_tensor->DataType()->Size();

Expand All @@ -173,17 +158,24 @@ static void core_impl(const Tensor* input_tensor, const Tensor* indices_tensor,
base_offset = compute_base_offset(process_dims, input_shape_pitches, axis);

// process 1 chunk of 'inner dimension' length
for (int64_t i = 0; i < inner_dim_size; ++i) {
// optimizer will remove the redundant if/else block based on 'is_string' template parameter
if (is_string) {
output_data[++output_counter] = input_data[base_offset + (indices_data[++indices_counter] * input_shape_pitches[axis]) + i];
} else {
// optimizer will remove the redundant if/else block based on 'is_string' template parameter
if (is_string) {
for (int64_t i = 0; i < inner_dim_size; ++i) {
output_data[++output_counter] =
input_data[base_offset +
(GetNegativeIndexAdjustedValue<Tin>(indices_data, ++indices_counter, axis, input_shape) *
input_shape_pitches[axis]) +
i];
}
} else {
for (int64_t i = 0; i < inner_dim_size; ++i) {
// optimizer will remove the redundant if/else block based on 'is_string' template parameter
memcpy(output_data,
input_data + (base_offset + (indices_data[++indices_counter] * input_shape_pitches[axis]) + i) * element_size, element_size);
input_data + (base_offset + (GetNegativeIndexAdjustedValue<Tin>(indices_data, ++indices_counter, axis, input_shape) * input_shape_pitches[axis]) + i) * element_size,
element_size);
output_data += element_size;
}
}

increment_over_inner_dim(process_dims, indices_shape);
}
}
Expand All @@ -193,13 +185,23 @@ static void core_impl(const Tensor* input_tensor, const Tensor* indices_tensor,
base_offset = compute_base_offset(process_dims, input_shape_pitches, axis);

// process 1 chunk of 'inner dimension' length
for (int64_t i = 0; i < inner_dim_size; ++i) {
// for innermost axis, input_shape_pitches[axis] = 1 (so no need to multiply)
// optimizer will remove the redundant if/else block based on 'is_string' template parameter
if (is_string) {
output_data[++output_counter] = input_data[base_offset + indices_data[++indices_counter]];
} else {
memcpy(output_data, input_data + (base_offset + indices_data[++indices_counter]) * element_size, element_size);
// optimizer will remove the redundant if/else block based on 'is_string' template parameter
if (is_string) {
for (int64_t i = 0; i < inner_dim_size; ++i) {
// for innermost axis, input_shape_pitches[axis] = 1 (so no need to multiply)
output_data[++output_counter] =
input_data[base_offset +
GetNegativeIndexAdjustedValue<Tin>(indices_data, ++indices_counter, axis, input_shape)];
}
} else {
for (int64_t i = 0; i < inner_dim_size; ++i) {
// for innermost axis, input_shape_pitches[axis] = 1 (so no need to multiply)
// optimizer will remove the redundant if/else block based on 'is_string' template parameter
memcpy(output_data,
input_data + (base_offset +
GetNegativeIndexAdjustedValue<Tin>(indices_data, ++indices_counter, axis, input_shape)) *
element_size,
element_size);
output_data += element_size;
}
}
Expand Down Expand Up @@ -270,10 +272,17 @@ Status GatherElements::Compute(OpKernelContext* context) const {
if (indices_shape.Size() == 0)
return Status::OK();

if (input_tensor->IsDataTypeString())
core_impl<true, std::string>(input_tensor, indices_tensor, output_tensor, axis);
else
core_impl<false, int8_t>(input_tensor, indices_tensor, output_tensor, axis);
if (input_tensor->IsDataTypeString()) {
if (indices_tensor->IsDataType<int32_t>())
core_impl<true, std::string, int32_t>(input_tensor, indices_tensor, output_tensor, axis);
else
core_impl<true, std::string, int64_t>(input_tensor, indices_tensor, output_tensor, axis);
} else {
if (indices_tensor->IsDataType<int32_t>())
core_impl<false, int8_t, int32_t>(input_tensor, indices_tensor, output_tensor, axis);
else
core_impl<false, int8_t, int64_t>(input_tensor, indices_tensor, output_tensor, axis);
}

return Status::OK();
}
Expand Down

0 comments on commit 1038f9c

Please sign in to comment.