Skip to content

Commit

Permalink
minor updates
Browse files Browse the repository at this point in the history
  • Loading branch information
shubhambhokare1 committed Oct 24, 2024
1 parent db3a867 commit 5a38a0c
Show file tree
Hide file tree
Showing 8 changed files with 25 additions and 16 deletions.
8 changes: 4 additions & 4 deletions docs/Changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -29046,8 +29046,8 @@ This version of the operator has been available since version 23 of the default

# Fully or partially perform rotation on input based on rotary_embedding_dim attribute
if rotary_embedding_dim == 0:
# If rotary_embedding_dim not provided, perform full rotation by using head_size * 2
rotary_embedding_dim = cos_cache.shape[1] * 2
# If rotary_embedding_dim not provided, perform full rotation by using head_size
rotary_embedding_dim = head_size * 2
x_rotate = input[:, :, :, :rotary_embedding_dim]
x_not_rotate = input[:, :, :, rotary_embedding_dim:]
rotary_embedding_dim_half = int(rotary_embedding_dim / 2)
Expand Down Expand Up @@ -29090,9 +29090,9 @@ This version of the operator has been available since version 23 of the default
<dt><tt>interleaved</tt> : int</dt>
<dd>Rotate using interleaved pattern. Default value is 0 (False).</dd>
<dt><tt>num_heads</tt> : int</dt>
<dd>Number of attention heads. Default value is 0. Must use with `rotary_embedding_dim`. </dd>
<dd>Number of attention heads. Must use with `rotary_embedding_dim`. </dd>
<dt><tt>rotary_embedding_dim</tt> : int</dt>
<dd>Rotary embedding dimension used to apply partial rotary embeddings. Default value is 0. </dd>
<dd>Rotary embedding dimension used to apply partial rotary embeddings.</dd>
</dl>

#### Inputs
Expand Down
8 changes: 4 additions & 4 deletions docs/Operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -27331,8 +27331,8 @@ expect(

# Fully or partially perform rotation on input based on rotary_embedding_dim attribute
if rotary_embedding_dim == 0:
# If rotary_embedding_dim not provided, perform full rotation by using head_size * 2
rotary_embedding_dim = cos_cache.shape[1] * 2
# If rotary_embedding_dim not provided, perform full rotation by using head_size
rotary_embedding_dim = head_size * 2
x_rotate = input[:, :, :, :rotary_embedding_dim]
x_not_rotate = input[:, :, :, rotary_embedding_dim:]
rotary_embedding_dim_half = int(rotary_embedding_dim / 2)
Expand Down Expand Up @@ -27375,9 +27375,9 @@ This version of the operator has been available since version 23 of the default
<dt><tt>interleaved</tt> : int</dt>
<dd>Rotate using interleaved pattern. Default value is 0 (False).</dd>
<dt><tt>num_heads</tt> : int</dt>
<dd>Number of attention heads. Default value is 0. Must use with `rotary_embedding_dim`. </dd>
<dd>Number of attention heads. Must use with `rotary_embedding_dim`. </dd>
<dt><tt>rotary_embedding_dim</tt> : int</dt>
<dd>Rotary embedding dimension used to apply partial rotary embeddings. Default value is 0. </dd>
<dd>Rotary embedding dimension used to apply partial rotary embeddings.</dd>
</dl>

#### Inputs
Expand Down
4 changes: 2 additions & 2 deletions onnx/backend/test/case/node/rotaryembedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ def compute_rotary_embedding(

# Fully or partially perform rotation on input based on rotary_embedding_dim attribute
if rotary_embedding_dim == 0:
# If rotary_embedding_dim not provided, perform full rotation by using head_size * 2
rotary_embedding_dim = cos_cache.shape[1] * 2
# If rotary_embedding_dim not provided, perform full rotation by using head_size
rotary_embedding_dim = head_size
x_rotate = input[:, :, :, :rotary_embedding_dim]
x_not_rotate = input[:, :, :, rotary_embedding_dim:]
rotary_embedding_dim_half = int(rotary_embedding_dim / 2)
Expand Down
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
21 changes: 15 additions & 6 deletions onnx/defs/nn/defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2852,8 +2852,8 @@ Rotary embeddings are defined using the following algorithm:
# Fully or partially perform rotation on input based on rotary_embedding_dim attribute
if rotary_embedding_dim == 0:
# If rotary_embedding_dim not provided, perform full rotation by using head_size * 2
rotary_embedding_dim = cos_cache.shape[1] * 2
# If rotary_embedding_dim not provided, perform full rotation by using head_size
rotary_embedding_dim = head_size * 2
x_rotate = input[:, :, :, :rotary_embedding_dim]
x_not_rotate = input[:, :, :, rotary_embedding_dim:]
rotary_embedding_dim_half = int(rotary_embedding_dim / 2)
Expand Down Expand Up @@ -2897,11 +2897,11 @@ ONNX_OPERATOR_SET_SCHEMA(
AttributeProto::INT,
OPTIONAL_VALUE)
.Attr("rotary_embedding_dim",
"Rotary embedding dimension used to apply partial rotary embeddings. Default value is 0. ",
"Rotary embedding dimension used to apply partial rotary embeddings.",
AttributeProto::INT,
OPTIONAL_VALUE)
.Attr("num_heads",
"Number of attention heads. Default value is 0. Must use with `rotary_embedding_dim`. ",
"Number of attention heads. Must use with `rotary_embedding_dim`. ",
AttributeProto::INT,
OPTIONAL_VALUE)
.Input(0,
Expand Down Expand Up @@ -2934,6 +2934,16 @@ ONNX_OPERATOR_SET_SCHEMA(
.TypeConstraint("T", {"tensor(float)", "tensor(float16)", "tensor(bfloat16)"}, "Constrain input and output types to float tensors.")
.TypeConstraint("M", {"tensor(int64)"}, "Constrain input and output types to integer tensors")
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
auto input_shape = ctx.getInputType(0)->tensor_type().shape();
if (input_shape.dim_size() < 3) {
return; // Input tensor should have at least three dimensions.
}

auto* num_heads_attr = ctx.getAttribute("num_heads");
if ((input_shape.dim_size() == 3) && (num_heads_attr == nullptr)) {
fail_shape_inference("Input shape is 3D, num_heads attribute must be provided");
}

propagateElemTypeFromInputToOutput(ctx, 0, 0);
propagateShapeFromInputToOutput(ctx, 0, 0);
})
Expand Down Expand Up @@ -2963,9 +2973,8 @@ ONNX_OPERATOR_SET_SCHEMA(
// There are two cases for the rotary embedding dimension:
// 1. Complete rotation: rotary embedding dimension defaults to head_size, rotary_embedding_dim = cos.shape[3] * 2 or head_size
// 2. Partial rotation: rotary embedding dimension is provided, rotary_embedding_dim = rotary_embedding_dim
builder.Add("HeadSizeHalf = Shape <start = 1, end = 2> (cos_cache)") // cos_cache.shape[1] or head_size // 2
builder.Add("HeadSize = Shape <start = 3, end = 4> (XReshaped)") // head_size
.Const1D("Two1D", (int64_t)2)
.Add("HeadSize = Mul(HeadSizeHalf, Two1D)") // cos.shape[1] * 2 or head_size
.Const1D("RotaryEmbedDimParam", rotary_embedding_dim)
.Const1D("Zero1D", (int64_t)0)
.Add("RotaryDimCond = Greater(RotaryEmbedDimParam, Zero1D)")
Expand Down

0 comments on commit 5a38a0c

Please sign in to comment.