Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add omp impl for tryparallelfor and modify gelu to use fastgelu impl. #3667

Merged
merged 2 commits into from
Apr 24, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 43 additions & 15 deletions include/onnxruntime/core/platform/threadpool.h
Original file line number Diff line number Diff line change
@@ -166,11 +166,26 @@ class ThreadPool {

static void TryParallelFor(concurrency::ThreadPool* tp, std::ptrdiff_t total, const TensorOpCost& cost_per_unit,
const std::function<void(std::ptrdiff_t first, std::ptrdiff_t last)>& fn) {
#ifdef _OPENMP
ORT_UNUSED_PARAMETER(cost_per_unit);
std::ptrdiff_t num_threads = concurrency::ThreadPool::NumThreads(tp);
pranavsharma marked this conversation as resolved.
Show resolved Hide resolved
if (total < num_threads) {
num_threads = total;
}
#pragma omp parallel for
for (std::ptrdiff_t i = 0; i < num_threads; i++) {
std::ptrdiff_t start, work_remaining;
PartitionWork(i, num_threads, total, &start, &work_remaining);
std::ptrdiff_t end = start + work_remaining;
fn(start, end);
}
#else
if (tp == nullptr) {
fn(0, total);
return;
}
tp->ParallelFor(total, cost_per_unit, fn);
#endif
}

// Similar to ParallelFor above, but takes the specified scheduling strategy
@@ -180,13 +195,28 @@ class ThreadPool {
const std::function<void(std::ptrdiff_t, std::ptrdiff_t)>& fn);

static void TryParallelFor(concurrency::ThreadPool* tp, std::ptrdiff_t total, const SchedulingParams& scheduling_params,
const std::function<void(std::ptrdiff_t, std::ptrdiff_t)>& fn) {
const std::function<void(std::ptrdiff_t first, std::ptrdiff_t last)>& fn) {
#ifdef _OPENMP
ORT_UNUSED_PARAMETER(scheduling_params);
std::ptrdiff_t num_threads = concurrency::ThreadPool::NumThreads(tp);
pranavsharma marked this conversation as resolved.
Show resolved Hide resolved
if (total < num_threads) {
num_threads = total;
}
#pragma omp parallel for
for (std::ptrdiff_t i = 0; i < num_threads; i++) {
std::ptrdiff_t start, work_remaining;
PartitionWork(i, num_threads, total, &start, &work_remaining);
std::ptrdiff_t end = start + work_remaining;
fn(start, end);
}
#else
if (tp == nullptr) {
fn(0, total);
return;
}
tp->ParallelFor(total, scheduling_params, fn);
}
#endif
} // namespace concurrency

// Prefer using this API to get the number of threads unless you know what you're doing.
// This API takes into account if openmp is enabled/disabled and if the thread pool ptr is nullptr.
@@ -208,16 +238,6 @@ class ThreadPool {
// cutting them by halves
void SimpleParallelFor(std::ptrdiff_t total, std::function<void(std::ptrdiff_t)> fn);

#ifdef _OPENMP
template <typename F>
inline static void TryBatchParallelFor(ThreadPool*, std::ptrdiff_t total, F&& fn, std::ptrdiff_t /*num_batches*/) {
#pragma omp parallel for
for (std::ptrdiff_t i = 0; i < total; ++i) {
fn(i);
}
}
#else

/**
* Tries to call the given function in parallel, with calls split into (num_batches) batches.
*\param num_batches If it is zero, it will be replaced to the value of NumThreads().
@@ -230,6 +250,14 @@ class ThreadPool {
**/
template <typename F>
inline static void TryBatchParallelFor(ThreadPool* tp, std::ptrdiff_t total, F&& fn, std::ptrdiff_t num_batches) {
#ifdef _OPENMP
ORT_UNUSED_PARAMETER(tp);
ORT_UNUSED_PARAMETER(num_batches);
#pragma omp parallel for
for (std::ptrdiff_t i = 0; i < total; ++i) {
fn(i);
}
#else
if (tp == nullptr) {
for (std::ptrdiff_t i = 0; i < total; ++i) {
// In many cases, fn can be inlined here.
@@ -264,8 +292,8 @@ class ThreadPool {
fn(i);
}
});
}
#endif
}

#ifndef _OPENMP
//Deprecated. Please avoid using Eigen Tensor because it will blow up binary size quickly.
@@ -291,7 +319,7 @@ class ThreadPool {
Eigen::ThreadPoolInterface* underlying_threadpool_;
// eigen_threadpool_ is instantiated and owned by thread::ThreadPool if
// user_threadpool is not in the constructor.
std::unique_ptr<ThreadPoolTempl<Env>> eigen_threadpool_;
std::unique_ptr<ThreadPoolTempl<Env> > eigen_threadpool_;
#ifndef _OPENMP
std::unique_ptr<Eigen::ThreadPoolDevice> threadpool_device_;
#endif
@@ -309,7 +337,7 @@ class ThreadPool {
*WorkRemaining = WorkPerThread;
}
}
};
}; // namespace concurrency

} // namespace concurrency
} // namespace onnxruntime
45 changes: 29 additions & 16 deletions onnxruntime/contrib_ops/cpu/activations.h
Original file line number Diff line number Diff line change
@@ -37,23 +37,36 @@ class Gelu : public OpKernel {
Gelu(const OpKernelInfo& info) : OpKernel(info) {}

Status Compute(OpKernelContext* context) const override {
const auto* X = context->Input<Tensor>(0);
Tensor* Y = context->Output(0, X->Shape());
const Tensor* input = context->Input<Tensor>(0);
const T* input_data = input->template Data<T>();

Tensor* output = context->Output(0, input->Shape());
T* output_data = output->template MutableData<T>();

concurrency::ThreadPool* tp = context->GetOperatorThreadPool();
const int64_t input_size = X->Shape().Size();
std::ptrdiff_t batch_size = static_cast<std::ptrdiff_t>(input_size);
//The cost comes from microbenchmark(manual tunning).
const double cost = 10.0;
const T* data = X->template Data<T>();
T* output = Y->template MutableData<T>();
concurrency::ThreadPool::TryParallelFor(tp, batch_size, cost, [data, output](ptrdiff_t first, ptrdiff_t last) {
ptrdiff_t len = last - first;
onnxruntime::ConstEigenVectorArrayMap<T> xm(data + first, len);
onnxruntime::EigenVectorArrayMap<T> ym(output + first, len);
ym = xm * static_cast<float>(M_SQRT1_2);
MlasComputeErf(output, output, len);
ym = xm * 0.5f * (ym + 1.0f);
});
int64_t elem_count = input->Shape().Size();
static const int64_t length_per_task = 4096; // this number comes from FastGelu.
int64_t task_count = (elem_count + length_per_task - 1) / length_per_task;
concurrency::ThreadPool::TryBatchParallelFor(
tp, static_cast<int32_t>(task_count),
[&](ptrdiff_t task_idx) {
const auto start = task_idx * length_per_task;
const T* p_input = input_data + start;
T* p_output = output_data + start;
int64_t count = std::min(length_per_task, elem_count - start);

for (int64_t i = 0; i < count; i++) {
T value = p_input[i];
p_output[i] = value * static_cast<T>(M_SQRT1_2);
}

MlasComputeErf(p_output, p_output, count);

for (int64_t i = 0; i < count; i++) {
p_output[i] = 0.5f * p_input[i] * (p_output[i] + 1.0f);
}
},
0);
return Status::OK();
}
};