diff --git a/docs/Changelog.md b/docs/Changelog.md
index c538f2b216c..61bbeabea20 100644
--- a/docs/Changelog.md
+++ b/docs/Changelog.md
@@ -29046,8 +29046,8 @@ This version of the operator has been available since version 23 of the default
# Fully or partially perform rotation on input based on rotary_embedding_dim attribute
if rotary_embedding_dim == 0:
- # If rotary_embedding_dim not provided, perform full rotation by using head_size * 2
- rotary_embedding_dim = cos_cache.shape[1] * 2
+ # If rotary_embedding_dim not provided, perform full rotation by using head_size
+ rotary_embedding_dim = head_size * 2
x_rotate = input[:, :, :, :rotary_embedding_dim]
x_not_rotate = input[:, :, :, rotary_embedding_dim:]
rotary_embedding_dim_half = int(rotary_embedding_dim / 2)
@@ -29090,9 +29090,9 @@ This version of the operator has been available since version 23 of the default
interleaved : int
Rotate using interleaved pattern. Default value is 0 (False).
num_heads : int
-Number of attention heads. Default value is 0. Must use with `rotary_embedding_dim`.
+Number of attention heads. Must use with `rotary_embedding_dim`.
rotary_embedding_dim : int
-Rotary embedding dimension used to apply partial rotary embeddings. Default value is 0.
+Rotary embedding dimension used to apply partial rotary embeddings.
#### Inputs
diff --git a/docs/Operators.md b/docs/Operators.md
index bafb5a1a1fb..d41407665fc 100644
--- a/docs/Operators.md
+++ b/docs/Operators.md
@@ -27331,8 +27331,8 @@ expect(
# Fully or partially perform rotation on input based on rotary_embedding_dim attribute
if rotary_embedding_dim == 0:
- # If rotary_embedding_dim not provided, perform full rotation by using head_size * 2
- rotary_embedding_dim = cos_cache.shape[1] * 2
+ # If rotary_embedding_dim not provided, perform full rotation by using head_size
+ rotary_embedding_dim = head_size * 2
x_rotate = input[:, :, :, :rotary_embedding_dim]
x_not_rotate = input[:, :, :, rotary_embedding_dim:]
rotary_embedding_dim_half = int(rotary_embedding_dim / 2)
@@ -27375,9 +27375,9 @@ This version of the operator has been available since version 23 of the default
interleaved : int
Rotate using interleaved pattern. Default value is 0 (False).
num_heads : int
-Number of attention heads. Default value is 0. Must use with `rotary_embedding_dim`.
+Number of attention heads. Must use with `rotary_embedding_dim`.
rotary_embedding_dim : int
-Rotary embedding dimension used to apply partial rotary embeddings. Default value is 0.
+Rotary embedding dimension used to apply partial rotary embeddings.
#### Inputs
diff --git a/onnx/backend/test/case/node/rotaryembedding.py b/onnx/backend/test/case/node/rotaryembedding.py
index 797092b0ee9..81d8e3b0d4f 100644
--- a/onnx/backend/test/case/node/rotaryembedding.py
+++ b/onnx/backend/test/case/node/rotaryembedding.py
@@ -34,8 +34,8 @@ def compute_rotary_embedding(
# Fully or partially perform rotation on input based on rotary_embedding_dim attribute
if rotary_embedding_dim == 0:
- # If rotary_embedding_dim not provided, perform full rotation by using head_size * 2
- rotary_embedding_dim = cos_cache.shape[1] * 2
+ # If rotary_embedding_dim not provided, perform full rotation by using head_size
+ rotary_embedding_dim = head_size
x_rotate = input[:, :, :, :rotary_embedding_dim]
x_not_rotate = input[:, :, :, rotary_embedding_dim:]
rotary_embedding_dim_half = int(rotary_embedding_dim / 2)
diff --git a/onnx/backend/test/data/node/test_rotary_embedding_3d_input_expanded/model.onnx b/onnx/backend/test/data/node/test_rotary_embedding_3d_input_expanded/model.onnx
index 0de50cf20bf..317957cd22f 100644
Binary files a/onnx/backend/test/data/node/test_rotary_embedding_3d_input_expanded/model.onnx and b/onnx/backend/test/data/node/test_rotary_embedding_3d_input_expanded/model.onnx differ
diff --git a/onnx/backend/test/data/node/test_rotary_embedding_expanded/model.onnx b/onnx/backend/test/data/node/test_rotary_embedding_expanded/model.onnx
index c7a2d657462..4f0344ec20e 100644
Binary files a/onnx/backend/test/data/node/test_rotary_embedding_expanded/model.onnx and b/onnx/backend/test/data/node/test_rotary_embedding_expanded/model.onnx differ
diff --git a/onnx/backend/test/data/node/test_rotary_embedding_interleaved_expanded/model.onnx b/onnx/backend/test/data/node/test_rotary_embedding_interleaved_expanded/model.onnx
index cf87c1430eb..f9649ea13b0 100644
Binary files a/onnx/backend/test/data/node/test_rotary_embedding_interleaved_expanded/model.onnx and b/onnx/backend/test/data/node/test_rotary_embedding_interleaved_expanded/model.onnx differ
diff --git a/onnx/backend/test/data/node/test_rotary_embedding_with_rotary_dim_expanded/model.onnx b/onnx/backend/test/data/node/test_rotary_embedding_with_rotary_dim_expanded/model.onnx
index 0fdbe9696aa..c29ab944987 100644
Binary files a/onnx/backend/test/data/node/test_rotary_embedding_with_rotary_dim_expanded/model.onnx and b/onnx/backend/test/data/node/test_rotary_embedding_with_rotary_dim_expanded/model.onnx differ
diff --git a/onnx/defs/nn/defs.cc b/onnx/defs/nn/defs.cc
index 6ce399f664c..075055817c5 100644
--- a/onnx/defs/nn/defs.cc
+++ b/onnx/defs/nn/defs.cc
@@ -2852,8 +2852,8 @@ Rotary embeddings are defined using the following algorithm:
# Fully or partially perform rotation on input based on rotary_embedding_dim attribute
if rotary_embedding_dim == 0:
- # If rotary_embedding_dim not provided, perform full rotation by using head_size * 2
- rotary_embedding_dim = cos_cache.shape[1] * 2
+ # If rotary_embedding_dim not provided, perform full rotation by using head_size
+ rotary_embedding_dim = head_size * 2
x_rotate = input[:, :, :, :rotary_embedding_dim]
x_not_rotate = input[:, :, :, rotary_embedding_dim:]
rotary_embedding_dim_half = int(rotary_embedding_dim / 2)
@@ -2897,11 +2897,11 @@ ONNX_OPERATOR_SET_SCHEMA(
AttributeProto::INT,
OPTIONAL_VALUE)
.Attr("rotary_embedding_dim",
- "Rotary embedding dimension used to apply partial rotary embeddings. Default value is 0. ",
+ "Rotary embedding dimension used to apply partial rotary embeddings.",
AttributeProto::INT,
OPTIONAL_VALUE)
.Attr("num_heads",
- "Number of attention heads. Default value is 0. Must use with `rotary_embedding_dim`. ",
+ "Number of attention heads. Must use with `rotary_embedding_dim`. ",
AttributeProto::INT,
OPTIONAL_VALUE)
.Input(0,
@@ -2934,6 +2934,16 @@ ONNX_OPERATOR_SET_SCHEMA(
.TypeConstraint("T", {"tensor(float)", "tensor(float16)", "tensor(bfloat16)"}, "Constrain input and output types to float tensors.")
.TypeConstraint("M", {"tensor(int64)"}, "Constrain input and output types to integer tensors")
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
+ auto input_shape = ctx.getInputType(0)->tensor_type().shape();
+ if (input_shape.dim_size() < 3) {
+ return; // Input tensor should have at least three dimensions.
+ }
+
+ auto* num_heads_attr = ctx.getAttribute("num_heads");
+ if ((input_shape.dim_size() == 3) && (num_heads_attr == nullptr)) {
+ fail_shape_inference("Input shape is 3D, num_heads attribute must be provided");
+ }
+
propagateElemTypeFromInputToOutput(ctx, 0, 0);
propagateShapeFromInputToOutput(ctx, 0, 0);
})
@@ -2963,9 +2973,8 @@ ONNX_OPERATOR_SET_SCHEMA(
// There are two cases for the rotary embedding dimension:
// 1. Complete rotation: rotary embedding dimension defaults to head_size, rotary_embedding_dim = cos.shape[3] * 2 or head_size
// 2. Partial rotation: rotary embedding dimension is provided, rotary_embedding_dim = rotary_embedding_dim
- builder.Add("HeadSizeHalf = Shape (cos_cache)") // cos_cache.shape[1] or head_size // 2
+ builder.Add("HeadSize = Shape (XReshaped)") // head_size
.Const1D("Two1D", (int64_t)2)
- .Add("HeadSize = Mul(HeadSizeHalf, Two1D)") // cos.shape[1] * 2 or head_size
.Const1D("RotaryEmbedDimParam", rotary_embedding_dim)
.Const1D("Zero1D", (int64_t)0)
.Add("RotaryDimCond = Greater(RotaryEmbedDimParam, Zero1D)")