Skip to content

Commit 424b733

Browse files
authored
Merge pull request #199 from SmallDoges/fix-bug
Fix attention bias calculation and dbias handling
2 parents 071ab90 + a5b4cb7 commit 424b733

File tree

2 files changed

+10
-11
lines changed

2 files changed

+10
-11
lines changed

csrc/flash_dmattn/flash_api.cpp

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -872,14 +872,12 @@ mha_bwd(
872872
const int num_heads_k = k.size(2);
873873
int num_heads_mask = has_mask ? mask.size(1) : 1;
874874
int num_heads_bias = has_bias ? bias.size(1) : 1;
875-
int batch_size_mask = has_mask ? mask.size(0) : batch_size;
876-
int batch_size_bias = has_bias ? bias.size(0) : batch_size;
877-
int seqlen_q_mask = has_mask ? mask.size(2) : seqlen_q;
878-
int seqlen_q_bias = has_bias ? bias.size(2) : seqlen_q;
879875
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
880876
const int head_size_rounded = round_multiple(head_size, head_size <= 128 ? 32 : 64);
881877
const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
882878
const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
879+
int batch_size_dbias = has_bias ? bias.size(0) : batch_size;
880+
int seqlen_q_dbias = has_bias ? bias.size(2) : seqlen_q;
883881

884882
TORCH_CHECK(batch_size > 0, "batch size must be positive");
885883
TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8");
@@ -945,7 +943,7 @@ mha_bwd(
945943
TORCH_CHECK(dbias.size(2) == seqlen_q || dbias.size(2) == 1, "Query length dimension in dbias must be 1 or equal to seqlen_q");
946944
TORCH_CHECK(dbias.size(3) == seqlen_k_rounded, "Key length dimension in dbias must be seqlen_k_rounded");
947945
} else {
948-
dbias = torch::empty({batch_size_bias, num_heads_bias, seqlen_q_bias, seqlen_k_rounded}, opts);
946+
dbias = torch::empty({batch_size_dbias, num_heads_bias, seqlen_q_dbias, seqlen_k_rounded}, opts);
949947
}
950948
} else {
951949
dbias = torch::empty({0}, opts);
@@ -977,8 +975,8 @@ mha_bwd(
977975
? torch::empty({batch_size, seqlen_k, num_heads, head_size}, opts)
978976
: dv;
979977
dbias_expanded = has_bias
980-
? (num_heads_bias != num_heads || batch_size_bias == 1 || seqlen_q_bias == 1) // MQA / GQA or dbias has different batch size or seqlen_q
981-
? (seqlen_q_bias == 1)
978+
? (num_heads_bias != num_heads || batch_size_dbias == 1 || seqlen_q_dbias == 1) // MQA / GQA or dbias has different batch size or seqlen_q
979+
? (seqlen_q_dbias == 1)
982980
? torch::zeros({batch_size, num_heads, 1, seqlen_k_rounded}, opts)
983981
: torch::zeros({batch_size, num_heads, seqlen_q, seqlen_k_rounded}, opts)
984982
: dbias
@@ -1033,15 +1031,15 @@ mha_bwd(
10331031
}
10341032
// For MQA/GQA or dbias has different batch size or seqlen_q, we need to sum dbias across the groups, batch and seqlen_q
10351033
if (has_bias) {
1036-
if (num_heads_bias != num_heads && batch_size_bias == batch_size && seqlen_q_bias == seqlen_q) {
1034+
if (num_heads_bias != num_heads && batch_size_dbias == batch_size && seqlen_q_dbias == seqlen_q) {
10371035
at::sum_out(dbias, at::reshape(dbias_expanded, {batch_size, num_heads_bias, num_heads / num_heads_bias, seqlen_q, seqlen_k_rounded}), {2});
10381036
} else {
1039-
if (seqlen_q_bias == 1) {
1037+
if (seqlen_q_dbias == 1) {
10401038
dbias_expanded = at::sum(at::reshape(dbias_expanded, {batch_size, num_heads_bias, num_heads / num_heads_bias, 1, seqlen_k_rounded}), {2});
10411039
} else {
10421040
dbias_expanded = at::sum(at::reshape(dbias_expanded, {batch_size, num_heads_bias, num_heads / num_heads_bias, seqlen_q, seqlen_k_rounded}), {2});
10431041
}
1044-
if (batch_size_bias == 1) {
1042+
if (batch_size_dbias == 1) {
10451043
dbias_expanded = at::sum(dbias_expanded, {0}, true);
10461044
}
10471045
dbias.copy_(dbias_expanded);

examples/modeling/modeling_doge.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ def forward(
218218
value_states.transpose(1, 2).reshape(value_states.shape[0], value_states.shape[-2], -1)
219219
)
220220
# original formula is exp(A * softplus(delta V)), but for numerical stability, it is changed to A * softplus(delta V)
221-
attn_bias = self.A * F.softplus(dt_states).transpose(-1, -2).unsqueeze(-2).to(hidden_states.dtype)
221+
attn_bias = (self.A * F.softplus(dt_states)).transpose(-1, -2).unsqueeze(-2).to(hidden_states.dtype)
222222

223223
attention_interface: Callable = flash_dynamic_mask_attention_forward
224224

@@ -230,6 +230,7 @@ def forward(
230230
attention_mask=attention_mask,
231231
attention_bias=attn_bias,
232232
scale=self.scaling,
233+
window_size=self.window_size,
233234
)
234235

235236
attn_output = attn_output.reshape(*input_shape, -1).contiguous()

0 commit comments

Comments
 (0)