Skip to content

[Executorch][SDPA] Remove slice creation #9911

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 13 additions & 67 deletions extension/llm/custom_ops/op_sdpa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,6 @@ Tensor& flash_attention_kernel_out(
Format [n_layers, batch size, max_seq_len, num heads, head dim]
....
@param[in] start_pos: sequence position
@param[in] seq_len: Seq length. e.g. seq_len dim of q_projected.
*/
Tensor& custom_sdpa_out(
RuntimeContext& ctx,
Expand Down Expand Up @@ -306,63 +305,7 @@ Tensor& custom_sdpa_out(
const int64_t seq_len = q.size(1);
auto q_seq_len = q.size(1);

// Refactor the following into create_view util perhaps using
// TensorPtr
std::array<::executorch::aten::DimOrderType, sdpa::impl::kKVDim>
sliced_key_dim_order{0, 1, 2, 3};
std::array<::executorch::aten::SizesType, sdpa::impl::kKVDim>
sliced_key_sizes;
sliced_key_sizes[0] = k.size(0);
sliced_key_sizes[1] = start_pos + seq_len; // key_cache.size(2);
sliced_key_sizes[2] = k.size(2);
sliced_key_sizes[3] = k.size(3);
std::array<::executorch::aten::StridesType, sdpa::impl::kKVDim>
sliced_key_strides;
dim_order_to_stride_nocheck(
sliced_key_sizes.data(),
sliced_key_dim_order.data(),
sdpa::impl::kKVDim,
sliced_key_strides.data());
// since the cache is sliced, the batch stride needs to stay the same.
sliced_key_strides[0] = k.strides()[0];
void* key_cache_data = k.mutable_data_ptr();
TensorImpl k_impl = TensorImpl(
k.scalar_type(),
sdpa::impl::kKVDim,
sliced_key_sizes.data(),
key_cache_data,
sliced_key_dim_order.data(),
sliced_key_strides.data(),
TensorShapeDynamism::STATIC);
Tensor sliced_key_cache(&k_impl);

std::array<::executorch::aten::DimOrderType, sdpa::impl::kKVDim>
sliced_value_dim_order{0, 1, 2, 3};
std::array<::executorch::aten::SizesType, sdpa::impl::kKVDim>
sliced_value_sizes;
sliced_value_sizes[0] = v.size(0);
sliced_value_sizes[1] = start_pos + seq_len; // value_cache.size(2);
sliced_value_sizes[2] = v.size(2);
sliced_value_sizes[3] = v.size(3);
std::array<::executorch::aten::StridesType, sdpa::impl::kKVDim>
sliced_value_strides;
dim_order_to_stride_nocheck(
sliced_value_sizes.data(),
sliced_value_dim_order.data(),
sdpa::impl::kKVDim,
sliced_value_strides.data());
// since the cache is sliced, the batch stride needs to stay the same.
sliced_value_strides[0] = v.strides()[0];
void* value_cache_data = v.mutable_data_ptr();
TensorImpl value_impl = TensorImpl(
v.scalar_type(),
sdpa::impl::kKVDim,
sliced_value_sizes.data(),
value_cache_data,
sliced_value_dim_order.data(),
sliced_value_strides.data(),
TensorShapeDynamism::STATIC);
Tensor sliced_value_cache(&value_impl);
const int64_t num_keys_for_causal_attention = start_pos + seq_len;

ET_KERNEL_CHECK(
ctx,
Expand All @@ -380,38 +323,41 @@ Tensor& custom_sdpa_out(
sdpa::impl::cpu_flash_attention<CTYPE, 256, 512>(
output,
q,
sliced_key_cache,
sliced_value_cache,
k,
v,
dropout_p,
is_causal,
attn_mask,
scale,
true, /* is_seq_at_dim_1 */
start_pos);
start_pos,
num_keys_for_causal_attention);
} else if (q_seq_len >= 192) {
sdpa::impl::cpu_flash_attention<CTYPE, 64, 512>(
output,
q,
sliced_key_cache,
sliced_value_cache,
k,
v,
dropout_p,
is_causal,
attn_mask,
scale,
true, /* is_seq_at_dim_1 */
start_pos);
start_pos,
num_keys_for_causal_attention);
} else {
sdpa::impl::cpu_flash_attention<CTYPE, 32, 512>(
output,
q,
sliced_key_cache,
sliced_value_cache,
k,
v,
dropout_p,
is_causal,
attn_mask,
scale,
true, /* is_seq_at_dim_1 */
start_pos);
start_pos,
num_keys_for_causal_attention);
}
});
return output;
Expand Down
10 changes: 9 additions & 1 deletion extension/llm/custom_ops/op_sdpa_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,8 @@ void cpu_flash_attention(
const optional<Tensor>& attn_mask,
const optional<double>& scale,
bool is_seq_at_dim_1 = false,
const int64_t start_pos = 0) {
const int64_t start_pos = 0,
const int64_t num_keys_for_causal_attention = -1) {
(void)dropout_p;
// Query (Batch x Num_heads x Q_seq_len x Dim_per_head)
// Key (Batch x Num_heads x KV_seq_len x Dim_per_head)
Expand Down Expand Up @@ -258,6 +259,13 @@ void cpu_flash_attention(
kvSize = value.size(1);
}

if (num_keys_for_causal_attention > 0) {
ET_CHECK_MSG(
num_keys_for_causal_attention <= kvSize,
"num_keys_for_causal_attention must be <= kvSize");
kvSize = num_keys_for_causal_attention;
}

ET_CHECK_MSG(
num_heads_kv <= num_head,
"FlashAttention does not support num kv heads > num query heads.Got num query heads=%" PRId64
Expand Down
Loading