Skip to content

Commit

Permalink
perf(acc_op): add if condition for the element number small situations (
Browse files Browse the repository at this point in the history
metaopt#105)

Co-authored-by: Xuehai Pan <XuehaiPan@pku.edu.cn>
  • Loading branch information
JieRen98 and XuehaiPan authored Nov 6, 2022
1 parent 824934b commit 23253b1
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 7 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

- Add if condition of number of threads for CPU OPs by [@JieRen98](https://github.com/JieRen98) in [#105](https://github.com/metaopt/torchopt/pull/105).
- Add implicit MAML omniglot few-shot classification example with OOP APIs by [@XuehaiPan](https://github.com/XuehaiPan) in [#107](https://github.com/metaopt/torchopt/pull/107).
- Add implicit MAML omniglot few-shot classification example by [@Benjamin-eecs](https://github.com/Benjamin-eecs) in [#48](https://github.com/metaopt/torchopt/pull/48).
- Add object-oriented modules support for implicit meta-gradient by [@XuehaiPan](https://github.com/XuehaiPan) in [#101](https://github.com/metaopt/torchopt/pull/101).
Expand Down
30 changes: 23 additions & 7 deletions src/adam_op/adam_op_impl_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ using std::size_t;

namespace adam_op {

constexpr size_t MIN_NUMEL_USE_OMP = 1000;

template <typename scalar_t, typename other_t>
void adamForwardInplaceCPUKernel(const other_t b1,
const other_t inv_one_minus_pow_b1,
Expand All @@ -38,7 +40,9 @@ void adamForwardInplaceCPUKernel(const other_t b1,
scalar_t *__restrict__ updates_ptr,
scalar_t *__restrict__ mu_ptr,
scalar_t *__restrict__ nu_ptr) {
#pragma omp parallel for num_threads(omp_get_num_procs())
#pragma omp parallel for num_threads( \
std::min(n / MIN_NUMEL_USE_OMP, \
static_cast <size_t>(omp_get_num_procs()))) if (n > MIN_NUMEL_USE_OMP) // NOLINT
for (size_t tid = 0; tid < n; ++tid) {
const scalar_t updates = updates_ptr[tid];
const scalar_t mu = mu_ptr[tid];
Expand Down Expand Up @@ -90,7 +94,9 @@ void adamForwardMuCPUKernel(const scalar_t *__restrict__ updates_ptr,
const other_t b1,
const size_t n,
scalar_t *__restrict__ mu_out_ptr) {
#pragma omp parallel for num_threads(omp_get_num_procs())
#pragma omp parallel for num_threads( \
std::min(n / MIN_NUMEL_USE_OMP, \
static_cast <size_t>(omp_get_num_procs()))) if (n > MIN_NUMEL_USE_OMP) // NOLINT
for (size_t tid = 0; tid < n; ++tid) {
const scalar_t updates = updates_ptr[tid];
const scalar_t mu = mu_ptr[tid];
Expand Down Expand Up @@ -122,7 +128,9 @@ void adamForwardNuCPUKernel(const scalar_t *__restrict__ updates_ptr,
const other_t b2,
const size_t n,
scalar_t *__restrict__ nu_out_ptr) {
#pragma omp parallel for num_threads(omp_get_num_procs())
#pragma omp parallel for num_threads( \
std::min(n / MIN_NUMEL_USE_OMP, \
static_cast <size_t>(omp_get_num_procs()))) if (n > MIN_NUMEL_USE_OMP) // NOLINT
for (size_t tid = 0; tid < n; ++tid) {
const scalar_t updates = updates_ptr[tid];
const scalar_t nu = nu_ptr[tid];
Expand Down Expand Up @@ -158,7 +166,9 @@ void adamForwardUpdatesCPUKernel(const scalar_t *__restrict__ new_mu_ptr,
const other_t eps_root,
const size_t n,
scalar_t *__restrict__ updates_out_ptr) {
#pragma omp parallel for num_threads(omp_get_num_procs())
#pragma omp parallel for num_threads( \
std::min(n / MIN_NUMEL_USE_OMP, \
static_cast <size_t>(omp_get_num_procs()))) if (n > MIN_NUMEL_USE_OMP) // NOLINT
for (size_t tid = 0; tid < n; ++tid) {
const scalar_t new_mu = new_mu_ptr[tid];
const scalar_t new_nu = new_nu_ptr[tid];
Expand Down Expand Up @@ -205,7 +215,9 @@ void adamBackwardMuCPUKernel(const scalar_t *__restrict__ dmu_ptr,
const size_t n,
scalar_t *__restrict__ dupdates_out_ptr,
scalar_t *__restrict__ dmu_out_ptr) {
#pragma omp parallel for num_threads(omp_get_num_procs())
#pragma omp parallel for num_threads( \
std::min(n / MIN_NUMEL_USE_OMP, \
static_cast <size_t>(omp_get_num_procs()))) if (n > MIN_NUMEL_USE_OMP) // NOLINT
for (size_t tid = 0; tid < n; ++tid) {
const scalar_t dmu = dmu_ptr[tid];

Expand Down Expand Up @@ -240,7 +252,9 @@ void adamBackwardNuCPUKernel(const scalar_t *__restrict__ dnu_ptr,
const size_t n,
scalar_t *__restrict__ dupdates_out_ptr,
scalar_t *__restrict__ dnu_out_ptr) {
#pragma omp parallel for num_threads(omp_get_num_procs())
#pragma omp parallel for num_threads( \
std::min(n / MIN_NUMEL_USE_OMP, \
static_cast <size_t>(omp_get_num_procs()))) if (n > MIN_NUMEL_USE_OMP) // NOLINT
for (size_t tid = 0; tid < n; ++tid) {
const scalar_t dnu = dnu_ptr[tid];
const scalar_t updates = updates_ptr[tid];
Expand Down Expand Up @@ -279,7 +293,9 @@ void adamBackwardUpdatesCPUKernel(const scalar_t *__restrict__ dupdates_ptr,
const size_t n,
scalar_t *__restrict__ dnew_mu_out_ptr,
scalar_t *__restrict__ dnew_nu_out_ptr) {
#pragma omp parallel for num_threads(omp_get_num_procs())
#pragma omp parallel for num_threads( \
std::min(n / MIN_NUMEL_USE_OMP, \
static_cast <size_t>(omp_get_num_procs()))) if (n > MIN_NUMEL_USE_OMP) // NOLINT
for (size_t tid = 0; tid < n; ++tid) {
const scalar_t dupdates = dupdates_ptr[tid];
const scalar_t updates = updates_ptr[tid];
Expand Down

0 comments on commit 23253b1

Please sign in to comment.