diff --git a/docs/Changelog.md b/docs/Changelog.md index c538f2b216c..61bbeabea20 100644 --- a/docs/Changelog.md +++ b/docs/Changelog.md @@ -29046,8 +29046,8 @@ This version of the operator has been available since version 23 of the default # Fully or partially perform rotation on input based on rotary_embedding_dim attribute if rotary_embedding_dim == 0: - # If rotary_embedding_dim not provided, perform full rotation by using head_size * 2 - rotary_embedding_dim = cos_cache.shape[1] * 2 + # If rotary_embedding_dim not provided, perform full rotation by using head_size + rotary_embedding_dim = head_size * 2 x_rotate = input[:, :, :, :rotary_embedding_dim] x_not_rotate = input[:, :, :, rotary_embedding_dim:] rotary_embedding_dim_half = int(rotary_embedding_dim / 2) @@ -29090,9 +29090,9 @@ This version of the operator has been available since version 23 of the default
interleaved : int
Rotate using interleaved pattern. Default value is 0 (False).
num_heads : int
-
Number of attention heads. Default value is 0. Must use with `rotary_embedding_dim`.
+
Number of attention heads. Must use with `rotary_embedding_dim`.
rotary_embedding_dim : int
-
Rotary embedding dimension used to apply partial rotary embeddings. Default value is 0.
+
Rotary embedding dimension used to apply partial rotary embeddings.
#### Inputs diff --git a/docs/Operators.md b/docs/Operators.md index bafb5a1a1fb..d41407665fc 100644 --- a/docs/Operators.md +++ b/docs/Operators.md @@ -27331,8 +27331,8 @@ expect( # Fully or partially perform rotation on input based on rotary_embedding_dim attribute if rotary_embedding_dim == 0: - # If rotary_embedding_dim not provided, perform full rotation by using head_size * 2 - rotary_embedding_dim = cos_cache.shape[1] * 2 + # If rotary_embedding_dim not provided, perform full rotation by using head_size + rotary_embedding_dim = head_size * 2 x_rotate = input[:, :, :, :rotary_embedding_dim] x_not_rotate = input[:, :, :, rotary_embedding_dim:] rotary_embedding_dim_half = int(rotary_embedding_dim / 2) @@ -27375,9 +27375,9 @@ This version of the operator has been available since version 23 of the default
interleaved : int
Rotate using interleaved pattern. Default value is 0 (False).
num_heads : int
-
Number of attention heads. Default value is 0. Must use with `rotary_embedding_dim`.
+
Number of attention heads. Must use with `rotary_embedding_dim`.
rotary_embedding_dim : int
-
Rotary embedding dimension used to apply partial rotary embeddings. Default value is 0.
+
Rotary embedding dimension used to apply partial rotary embeddings.
#### Inputs diff --git a/onnx/backend/test/case/node/rotaryembedding.py b/onnx/backend/test/case/node/rotaryembedding.py index 797092b0ee9..81d8e3b0d4f 100644 --- a/onnx/backend/test/case/node/rotaryembedding.py +++ b/onnx/backend/test/case/node/rotaryembedding.py @@ -34,8 +34,8 @@ def compute_rotary_embedding( # Fully or partially perform rotation on input based on rotary_embedding_dim attribute if rotary_embedding_dim == 0: - # If rotary_embedding_dim not provided, perform full rotation by using head_size * 2 - rotary_embedding_dim = cos_cache.shape[1] * 2 + # If rotary_embedding_dim not provided, perform full rotation by using head_size + rotary_embedding_dim = head_size x_rotate = input[:, :, :, :rotary_embedding_dim] x_not_rotate = input[:, :, :, rotary_embedding_dim:] rotary_embedding_dim_half = int(rotary_embedding_dim / 2) diff --git a/onnx/backend/test/data/node/test_rotary_embedding_3d_input_expanded/model.onnx b/onnx/backend/test/data/node/test_rotary_embedding_3d_input_expanded/model.onnx index 0de50cf20bf..317957cd22f 100644 Binary files a/onnx/backend/test/data/node/test_rotary_embedding_3d_input_expanded/model.onnx and b/onnx/backend/test/data/node/test_rotary_embedding_3d_input_expanded/model.onnx differ diff --git a/onnx/backend/test/data/node/test_rotary_embedding_expanded/model.onnx b/onnx/backend/test/data/node/test_rotary_embedding_expanded/model.onnx index c7a2d657462..4f0344ec20e 100644 Binary files a/onnx/backend/test/data/node/test_rotary_embedding_expanded/model.onnx and b/onnx/backend/test/data/node/test_rotary_embedding_expanded/model.onnx differ diff --git a/onnx/backend/test/data/node/test_rotary_embedding_interleaved_expanded/model.onnx b/onnx/backend/test/data/node/test_rotary_embedding_interleaved_expanded/model.onnx index cf87c1430eb..f9649ea13b0 100644 Binary files a/onnx/backend/test/data/node/test_rotary_embedding_interleaved_expanded/model.onnx and b/onnx/backend/test/data/node/test_rotary_embedding_interleaved_expanded/model.onnx differ diff --git a/onnx/backend/test/data/node/test_rotary_embedding_with_rotary_dim_expanded/model.onnx b/onnx/backend/test/data/node/test_rotary_embedding_with_rotary_dim_expanded/model.onnx index 0fdbe9696aa..c29ab944987 100644 Binary files a/onnx/backend/test/data/node/test_rotary_embedding_with_rotary_dim_expanded/model.onnx and b/onnx/backend/test/data/node/test_rotary_embedding_with_rotary_dim_expanded/model.onnx differ diff --git a/onnx/defs/nn/defs.cc b/onnx/defs/nn/defs.cc index 6ce399f664c..075055817c5 100644 --- a/onnx/defs/nn/defs.cc +++ b/onnx/defs/nn/defs.cc @@ -2852,8 +2852,8 @@ Rotary embeddings are defined using the following algorithm: # Fully or partially perform rotation on input based on rotary_embedding_dim attribute if rotary_embedding_dim == 0: - # If rotary_embedding_dim not provided, perform full rotation by using head_size * 2 - rotary_embedding_dim = cos_cache.shape[1] * 2 + # If rotary_embedding_dim not provided, perform full rotation by using head_size + rotary_embedding_dim = head_size * 2 x_rotate = input[:, :, :, :rotary_embedding_dim] x_not_rotate = input[:, :, :, rotary_embedding_dim:] rotary_embedding_dim_half = int(rotary_embedding_dim / 2) @@ -2897,11 +2897,11 @@ ONNX_OPERATOR_SET_SCHEMA( AttributeProto::INT, OPTIONAL_VALUE) .Attr("rotary_embedding_dim", - "Rotary embedding dimension used to apply partial rotary embeddings. Default value is 0. ", + "Rotary embedding dimension used to apply partial rotary embeddings.", AttributeProto::INT, OPTIONAL_VALUE) .Attr("num_heads", - "Number of attention heads. Default value is 0. Must use with `rotary_embedding_dim`. ", + "Number of attention heads. Must use with `rotary_embedding_dim`. ", AttributeProto::INT, OPTIONAL_VALUE) .Input(0, @@ -2934,6 +2934,16 @@ ONNX_OPERATOR_SET_SCHEMA( .TypeConstraint("T", {"tensor(float)", "tensor(float16)", "tensor(bfloat16)"}, "Constrain input and output types to float tensors.") .TypeConstraint("M", {"tensor(int64)"}, "Constrain input and output types to integer tensors") .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { + auto input_shape = ctx.getInputType(0)->tensor_type().shape(); + if (input_shape.dim_size() < 3) { + return; // Input tensor should have at least three dimensions. + } + + auto* num_heads_attr = ctx.getAttribute("num_heads"); + if ((input_shape.dim_size() == 3) && (num_heads_attr == nullptr)) { + fail_shape_inference("Input shape is 3D, num_heads attribute must be provided"); + } + propagateElemTypeFromInputToOutput(ctx, 0, 0); propagateShapeFromInputToOutput(ctx, 0, 0); }) @@ -2963,9 +2973,8 @@ ONNX_OPERATOR_SET_SCHEMA( // There are two cases for the rotary embedding dimension: // 1. Complete rotation: rotary embedding dimension defaults to head_size, rotary_embedding_dim = cos.shape[3] * 2 or head_size // 2. Partial rotation: rotary embedding dimension is provided, rotary_embedding_dim = rotary_embedding_dim - builder.Add("HeadSizeHalf = Shape (cos_cache)") // cos_cache.shape[1] or head_size // 2 + builder.Add("HeadSize = Shape (XReshaped)") // head_size .Const1D("Two1D", (int64_t)2) - .Add("HeadSize = Mul(HeadSizeHalf, Two1D)") // cos.shape[1] * 2 or head_size .Const1D("RotaryEmbedDimParam", rotary_embedding_dim) .Const1D("Zero1D", (int64_t)0) .Add("RotaryDimCond = Greater(RotaryEmbedDimParam, Zero1D)")