Skip to content

Commit 72292a7

Browse files
committed
Update on "Reduce allocation overhead in quantized sdpa"
For small models dequantizing portions of v cache causes extra alloc overhead. Probably a better way to handle this is to dequantize entire v cache outside the model There isnt significant perf advantage from this yet but subsequent diffs will use caching allocator where this refactor help. Differential Revision: [D85532077](https://our.internmc.facebook.com/intern/diff/D85532077/) [ghstack-poisoned]
2 parents 99902b8 + bf5abbf commit 72292a7

File tree

3 files changed

+14
-4
lines changed

3 files changed

+14
-4
lines changed

extension/llm/custom_ops/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,5 +60,6 @@ runtime.python_test(
6060
],
6161
deps = [
6262
"//caffe2:torch",
63+
"//executorch/extension/pybindings:portable_lib",
6364
],
6465
)

extension/llm/custom_ops/op_sdpa_impl.h

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -779,11 +779,13 @@ void cpu_flash_attention(
779779
// Lets align size_per_thread_qdq_vec to 64 bytes, for coalesced cache reads,
780780
// by padding with right number of per thread elements
781781
constexpr int64_t kAlignment = 32;
782-
size_per_thread_qdq_vec = (size_per_thread_qdq_vec + kAlignment - 1) & (-(kAlignment - 1));
783-
int64_t size_per_thread_qdq_bytes = size_per_thread_qdq_vec * query.element_size();
782+
size_per_thread_qdq_vec =
783+
(size_per_thread_qdq_vec + kAlignment - 1) & (-(kAlignment - 1));
784+
int64_t size_per_thread_qdq_bytes = size_per_thread_qdq_vec * sizeof(accum_t);
784785
int64_t size_qdq_bytes = size_per_thread_qdq_bytes * num_thread;
785786
std::vector<char> scratch_for_quant_dequant_vec(size_qdq_bytes);
786-
accum_t* scratch_for_quant_dequant = reinterpret_cast<accum_t*>(scratch_for_quant_dequant_vec.data());
787+
accum_t* scratch_for_quant_dequant =
788+
reinterpret_cast<accum_t*>(scratch_for_quant_dequant_vec.data());
787789

788790
// Data ptrs
789791
const scalar_t* q_data = query.const_data_ptr<scalar_t>();
@@ -808,7 +810,8 @@ void cpu_flash_attention(
808810
scalar_t* qk_reduced_data = is_reduced_type
809811
? buf_reduced_data + ompIdx * qSplitSize * kvSplitSize
810812
: nullptr;
811-
accum_t* buf_qdq_ptr = scratch_for_quant_dequant + ompIdx * size_per_thread_qdq_vec;
813+
accum_t* buf_qdq_ptr =
814+
scratch_for_quant_dequant + ompIdx * size_per_thread_qdq_vec;
812815

813816
for (int64_t z = begin; z < end; z++) {
814817
int64_t m = k * qSplitSize;

extension/llm/custom_ops/test_quantized_sdpa.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import torch.nn.functional as F
1313

1414
from executorch.extension.llm.custom_ops import custom_ops # noqa
15+
from executorch.extension.pybindings.portable_lib import _unsafe_reset_threadpool
1516

1617

1718
def is_fbcode():
@@ -40,6 +41,11 @@ def setUp(self):
4041
self.q_shape = None
4142
self.kv_shape = None
4243
self.is_seq_at_dim_2 = True
44+
# For some reason 4 threads doesnt work
45+
# This setting is needed to make this test not flaky due to OMP
46+
# error of "OMP: Error #131: Thread identifier invalid"
47+
# Not clear why that happens but having smaller threadpool resolves it
48+
_unsafe_reset_threadpool(3)
4349

4450
def _scale_tensor(self, tensor, min_value, max_value, scale=True):
4551
normalized_tensor = (tensor - tensor.min()) / (tensor.max() - tensor.min())

0 commit comments

Comments
 (0)