File tree Expand file tree Collapse file tree 2 files changed +43
-98
lines changed
backends/vulkan/runtime/graph/ops/glsl Expand file tree Collapse file tree 2 files changed +43
-98
lines changed Original file line number Diff line number Diff line change @@ -119,55 +119,27 @@ void main() {
119119 }
120120 // Otherwise, need to actually compute output tile
121121 else {
122- const bool dont_check_bounds = (S - s) >= TILE_M &&
123- (context_len - c) >= TILE_N;
124-
125- if (dont_check_bounds) {
126- for (int d4 = worker_id; d4 < D4; d4 += NUM_WORKERS_PER_OUT) {
127- load_q_projected_tile_no_checks(
128- q_tile,
129- d4,
130- s,
131- q_h,
132- D4,
133- Q_H,
134- S);
135-
136- load_k_cache_tile_no_checks(
137- w_tile,
138- d4,
139- c,
140- kv_h,
141- D4,
142- context_len,
143- C,
144- KV_H);
145-
146- fp_accumulate_with_fp_weight(out_tile, q_tile, w_tile);
147- }
148- } else {
149- for (int d4 = worker_id; d4 < D4; d4 += NUM_WORKERS_PER_OUT) {
150- load_q_projected_tile_with_checks(
151- q_tile,
152- d4,
153- s,
154- q_h,
155- D4,
156- Q_H,
157- S);
158-
159- load_k_cache_tile_with_checks(
160- w_tile,
161- d4,
162- c,
163- kv_h,
164- D4,
165- context_len,
166- C,
167- KV_H);
168-
169- fp_accumulate_with_fp_weight(out_tile, q_tile, w_tile);
170- }
122+ for (int d4 = worker_id; d4 < D4; d4 += NUM_WORKERS_PER_OUT) {
123+ load_q_projected_tile_with_checks(
124+ q_tile,
125+ d4,
126+ s,
127+ q_h,
128+ D4,
129+ Q_H,
130+ S);
131+
132+ load_k_cache_tile_with_checks(
133+ w_tile,
134+ d4,
135+ c,
136+ kv_h,
137+ D4,
138+ context_len,
139+ C,
140+ KV_H);
141+
142+ fp_accumulate_with_fp_weight(out_tile, q_tile, w_tile);
171143 }
172144 }
173145
Original file line number Diff line number Diff line change @@ -130,55 +130,28 @@ void main() {
130130 }
131131 // Otherwise, need to actually compute output tile
132132 else {
133- const bool dont_check_bounds = (S - s) >= TILE_M &&
134- (context_len - c) >= TILE_N;
135-
136- if (dont_check_bounds) {
137- for (int d4 = 0 ; d4 < D4; d4++ ) {
138- load_q_projected_tile_no_checks(
139- q_tile,
140- d4,
141- s,
142- q_h,
143- D4,
144- Q_H,
145- S);
146-
147- load_k_cache_tile_no_checks(
148- w_tile,
149- d4,
150- c,
151- kv_h,
152- D4,
153- context_len,
154- C,
155- KV_H);
156-
157- fp_accumulate_with_fp_weight(out_tile, q_tile, w_tile);
158- }
159- } else {
160- for (int d4 = 0 ; d4 < D4; d4++ ) {
161- load_q_projected_tile_with_checks(
162- q_tile,
163- d4,
164- s,
165- q_h,
166- D4,
167- Q_H,
168- S);
169-
170- load_k_cache_tile_with_checks(
171- w_tile,
172- d4,
173- c,
174- kv_h,
175- D4,
176- context_len,
177- C,
178- KV_H);
179-
180- fp_accumulate_with_fp_weight(out_tile, q_tile, w_tile);
181- }
133+ for (int d4 = 0 ; d4 < D4; d4++ ) {
134+ load_q_projected_tile_with_checks(
135+ q_tile,
136+ d4,
137+ s,
138+ q_h,
139+ D4,
140+ Q_H,
141+ S);
142+
143+ load_k_cache_tile_with_checks(
144+ w_tile,
145+ d4,
146+ c,
147+ kv_h,
148+ D4,
149+ context_len,
150+ C,
151+ KV_H);
152+
153+
154+ fp_accumulate_with_fp_weight(out_tile, q_tile, w_tile);
182155 }
183156
184157 // Apply scale and mask
You can’t perform that action at this time.
0 commit comments