Skip to content

Commit fdadfd8

Browse files
authored
Fix flash decoding in GPU. (apple#999)
target_positions used to be time_step, but after PR apple#995, it now represents the actual target positions with shape [batch, step_len]. apple#995 Updating the GPU decoding code to align with this change. CI did not cover GPU unit tests. TEST=test_extend_step10 of axlearn/common/flash_attention/layer_test.py in GPU
1 parent 9e64388 commit fdadfd8

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

axlearn/common/flash_attention/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,7 @@ def get_segment_ids(segment_ids: SegmentIdAttentionBias) -> Optional[Tensor]:
212212
if mask is None or mask.target_positions is None:
213213
raise RuntimeError("Cannot retrive MaskFnAttentionBias or target_positions.")
214214
mask_fn = mask.mask
215-
kv_seq_len = mask.target_positions + 1
215+
kv_seq_len = mask.target_positions[:, -1] + 1
216216
logging.info("Using mask_fn=%s for FlashDecoding.", mask_fn)
217217

218218
bias = explicit_bias.value()

0 commit comments

Comments
 (0)