Skip to content

Commit e99ada1

Browse files
author
ssjia
committed
[ET-VK][ez] Don't copy zeros for cache tensors
Currently, cache tensors for SDPA are prepacked even though the mutable buffer data just contains zeros. For fused SDPA, this step can be skipped. Differential Revision: [D86226137](https://our.internmc.facebook.com/intern/diff/D86226137/) [ghstack-poisoned]
1 parent 77ed617 commit e99ada1

File tree

2 files changed

+14
-6
lines changed

2 files changed

+14
-6
lines changed

backends/vulkan/runtime/graph/ops/glsl/sdpa_kv_cache_update.glsl

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@ ${define_required_extensions(DTYPE)}
1414

1515
layout(std430) buffer;
1616

17+
#define DEBUG_MODE
18+
19+
#extension GL_EXT_debug_printf : enable
20+
1721
#include "common.glslh"
1822

1923
${layout_declare_tensor(B, "w", "t_cache", DTYPE, OUTPUT_STORAGE, is_scalar_array=False)}
@@ -80,13 +84,17 @@ void main() {
8084
const int S = projected_sizes.z;
8185
const int H = projected_sizes.y;
8286

83-
if (d4 >= D4 || s >= S || h >= H) {
87+
const int c = s + input_pos; // idx along max_context_len dim
88+
const int C = cache_sizes.z;
89+
90+
if (d4 >= D4 || c >= C || h >= H) {
8491
return;
8592
}
8693

87-
const int c = s + input_pos; // idx along max_context_len dim
88-
const int C = cache_sizes.y;
94+
IN_VEC4_T in_texel = IN_VEC4_T(0.0);
95+
if (s < S) {
96+
in_texel = read_projected_d4(d4, h, s, D4, H, S);
97+
}
8998

90-
IN_VEC4_T in_texel = read_projected_d4(d4, h, s, D4, H, S);
9199
write_cache_d4(in_texel, d4, c, h, D4, C, H);
92100
}

backends/vulkan/runtime/graph/ops/impl/SDPA.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -575,9 +575,9 @@ void compute_attn_weight_with_kv_cache_impl(
575575

576576
utils::StorageType cache_storage = graph.storage_type_of(q_projected);
577577
const ValueRef k_cache =
578-
prepack_standard(graph, k_cache_data, cache_storage, utils::kWidthPacked);
578+
graph.add_tensor_like(k_cache_data, cache_storage, utils::kWidthPacked);
579579
const ValueRef v_cache =
580-
prepack_standard(graph, v_cache_data, cache_storage, utils::kWidthPacked);
580+
graph.add_tensor_like(v_cache_data, cache_storage, utils::kWidthPacked);
581581

582582
update_cache_impl(graph, {k_projected, k_cache, input_pos_symint, -1});
583583
update_cache_impl(graph, {v_projected, v_cache, input_pos_symint, -1});

0 commit comments

Comments
 (0)