Skip to content

Commit

Permalink
add expanded doc info
Browse files Browse the repository at this point in the history
  • Loading branch information
shubhambhokare1 committed Oct 28, 2024
1 parent 15bc2d2 commit 982823f
Show file tree
Hide file tree
Showing 209 changed files with 144 additions and 590 deletions.
40 changes: 36 additions & 4 deletions docs/Changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -29034,7 +29034,39 @@ This version of the operator has been available since version 23 of the default

### <a name="ScalarDotProductAttention-23"></a>**ScalarDotProductAttention-23**</a>

Computes scaled dot product attention on query, key and value tensors, using an optional attention mask if passed
Computes scaled dot product attention on query, key and value tensors, using an optional attention mask if passed.

This operator covers self and cross variants of the attention operation based on sequence lengths of K, Q and V.
For self attention, kv_sequence_length equals to q_sequence_length.
For cross attention, query and key might have different lengths.

This operator also covers the 3 following variants based on the number of heads:
1) Multi-headed Attention (MHA): Described in the paper https://arxiv.org/pdf/1706.03762, q_num_heads = kv_num_heads.
2) Group-query Attention (GQA): Described in the paper https://arxiv.org/pdf/2305.13245, q_num_heads > kv_num_heads.
3) Multi-query Attention (MQA): Described in the paper https://arxiv.org/pdf/1911.02150, q_num_heads > kv_num_heads, q_num_heads=1.

Attention bias to be added is calculated based on attn_mask input and is_causal attribute, only one of which can be provided.
1) If is_causal is set to 1, the attention masking is a lower triangular matrix when the mask is a square matrix. The attention masking has the form of the upper left causal bias due to the alignment.
2) attn_mask: A boolean mask where a value of True indicates that the element should take part in attention or a float mask of the same type as query, key, value that is added to the attention score.

Both past and present state key/values are optional. They shall be used together, and not allowed to use only one of them.
The following pattern is applied to the Q, K and V inputs after appropriate reshaping of K and V inputs based on sequence lengths and num heads provided:

Q K V
| | |
| Transpose |
| | |
---MatMul--- |
| |
scale---Mul |
| |
at_bias---Add |
| |
Softmax |
| |
-----MatMul------
|
Y


#### Version
Expand Down Expand Up @@ -29064,7 +29096,7 @@ This version of the operator has been available since version 23 of the default
<dt><tt>V</tt> : T</dt>
<dd>Value tensor. 4D tensor with shape (batch_size, kv_num_heads, kv_sequence_length, v_head_size) or 3D tensor with shape (batch_size, kv_sequence_length, v_hidden_size). For cases with a 3D input tensor, v_hidden_size = kv_num_heads * v_head_size</dd>
<dt><tt>attn_mask</tt> (optional) : U</dt>
<dd>Attention mask. Shape must be broadcastable to 3D tensor with shape (batch_size, q_sequence_length, kv_sequence_length). Two types of masks are supported. A boolean mask where a value of True indicates that the element should take part in attention. Also supports a float mask of the same type as query, key, value that is added to the attention score.</dd>
<dd>Attention mask. Shape must be broadcastable to 4D tensor with shape (batch_size, q_num_heads, q_sequence_length, total_sequence_length). total_sequence_length is past_sequence_length + kv_sequence_length. Two types of masks are supported. A boolean mask where a value of True indicates that the element should take part in attention. Also supports a float mask of the same type as query, key, value that is added to the attention score.</dd>
<dt><tt>past_key</tt> (optional) : T</dt>
<dd>past state cache for key with shape (batch_size, kv_num_heads, past_sequence_length, head_size)</dd>
<dt><tt>past_value</tt> (optional) : T</dt>
Expand All @@ -29077,9 +29109,9 @@ This version of the operator has been available since version 23 of the default
<dt><tt>Y</tt> : T</dt>
<dd>The output tensor . 4D tensor with shape (batch_size, q_num_heads, q_sequence_length, v_head_size) or 3D tensor with shape (batch_size, q_sequence_length, hidden_size). For cases with a 3D input tensor, hidden_size = q_num_heads * v_head_size</dd>
<dt><tt>present_key</tt> (optional) : T</dt>
<dd>Updated key cache with shape (batch_size, kv_num_heads, max_sequence_length, head_size).</dd>
<dd>Updated key cache with shape (batch_size, kv_num_heads, total_sequence_length, head_size). total_sequence_length is past_sequence_length + kv_sequence_length.</dd>
<dt><tt>present_value</tt> (optional) : T</dt>
<dd>Updated value cache with shape (batch_size, kv_num_heads, max_sequence_length, v_head_size).</dd>
<dd>Updated value cache with shape (batch_size, kv_num_heads, total_sequence_length, v_head_size). total_sequence_length is past_sequence_length + kv_sequence_length.</dd>
</dl>

#### Type Constraints
Expand Down
Loading

0 comments on commit 982823f

Please sign in to comment.