@@ -170,8 +170,8 @@ def quaternion_weights(
170170      quaternions in its last dimension. 
171171    quaternion2: A tensor of shape `[A1, ... , An, 4]` storing normalized 
172172      quaternions in its last dimension. 
173-     percent: A `float` or a  tensor with a  shape broadcastable to the shape `[A1,  
174-       ... , An]` . 
173+     percent: A `float` or tensor with shape broadcastable to the shape of input  
174+       vectors . 
175175    eps: A `float` used to make operations safe. When left as None, the function 
176176      automatically picks the best epsilon based on the dtype and the operation. 
177177    name: A name for this op. Defaults to "quaternion_weights". 
@@ -198,7 +198,7 @@ def quaternion_weights(
198198        tensor = quaternion2 , tensor_name = "quaternion2" , has_dim_equals = (- 1 , 4 ))
199199    shape .compare_batch_dimensions (
200200        tensors = (quaternion1 , quaternion2 , percent ),
201-         last_axes = ( - 2 ,  - 2 ,  - 1 ) ,
201+         last_axes = - 1 ,
202202        broadcast_compatible = True ,
203203        tensor_names = ("quaternion1" , "quaternion2" , "percent" ))
204204    quaternion1  =  asserts .assert_normalized (quaternion1 )
@@ -266,7 +266,7 @@ def vector_weights(vector1: type_alias.TensorLike,
266266        tensor_names = ("vector1" , "vector2" ))
267267    shape .compare_batch_dimensions (
268268        tensors = (vector1 , vector2 , percent ),
269-         last_axes = ( - 2 ,  - 2 ,  - 1 ) ,
269+         last_axes = - 1 ,
270270        broadcast_compatible = True ,
271271        tensor_names = ("vector1" , "vector2" , "percent" ))
272272    normalized1  =  tf .nn .l2_normalize (vector1 , axis = - 1 )
0 commit comments