Skip to content

Commit c66ea27

Browse files
authored
[ET-VK][ez] SDPA don't branch based on whether bounds check needed (#15600)
Title says it all! Why? * The branching path is causing incorrect output on Samsung S24. It's unclear what the exact underlying issue is but the problem is not reproducible on other GPUs and appears to be an issue specific to Adreno 750 architecture. To be safe, always use bounds checking. Differential Revision: [D86226136](https://our.internmc.facebook.com/intern/diff/D86226136/)
1 parent a13dc4c commit c66ea27

File tree

2 files changed

+43
-98
lines changed

2 files changed

+43
-98
lines changed

backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_coop.glsl

Lines changed: 21 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -119,55 +119,27 @@ void main() {
119119
}
120120
// Otherwise, need to actually compute output tile
121121
else {
122-
const bool dont_check_bounds = (S - s) >= TILE_M &&
123-
(context_len - c) >= TILE_N;
124-
125-
if (dont_check_bounds) {
126-
for (int d4 = worker_id; d4 < D4; d4 += NUM_WORKERS_PER_OUT) {
127-
load_q_projected_tile_no_checks(
128-
q_tile,
129-
d4,
130-
s,
131-
q_h,
132-
D4,
133-
Q_H,
134-
S);
135-
136-
load_k_cache_tile_no_checks(
137-
w_tile,
138-
d4,
139-
c,
140-
kv_h,
141-
D4,
142-
context_len,
143-
C,
144-
KV_H);
145-
146-
fp_accumulate_with_fp_weight(out_tile, q_tile, w_tile);
147-
}
148-
} else {
149-
for (int d4 = worker_id; d4 < D4; d4 += NUM_WORKERS_PER_OUT) {
150-
load_q_projected_tile_with_checks(
151-
q_tile,
152-
d4,
153-
s,
154-
q_h,
155-
D4,
156-
Q_H,
157-
S);
158-
159-
load_k_cache_tile_with_checks(
160-
w_tile,
161-
d4,
162-
c,
163-
kv_h,
164-
D4,
165-
context_len,
166-
C,
167-
KV_H);
168-
169-
fp_accumulate_with_fp_weight(out_tile, q_tile, w_tile);
170-
}
122+
for (int d4 = worker_id; d4 < D4; d4 += NUM_WORKERS_PER_OUT) {
123+
load_q_projected_tile_with_checks(
124+
q_tile,
125+
d4,
126+
s,
127+
q_h,
128+
D4,
129+
Q_H,
130+
S);
131+
132+
load_k_cache_tile_with_checks(
133+
w_tile,
134+
d4,
135+
c,
136+
kv_h,
137+
D4,
138+
context_len,
139+
C,
140+
KV_H);
141+
142+
fp_accumulate_with_fp_weight(out_tile, q_tile, w_tile);
171143
}
172144
}
173145

backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_tiled.glsl

Lines changed: 22 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -130,55 +130,28 @@ void main() {
130130
}
131131
// Otherwise, need to actually compute output tile
132132
else {
133-
const bool dont_check_bounds = (S - s) >= TILE_M &&
134-
(context_len - c) >= TILE_N;
135-
136-
if (dont_check_bounds) {
137-
for (int d4 = 0; d4 < D4; d4++) {
138-
load_q_projected_tile_no_checks(
139-
q_tile,
140-
d4,
141-
s,
142-
q_h,
143-
D4,
144-
Q_H,
145-
S);
146-
147-
load_k_cache_tile_no_checks(
148-
w_tile,
149-
d4,
150-
c,
151-
kv_h,
152-
D4,
153-
context_len,
154-
C,
155-
KV_H);
156-
157-
fp_accumulate_with_fp_weight(out_tile, q_tile, w_tile);
158-
}
159-
} else {
160-
for (int d4 = 0; d4 < D4; d4++) {
161-
load_q_projected_tile_with_checks(
162-
q_tile,
163-
d4,
164-
s,
165-
q_h,
166-
D4,
167-
Q_H,
168-
S);
169-
170-
load_k_cache_tile_with_checks(
171-
w_tile,
172-
d4,
173-
c,
174-
kv_h,
175-
D4,
176-
context_len,
177-
C,
178-
KV_H);
179-
180-
fp_accumulate_with_fp_weight(out_tile, q_tile, w_tile);
181-
}
133+
for (int d4 = 0; d4 < D4; d4++) {
134+
load_q_projected_tile_with_checks(
135+
q_tile,
136+
d4,
137+
s,
138+
q_h,
139+
D4,
140+
Q_H,
141+
S);
142+
143+
load_k_cache_tile_with_checks(
144+
w_tile,
145+
d4,
146+
c,
147+
kv_h,
148+
D4,
149+
context_len,
150+
C,
151+
KV_H);
152+
153+
154+
fp_accumulate_with_fp_weight(out_tile, q_tile, w_tile);
182155
}
183156

184157
// Apply scale and mask

0 commit comments

Comments
 (0)