You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
* Roll back '_repeat_kv_heads' change in Flash Attention
Recent PR removed _repeat_kv_heads from Flash Attention for GQA optimization,
in the hope to reduce HBM usage. However the actual HBM saving would be limited
in the model-parallel setting, as the heads are already sharded across devices.
It also introduces some limitation which breaks some of the existing sharding
configurations.
For example, let's say num_heads = 8 and num_kv_heads = 4. When we repeat KV heads,
we can set the model axis as 8 so that each device will have only one Q, K, V head;
Without repeat_kv_heads, the max value of model axis is 4, and each device will have
2 Q heads as a result, increasing the actual HBM usage.
* Repeat kv as necessary for sharding
* Unit tests
* Address comments.
0 commit comments