@@ -872,14 +872,12 @@ mha_bwd(
872872 const int num_heads_k = k.size (2 );
873873 int num_heads_mask = has_mask ? mask.size (1 ) : 1 ;
874874 int num_heads_bias = has_bias ? bias.size (1 ) : 1 ;
875- int batch_size_mask = has_mask ? mask.size (0 ) : batch_size;
876- int batch_size_bias = has_bias ? bias.size (0 ) : batch_size;
877- int seqlen_q_mask = has_mask ? mask.size (2 ) : seqlen_q;
878- int seqlen_q_bias = has_bias ? bias.size (2 ) : seqlen_q;
879875 auto round_multiple = [](int x, int m) { return (x + m - 1 ) / m * m; };
880876 const int head_size_rounded = round_multiple (head_size, head_size <= 128 ? 32 : 64 );
881877 const int seqlen_q_rounded = round_multiple (seqlen_q, 128 );
882878 const int seqlen_k_rounded = round_multiple (seqlen_k, 128 );
879+ int batch_size_dbias = has_bias ? bias.size (0 ) : batch_size;
880+ int seqlen_q_dbias = has_bias ? bias.size (2 ) : seqlen_q;
883881
884882 TORCH_CHECK (batch_size > 0 , " batch size must be positive" );
885883 TORCH_CHECK (head_size % 8 == 0 , " head_size should be a multiple of 8" );
@@ -945,7 +943,7 @@ mha_bwd(
945943 TORCH_CHECK (dbias.size (2 ) == seqlen_q || dbias.size (2 ) == 1 , " Query length dimension in dbias must be 1 or equal to seqlen_q" );
946944 TORCH_CHECK (dbias.size (3 ) == seqlen_k_rounded, " Key length dimension in dbias must be seqlen_k_rounded" );
947945 } else {
948- dbias = torch::empty ({batch_size_bias , num_heads_bias, seqlen_q_bias , seqlen_k_rounded}, opts);
946+ dbias = torch::empty ({batch_size_dbias , num_heads_bias, seqlen_q_dbias , seqlen_k_rounded}, opts);
949947 }
950948 } else {
951949 dbias = torch::empty ({0 }, opts);
@@ -977,8 +975,8 @@ mha_bwd(
977975 ? torch::empty ({batch_size, seqlen_k, num_heads, head_size}, opts)
978976 : dv;
979977 dbias_expanded = has_bias
980- ? (num_heads_bias != num_heads || batch_size_bias == 1 || seqlen_q_bias == 1 ) // MQA / GQA or dbias has different batch size or seqlen_q
981- ? (seqlen_q_bias == 1 )
978+ ? (num_heads_bias != num_heads || batch_size_dbias == 1 || seqlen_q_dbias == 1 ) // MQA / GQA or dbias has different batch size or seqlen_q
979+ ? (seqlen_q_dbias == 1 )
982980 ? torch::zeros ({batch_size, num_heads, 1 , seqlen_k_rounded}, opts)
983981 : torch::zeros ({batch_size, num_heads, seqlen_q, seqlen_k_rounded}, opts)
984982 : dbias
@@ -1033,15 +1031,15 @@ mha_bwd(
10331031 }
10341032 // For MQA/GQA or dbias has different batch size or seqlen_q, we need to sum dbias across the groups, batch and seqlen_q
10351033 if (has_bias) {
1036- if (num_heads_bias != num_heads && batch_size_bias == batch_size && seqlen_q_bias == seqlen_q) {
1034+ if (num_heads_bias != num_heads && batch_size_dbias == batch_size && seqlen_q_dbias == seqlen_q) {
10371035 at::sum_out (dbias, at::reshape (dbias_expanded, {batch_size, num_heads_bias, num_heads / num_heads_bias, seqlen_q, seqlen_k_rounded}), {2 });
10381036 } else {
1039- if (seqlen_q_bias == 1 ) {
1037+ if (seqlen_q_dbias == 1 ) {
10401038 dbias_expanded = at::sum (at::reshape (dbias_expanded, {batch_size, num_heads_bias, num_heads / num_heads_bias, 1 , seqlen_k_rounded}), {2 });
10411039 } else {
10421040 dbias_expanded = at::sum (at::reshape (dbias_expanded, {batch_size, num_heads_bias, num_heads / num_heads_bias, seqlen_q, seqlen_k_rounded}), {2 });
10431041 }
1044- if (batch_size_bias == 1 ) {
1042+ if (batch_size_dbias == 1 ) {
10451043 dbias_expanded = at::sum (dbias_expanded, {0 }, true );
10461044 }
10471045 dbias.copy_ (dbias_expanded);
0 commit comments