Skip to content

Commit 09a5e54

Browse files
authored
torchao::parallel_for backends
Differential Revision: D60867909 Pull Request resolved: #774
1 parent cfabc13 commit 09a5e54

9 files changed

+266
-138
lines changed

torchao/experimental/kernels/cpu/linear/channelwise_8bit_activation_groupwise_lowbit_weight-impl.h

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@ PackWeightDataTilingParams get_default_pack_weight_data_tiling_params(
1818
int n,
1919
int target_panels_per_thread) {
2020
TORCHAO_CHECK(n >= 1, "n must be >= 1");
21-
TORCHAO_CHECK(target_panels_per_thread >= 1, "target_panels_per_thread must be >= 1");
21+
TORCHAO_CHECK(
22+
target_panels_per_thread >= 1, "target_panels_per_thread must be >= 1");
2223

2324
PackWeightDataTilingParams tiling_params;
2425
int nr = ukernel_config.nr;
@@ -57,6 +58,10 @@ void pack_weight_data_operator(
5758
int num_nc_panels = (n + nc - 1) / nc;
5859

5960
torchao::parallel_for(0, num_nc_panels, 1, [&](int64_t begin, int64_t end) {
61+
// TODO(T200106949): decide how to handle at::parallel_for not respecting
62+
// user-supplied grain_size
63+
assert(end == begin + 1);
64+
6065
int nc_tile_idx = begin;
6166
int n_idx = nc_tile_idx * nc;
6267
int nc_tile_size = std::min(nc, n - n_idx);
@@ -85,7 +90,8 @@ LinearTilingParams get_default_linear_tiling_params(
8590
int target_tiles_per_thread) {
8691
TORCHAO_CHECK(m >= 1, "m must be >= 1");
8792
TORCHAO_CHECK(n >= 1, "n must be >= 1");
88-
TORCHAO_CHECK(target_tiles_per_thread >= 1, "target_tiles_per_thread must be >= 1");
93+
TORCHAO_CHECK(
94+
target_tiles_per_thread >= 1, "target_tiles_per_thread must be >= 1");
8995

9096
LinearTilingParams tiling_params;
9197
auto num_threads = torchao::get_num_threads();
@@ -159,6 +165,7 @@ void linear_operator_with_tile_schedule_policy_single_mc_parallel_nc(
159165
int nc = std::min(n, tiling_params.nc_by_nr * ukernel_config.nr);
160166
int num_mc_panels = (m + mc - 1) / mc;
161167
int num_nc_panels = (n + nc - 1) / nc;
168+
int weight_data_size = ukernel_config.weight_data_size_fn(nr, k, group_size);
162169

163170
for (int mc_tile_idx = 0; mc_tile_idx < num_mc_panels; mc_tile_idx++) {
164171
int m_idx = mc_tile_idx * mc;
@@ -172,13 +179,16 @@ void linear_operator_with_tile_schedule_policy_single_mc_parallel_nc(
172179
activations + activations_offset);
173180

174181
torchao::parallel_for(0, num_nc_panels, 1, [&](int64_t begin, int64_t end) {
182+
// TODO(T200106949): decide how to handle at::parallel_for not respecting
183+
// user-supplied grain_size
184+
assert(end == begin + 1);
185+
175186
int nc_tile_idx = begin;
176187
int n_idx = nc_tile_idx * nc;
177188
int nc_tile_size = std::min(nc, n - n_idx);
178189

179190
int output_offset = m_idx * n + n_idx;
180-
int weight_data_offset =
181-
(n_idx / nr) * ukernel_config.weight_data_size_fn(nr, k, group_size);
191+
int weight_data_offset = (n_idx / nr) * weight_data_size;
182192
int bias_offset = m_idx;
183193

184194
ukernel_config.kernel_fn(
@@ -220,13 +230,16 @@ void linear_operator_with_tile_schedule_policy_parallel_mc_parallel_nc(
220230
int num_mc_panels = (m + mc - 1) / mc;
221231
int num_nc_panels = (n + nc - 1) / nc;
222232

233+
int weight_data_size = ukernel_config.weight_data_size_fn(nr, k, group_size);
234+
int activation_data_size =
235+
ukernel_config.activation_data_size_fn(mr, k, group_size);
236+
223237
torchao::parallel_for(0, num_mc_panels, 1, [&](int64_t begin, int64_t end) {
224238
int mc_tile_idx = begin;
225239
int m_idx = mc_tile_idx * mc;
226240
int mc_tile_size = std::min(mc, m - m_idx);
227241
int activations_offset = m_idx * k;
228-
int activation_data_offset = (m_idx / mr) *
229-
ukernel_config.activation_data_size_fn(mr, k, group_size);
242+
int activation_data_offset = (m_idx / mr) * activation_data_size;
230243

231244
ukernel_config.prepare_activation_data_fn(
232245
activation_data_buffer + activation_data_offset,
@@ -246,11 +259,9 @@ void linear_operator_with_tile_schedule_policy_parallel_mc_parallel_nc(
246259
int n_idx = nc_tile_idx * nc;
247260
int nc_tile_size = std::min(nc, n - n_idx);
248261

249-
int activation_data_offset = (m_idx / mr) *
250-
ukernel_config.activation_data_size_fn(mr, k, group_size);
262+
int activation_data_offset = (m_idx / mr) * activation_data_size;
251263
int output_offset = m_idx * n + n_idx;
252-
int weight_data_offset = (n_idx / nr) *
253-
ukernel_config.weight_data_size_fn(nr, k, group_size);
264+
int weight_data_offset = (n_idx / nr) * weight_data_size;
254265
int bias_offset = m_idx;
255266

256267
ukernel_config.kernel_fn(
@@ -283,7 +294,6 @@ void linear_operator(
283294
int group_size,
284295
const void* weight_data,
285296
const float* activations,
286-
// const void* activation_data,
287297
// Not applied if nullptr
288298
const float* bias,
289299
// Ignored if has_clamp = false
@@ -371,12 +381,12 @@ UKernelConfig get_ukernel_config() {
371381
config.nr = 8;
372382
config.activation_data_size_fn =
373383
&ukernel::activation_data_size<has_weight_zeros>;
374-
config.activation_data_alignment = alignof(char*);
384+
config.activation_data_alignment = 16; // size of neon register
375385
config.prepare_activation_data_fn =
376386
&ukernel::prepare_activation_data<has_weight_zeros>;
377387
config.weight_data_size_fn =
378388
&ukernel::weight_data_size<weight_nbit, has_weight_zeros>;
379-
config.weight_data_alignment = alignof(char*);
389+
config.weight_data_alignment = 16; // size of neon register
380390
config.prepare_weight_data_fn =
381391
&ukernel::prepare_weight_data<weight_nbit, has_weight_zeros>;
382392
config.kernel_fn =

torchao/experimental/kernels/cpu/linear/channelwise_8bit_activation_groupwise_lowbit_weight.h

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,14 @@ namespace torchao::operators::cpu::linear::
1111
channelwise_8bit_activation_groupwise_lowbit_weight {
1212

1313
struct UKernelConfig {
14-
using activation_data_size_fn_type =
15-
int (*)(int m, int k, int group_size);
14+
using activation_data_size_fn_type = int (*)(int m, int k, int group_size);
1615
using prepare_activation_data_fn_type = void (*)(
1716
void* activation_data,
1817
int m,
1918
int k,
2019
int group_size,
2120
const float* activations);
22-
using weight_data_size_fn_type =
23-
int (*)(int n, int k, int group_size);
21+
using weight_data_size_fn_type = int (*)(int n, int k, int group_size);
2422
using prepare_weight_data_fn_type = void (*)(
2523
void* weight_data,
2624
int n,
@@ -43,10 +41,18 @@ struct UKernelConfig {
4341
float clamp_max);
4442

4543
activation_data_size_fn_type activation_data_size_fn{nullptr};
44+
// activation_data_alignment is only a preferred alignment for
45+
// performance reasons. Integration surfaces are not required to
46+
// respect this alignment, and the ukernel must behave correctly no matter
47+
// how the prepared_activation_data byte-array is aligned
4648
int activation_data_alignment{0};
4749
prepare_activation_data_fn_type prepare_activation_data_fn{nullptr};
4850

4951
weight_data_size_fn_type weight_data_size_fn{nullptr};
52+
// weight_data_alignment is only a preferred alignment for
53+
// performance reasons. Integration surfaces are not required to
54+
// respect this alignment, and the ukernel must behave correctly no matter
55+
// how the prepared_weight_data byte-array is aligned
5056
int weight_data_alignment{0};
5157
prepare_weight_data_fn_type prepare_weight_data_fn{nullptr};
5258

torchao/experimental/kernels/cpu/linear/examples/Channelwise8BitActivationGroupwiseLowbitWeightLinearOperator.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ class Channelwise8BitActivationGroupwiseLowbitWeightLinearOperator {
1919
torchao::aligned_byte_ptr packed_weight_data_{
2020
nullptr,
2121
nullptr};
22+
int packed_weight_data_size_{0};
23+
int packed_weight_data_alignment_{0};
2224

2325
torchao::aligned_byte_ptr activation_data_buffer_{
2426
nullptr,
@@ -112,6 +114,9 @@ class Channelwise8BitActivationGroupwiseLowbitWeightLinearOperator {
112114
get_packed_weight_data_size(ukernel_config_, n_, k_, group_size_);
113115
auto packed_weight_data_alignment =
114116
get_packed_weight_data_alignment(ukernel_config_);
117+
118+
packed_weight_data_size_ = packed_weight_data_size;
119+
packed_weight_data_alignment_ = packed_weight_data_alignment;
115120
packed_weight_data_ = torchao::make_aligned_byte_ptr(
116121
packed_weight_data_alignment, packed_weight_data_size);
117122

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
// Copyright (c) Meta Platforms, Inc. and affiliates.
2+
// All rights reserved.
3+
//
4+
// This source code is licensed under the license found in the
5+
// LICENSE file in the root directory of this source tree.
6+
7+
#pragma once
8+
9+
#ifdef TORCHAO_PARALLEL_ATEN
10+
#pragma message("TORCHAO_PARALLEL_ATEN is set. Using ATen parallel backend.")
11+
12+
// TODO(T200106949): reconcile at::parallel_for's grain_size with what is needed
13+
// in torchao::parallel_for
14+
#error "TORCHAO_PARALLEL_ATEN is not implemented yet"
15+
16+
#else
17+
#ifdef TORCHAO_PARALLEL_EXECUTORCH
18+
#pragma message( \
19+
"TORCHAO_PARALLEL_EXECUTORCH is set. Using ExecuTorch parallel backend.")
20+
21+
#error "TORCHAO_PARALLEL_EXECUTORCH is not implemented yet"
22+
23+
#else
24+
#ifdef TORCHAO_PARALLEL_PTHREADPOOL
25+
#pragma message( \
26+
"TORCHAO_PARALLEL_PTHREADPOOL is set. Using pthreadpool parallel backend.")
27+
#include <torchao/experimental/kernels/cpu/parallel-pthreadpool-impl.h>
28+
29+
#else
30+
#ifdef TORCHAO_PARALLEL_OMP
31+
#pragma message("TORCHAO_PARALLEL_OMP is set. Using OMP parallel backend.")
32+
#include <torchao/experimental/kernels/cpu/parallel-omp-impl.h>
33+
34+
#else
35+
#if defined TORCHAO_PARALLEL_SINGLE_THREADED
36+
#pragma message( \
37+
"TORCHAO_PARALLEL_SINGLE_THREADED is set. Using single-threaded parallel backend.")
38+
#include <torchao/experimental/kernels/cpu/parallel-single_threaded-impl.h>
39+
40+
#else
41+
#if defined TORCHAO_PARALLEL_TEST_DUMMY
42+
#pragma message( \
43+
"TORCHAO_PARALLEL_TEST_DUMMY is set. Using test dummy parallel backend.")
44+
#include <torchao/experimental/kernels/cpu/parallel-test_dummy-impl.h>
45+
46+
#else
47+
#error \
48+
"Set parallel backend by defining one of the following: \
49+
TORCHAO_PARALLEL_ATEN, \
50+
TORCHAO_PARALLEL_EXECUTORCH, \
51+
TORCHAO_PARALLEL_PTHREADPOOL, \
52+
TORCHAO_PARALLEL_OMP, \
53+
TORCHAO_PARALLEL_SINGLE_THREADED, \
54+
TORCHAO_PARALLEL_TEST_DUMMY"
55+
56+
#endif // TORCHAO_PARALLEL_TEST_DUMMY
57+
#endif // TORCHAO_PARALLEL_SINGLE_THREADED
58+
#endif // TORCHAO_PARALLEL_OMP
59+
#endif // TORCHAO_PARALLEL_PTHREADPOOL
60+
#endif // TORCHAO_PARALLEL_EXECUTORCH
61+
#endif // TORCHAO_PARALLEL_ATEN
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
// Copyright (c) Meta Platforms, Inc. and affiliates.
2+
// All rights reserved.
3+
//
4+
// This source code is licensed under the license found in the
5+
// LICENSE file in the root directory of this source tree.
6+
7+
#pragma once
8+
#include <omp.h>
9+
10+
template <typename F>
11+
void torchao::parallel_for(
12+
const int64_t begin,
13+
const int64_t end,
14+
const int64_t grain_size,
15+
const F& f) {
16+
#pragma omp parallel
17+
{
18+
#pragma omp for
19+
for (int i = begin; i < end; i += grain_size) {
20+
f(i, i + grain_size);
21+
}
22+
}
23+
}
24+
25+
void torchao::set_num_threads(int num_threads) {
26+
omp_set_num_threads(num_threads);
27+
}
28+
int torchao::get_num_threads() {
29+
// omp_get_num_threads returns the number of threads
30+
// in the current code section, which will be 1 in the routines
31+
// that select tiling params
32+
return omp_get_max_threads();
33+
}
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
// Copyright (c) Meta Platforms, Inc. and affiliates.
2+
// All rights reserved.
3+
//
4+
// This source code is licensed under the license found in the
5+
// LICENSE file in the root directory of this source tree.
6+
7+
#pragma once
8+
#include <pthreadpool.h>
9+
#include <stdexcept>
10+
11+
namespace torchao::parallel::internal {
12+
class Threadpool {
13+
private:
14+
pthreadpool_t pthreadpool_{nullptr};
15+
16+
public:
17+
Threadpool(size_t num_threads = 0) {
18+
pthreadpool_ = pthreadpool_create(num_threads);
19+
if (pthreadpool_ == nullptr) {
20+
throw std::runtime_error("Failed to create pthreadpool.");
21+
}
22+
}
23+
~Threadpool() {
24+
pthreadpool_destroy(pthreadpool_);
25+
pthreadpool_ = nullptr;
26+
}
27+
pthreadpool_t get() {
28+
return pthreadpool_;
29+
}
30+
size_t get_num_threads() {
31+
if (pthreadpool_ == nullptr) {
32+
return 0;
33+
}
34+
return pthreadpool_get_threads_count(pthreadpool_);
35+
}
36+
void set_num_threads(size_t num_threads) {
37+
if (num_threads == get_num_threads()) {
38+
return;
39+
}
40+
pthreadpool_destroy(pthreadpool_);
41+
pthreadpool_ = pthreadpool_create(num_threads);
42+
}
43+
};
44+
45+
template <typename F>
46+
struct Context {
47+
const F& f;
48+
int grain_size;
49+
Context(const F& f, int grain_size) : f{f}, grain_size{grain_size} {}
50+
};
51+
52+
template <typename F>
53+
static void task(Context<F>* context, size_t grain_idx) {
54+
int i = grain_idx * context->grain_size;
55+
context->f(i, i + context->grain_size);
56+
}
57+
58+
static Threadpool threadpool;
59+
} // namespace torchao::parallel::internal
60+
61+
int torchao::get_num_threads() {
62+
return torchao::parallel::internal::threadpool.get_num_threads();
63+
}
64+
65+
void torchao::set_num_threads(int num_threads) {
66+
torchao::parallel::internal::threadpool.set_num_threads(num_threads);
67+
}
68+
69+
template <typename F>
70+
void torchao::parallel_for(
71+
const int64_t begin,
72+
const int64_t end,
73+
const int64_t grain_size,
74+
const F& f) {
75+
int grain_idx_end = end / grain_size;
76+
auto context = torchao::parallel::internal::Context<F>(f, grain_size);
77+
pthreadpool_parallelize_1d(
78+
torchao::parallel::internal::threadpool.get(),
79+
(pthreadpool_task_1d_t)torchao::parallel::internal::task<F>,
80+
(void**)&context,
81+
grain_idx_end,
82+
0 /* flags */);
83+
}
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
// Copyright (c) Meta Platforms, Inc. and affiliates.
2+
// All rights reserved.
3+
//
4+
// This source code is licensed under the license found in the
5+
// LICENSE file in the root directory of this source tree.
6+
7+
#pragma once
8+
9+
template <typename F>
10+
void torchao::parallel_for(
11+
const int64_t begin,
12+
const int64_t end,
13+
const int64_t grain_size,
14+
const F& f) {
15+
for (int i = begin; i < end; i += grain_size) {
16+
f(i, i + grain_size);
17+
}
18+
}
19+
20+
void torchao::set_num_threads(int num_threads) {}
21+
int torchao::get_num_threads() {
22+
return 1;
23+
}

0 commit comments

Comments
 (0)