Skip to content

Commit 30d55b9

Browse files
authored
Merge pull request #185 from SmallDoges/fix-183
Implement variable-length attention with mask and bias support
2 parents b6d2ea7 + e3ff84c commit 30d55b9

File tree

11 files changed

+1251
-489
lines changed

11 files changed

+1251
-489
lines changed

README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ import math
157157

158158
# Setup
159159
batch_size, seq_len, num_heads, num_kv_heads, head_dim = 1, 256, 2, 1, 64
160-
keep_window_size = 128
160+
window_size = 128
161161
device = torch.device('cuda')
162162
dtype = torch.bfloat16
163163
min_dtype = torch.finfo(dtype).min # dtype minimum value
@@ -172,10 +172,10 @@ attention_mask = torch.ones(batch_size, num_kv_heads, seq_len, seq_len, device=d
172172
attention_bias = torch.randn(batch_size, num_kv_heads, seq_len, seq_len, device=device, dtype=dtype)
173173

174174
# Generate sparse mask based on bias
175-
if seq_len > keep_window_size:
175+
if seq_len > window_size:
176176
# Select top-k most important keys for each query
177177
topk_values, topk_indices = torch.topk(
178-
attention_bias, keep_window_size, dim=-1,
178+
attention_bias, window_size, dim=-1,
179179
largest=True, sorted=False
180180
)
181181
# Generate valid top-k mask

README_zh.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ import math
157157

158158
# 设置
159159
batch_size, seq_len, num_heads, num_kv_heads, head_dim = 1, 256, 2, 1, 64
160-
keep_window_size = 128
160+
window_size = 128
161161
device = torch.device('cuda')
162162
dtype = torch.bfloat16
163163
min_dtype = torch.finfo(dtype).min # dtype 的最小值
@@ -172,10 +172,10 @@ attention_mask = torch.ones(batch_size, num_kv_heads, seq_len, seq_len, device=d
172172
attention_bias = torch.randn(batch_size, num_kv_heads, seq_len, seq_len, device=device, dtype=dtype)
173173

174174
# 基于 bias 生成稀疏 mask
175-
if seq_len > keep_window_size:
175+
if seq_len > window_size:
176176
# 为每个查询选择 top-k 最重要的键
177177
topk_values, topk_indices = torch.topk(
178-
attention_bias, keep_window_size, dim=-1,
178+
attention_bias, window_size, dim=-1,
179179
largest=True, sorted=False
180180
)
181181
# 生成有效的 top-k mask

csrc/flash_dmattn/flash_api.cpp

Lines changed: 200 additions & 204 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)