@@ -189,7 +189,7 @@ def _define_contig_flag(x):
189
189
190
190
def _define_dim_flags (x , axis ):
191
191
"""
192
- Define useful flags for the main calculation in dpnp_matmul.
192
+ Define useful flags for the calculations in dpnp_matmul and dpnp_vecdot .
193
193
x_is_1D: `x` is 1D array or inherently 1D (all dimensions are equal to one
194
194
except for one of them), for instance, if x.shape = (1, 1, 1, 2),
195
195
then x_is_1D = True
@@ -220,7 +220,7 @@ def _define_dim_flags(x, axis):
220
220
return x_is_2D , x_is_1D , x_base_is_1D
221
221
222
222
223
- def _get_result_shape (x1 , x2 , out , func , np_flag ):
223
+ def _get_result_shape (x1 , x2 , out , _get_result_shape_fn , np_flag ):
224
224
"""
225
225
Three task are completed in this function:
226
226
- Get the shape of the result array.
@@ -239,15 +239,7 @@ def _get_result_shape(x1, x2, out, func, np_flag):
239
239
"The second input array does not have enough dimensions (has 0, but requires at least 1)"
240
240
)
241
241
242
- if func == "matmul" :
243
- x1 , x2 , result_shape = _get_result_shape_matmul (
244
- x1 , x2 , x1_ndim , x2_ndim
245
- )
246
- else : # func == "vecdot"
247
- assert func == "vecdot"
248
- x1 , x2 , result_shape = _get_result_shape_vecdot (
249
- x1 , x2 , x1_ndim , x2_ndim
250
- )
242
+ x1 , x2 , result_shape = _get_result_shape_fn (x1 , x2 , x1_ndim , x2_ndim )
251
243
252
244
if out is not None :
253
245
out_shape = out .shape
@@ -474,7 +466,7 @@ def _shape_error(shape1, shape2, func, err_msg):
474
466
elif func == "vecdot" :
475
467
signature = "(n?,),(n?,)->()"
476
468
else :
477
- # applicable when err_msg == 3
469
+ # applicable when err_msg == 2
478
470
assert func is None
479
471
480
472
if err_msg == 0 :
@@ -655,7 +647,7 @@ def dpnp_cross(a, b, cp):
655
647
return cp
656
648
657
649
658
- def dpnp_dot (a , b , / , out = None , * , conjugate = False ):
650
+ def dpnp_dot (a , b , / , out = None , * , casting = "same_kind" , conjugate = False ):
659
651
"""
660
652
Return the dot product of two arrays.
661
653
@@ -717,8 +709,7 @@ def dpnp_dot(a, b, /, out=None, *, conjugate=False):
717
709
if dot_dtype != res_dtype :
718
710
result = result .astype (res_dtype , copy = False )
719
711
720
- # numpy.dot does not allow casting even if it is safe
721
- return dpnp .get_result_array (result , out , casting = "no" )
712
+ return dpnp .get_result_array (result , out , casting = casting )
722
713
723
714
724
715
def dpnp_kron (a , b , a_ndim , b_ndim ):
@@ -773,8 +764,10 @@ def dpnp_matmul(
773
764
order = "F"
774
765
else :
775
766
order = "C"
776
-
777
- if order in "kK" :
767
+ elif order in "kK" :
768
+ # For order="K", we return order="C" to align with NumPy behavior
769
+ # It is different than logic used in dpnp_vecdot because NumPy
770
+ # behaves differently for matmul and vecdot
778
771
order = "C"
779
772
780
773
x1_ndim = x1 .ndim
@@ -806,7 +799,7 @@ def dpnp_matmul(
806
799
)
807
800
808
801
x1 , x2 , result_shape = _get_result_shape (
809
- x1 , x2 , out , "matmul" , NumPy_special_behavior
802
+ x1 , x2 , out , _get_result_shape_matmul , NumPy_special_behavior
810
803
)
811
804
812
805
# Determine the appropriate data types
@@ -1000,6 +993,9 @@ def dpnp_vecdot(
1000
993
_validate_out_array (out , exec_q )
1001
994
1002
995
if order in "aAkK" :
996
+ # This logic is also used for order="K" to align with NumPy behavior.
997
+ # It is different than logic used in dpnp_matmul because NumPy
998
+ # behaves differently for matmul and vecdot
1003
999
if x1 .flags .fnc and x2 .flags .fnc :
1004
1000
order = "F"
1005
1001
else :
@@ -1035,7 +1031,7 @@ def dpnp_vecdot(
1035
1031
)
1036
1032
1037
1033
x1 , x2 , result_shape = _get_result_shape (
1038
- x1 , x2 , out , "vecdot" , NumPy_special_behavior
1034
+ x1 , x2 , out , _get_result_shape_vecdot , NumPy_special_behavior
1039
1035
)
1040
1036
1041
1037
# Determine the appropriate data types
@@ -1047,21 +1043,7 @@ def dpnp_vecdot(
1047
1043
_ , x2_is_1D , _ = _define_dim_flags (x2 , axis = - 1 )
1048
1044
1049
1045
if x1 .size == 0 or x2 .size == 0 :
1050
- order = "C" if order in "kK" else order
1051
- result = _create_result_array (
1052
- x1 ,
1053
- x2 ,
1054
- out ,
1055
- shape = result_shape ,
1056
- dtype = res_dtype ,
1057
- usm_type = res_usm_type ,
1058
- sycl_queue = exec_q ,
1059
- order = order ,
1060
- )
1061
- if numpy .prod (result_shape ) == 0 :
1062
- return result
1063
- result .fill (0 )
1064
- return result
1046
+ call_flag = "trivial"
1065
1047
elif x1_is_1D and x2_is_1D :
1066
1048
call_flag = "dot"
1067
1049
# arrays are inehrently 1D, make them 1D
@@ -1072,7 +1054,20 @@ def dpnp_vecdot(
1072
1054
call_flag = "vecdot"
1073
1055
1074
1056
# dispatch to proper function call
1075
- if call_flag == "dot" :
1057
+ if call_flag == "trivial" :
1058
+ result = _create_result_array (
1059
+ x1 ,
1060
+ x2 ,
1061
+ out ,
1062
+ shape = result_shape ,
1063
+ dtype = res_dtype ,
1064
+ usm_type = res_usm_type ,
1065
+ sycl_queue = exec_q ,
1066
+ order = order ,
1067
+ )
1068
+ if numpy .prod (result_shape ) != 0 :
1069
+ result .fill (0 )
1070
+ elif call_flag == "dot" :
1076
1071
if out is not None and out .shape != ():
1077
1072
result = dpnp_dot (x1 , x2 , out = None , conjugate = True )
1078
1073
else :
0 commit comments