You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Thanks for your effort to make this great platform.
In normal attention, the input of softmax function is a form of matmul(Q,K_T) and its dimension is (batch, num_heads, q_len, k_len)
Also, the attention mask is like a trigonal shape (total shape is could be q_len x k_len)
so, matmul(q, k_t) is masked with the attention mask.
However, I don't understand how matmul(q_chunk, transposed k_chunk) works and results in masked input of softmax compared with original attention algorithm flow at the code lines below.
Thanks for your effort to make this great platform.
In normal attention, the input of softmax function is a form of matmul(Q,K_T) and its dimension is (batch, num_heads, q_len, k_len)
Also, the attention mask is like a trigonal shape (total shape is could be q_len x k_len)
so, matmul(q, k_t) is masked with the attention mask.
However, I don't understand how matmul(q_chunk, transposed k_chunk) works and results in masked input of softmax compared with original attention algorithm flow at the code lines below.
flash-attention-jax/flash_attention_jax/flash_attention.py
Lines 34 to 37 in 5727815
Can you explain it with details?
The text was updated successfully, but these errors were encountered: