From 815ec86d2c4e9c4c919e0671451a22d309accbe3 Mon Sep 17 00:00:00 2001 From: Vahid Tavanashad Date: Fri, 26 Jan 2024 15:43:36 -0600 Subject: [PATCH] address comments --- dpnp/dpnp_iface_linearalgebra.py | 17 +- dpnp/dpnp_utils/dpnp_utils_linearalgebra.py | 147 +++++++++--------- .../cupy/linalg_tests/test_eigenvalue.py | 25 +-- .../cupy/math_tests/test_matmul.py | 55 +++++++ 4 files changed, 142 insertions(+), 102 deletions(-) diff --git a/dpnp/dpnp_iface_linearalgebra.py b/dpnp/dpnp_iface_linearalgebra.py index 0213a23e265..ae7b7bf58c5 100644 --- a/dpnp/dpnp_iface_linearalgebra.py +++ b/dpnp/dpnp_iface_linearalgebra.py @@ -82,19 +82,14 @@ def dot(a, b, out=None): Returns the dot product of `a` and `b`. If `out` is given, then it is returned. - Limitations - ----------- - Parameters `x1` and `x2` are supported as either scalar, :class:`dpnp.ndarray` - or :class:`dpctl.tensor.usm_ndarray`, but both `x1` and `x2` can not be scalars at the same time. - Keyword argument ``kwargs`` is currently unsupported. - Otherwise the functions will be executed sequentially on CPU. - Input array data types are limited by supported DPNP :ref:`Data types`. - See Also -------- :obj:`dpnp.ndarray.dot` : Equivalent method. :obj:`dpnp.tensordot` : Sum products over arbitrary axes. :obj:`dpnp.vdot` : Complex-conjugating dot product. + :obj:`dpnp.einsum` : Einstein summation convention. + :obj:`dpnp.matmul` : Matrix product of two arrays. + :obj:`dpnp.linalg.multi_dot` : Chained dot product. Examples -------- @@ -135,15 +130,19 @@ def dot(a, b, out=None): raise ValueError("Only C-contiguous array is acceptable.") if dpnp.isscalar(a) or dpnp.isscalar(b): + # TODO: investigate usage of axpy (axpy_batch) or scal + # functions from BLAS here instead of dpnp.multiply return dpnp.multiply(a, b, out=out) elif a.ndim == 0 or b.ndim == 0: + # TODO: investigate usage of axpy (axpy_batch) or scal + # functions from BLAS here instead of dpnp.multiply return dpnp.multiply(a, b, out=out) elif a.ndim == 1 and b.ndim == 1: return dpnp_dot(a, b, out=out) elif a.ndim == 2 and b.ndim == 2: # NumPy does not allow casting even if it is safe return dpnp.matmul(a, b, out=out, casting="no") - elif a.ndim > 1 and b.ndim == 1: + elif a.ndim == 1 or b.ndim == 1: # NumPy does not allow casting even if it is safe return dpnp.matmul(a, b, out=out, casting="no") else: diff --git a/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py b/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py index 5dafff9481f..4efc9527e54 100644 --- a/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py +++ b/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py @@ -34,61 +34,34 @@ __all__ = ["dpnp_dot", "dpnp_matmul"] -def _op_res_dtype(*arrays, dtype, casting, sycl_queue): +def _copy_array(x, dep_events, host_events, contig_copy=False, dtype=None): """ - _op_res_dtype(*arrays, dtype, casting, sycl_queue) - - Determines the output array data type and an intermediate data type - used in performing calculations related to a specific math function. - If dtype is ``None``, the output array data type of the operation is - determined based on the Promotion Type Rule and device capabilities. - Otherwise, `dtype` is used as output array dtype, if input arrays - can cast to it according to the casting rule determined. If casting - cannot be done, a ``TypeError`` is raised. - The intermediate data type is the data type used for performing the math - function calculations. If output array dtype is a floating-point data type, - it is also used for the intermediate data type. If output array dtype is an - integral data type, the default floating point data type of the device where - input arrays are allocated on are used for intermediate data type. - - Parameters - ---------- - arrays : {dpnp.ndarray, usm_ndarray} - Input arrays. - dtype : dtype - If not ``None``, data type of the output array. - casting : {'no', 'equiv', 'safe', 'same_kind', 'unsafe'}, optional - Controls what kind of data casting may occur. - sycl_queue : {SyclQueue} - A SYCL queue to use for determining default floating point datat type. + Creating a copy of input array if needed. - Returns - ------- - op_dtype, res_dtype : - `op_dtype` is the data type used in performing math function calculations. - The input arrays of the math function are cast to `op_dtype` and then - the calculations are performed. - `res_dtype` is the output data type. When the result is obtained, it is cast - to `res_dtype`. + If `contig_copy` is ``True``, a C-contiguous copy of input array is returned. + In this case, the copy array has the input array data type unless `dtype` is + determined. + If `contig_copy` is ``False`` and input array data type is different than `dtype`, + a C-contiguous copy of input array with specified `dtype` is returned. """ - res_dtype = dpnp.result_type(*arrays) - default_dtype = dpnp.default_float_type(sycl_queue=sycl_queue) - - if dtype is not None: - if dpnp.can_cast(res_dtype, dtype, casting=casting): - res_dtype = dtype - else: - raise TypeError( - f"Cannot cast ufunc 'matmul' output from dtype({res_dtype}) to dtype({dtype}) with casting rule {casting}" - ) - - op_dtype = ( - res_dtype if dpnp.issubdtype(res_dtype, dpnp.inexact) else default_dtype - ) + if contig_copy: + copy = contig_copy + else: + copy = x.dtype != dtype if dtype is not None else False - return op_dtype, res_dtype + if copy: + x_copy = dpnp.empty_like(x, dtype=dtype, order="C") + ht_copy_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=dpnp.get_usm_ndarray(x), + dst=x_copy.get_array(), + sycl_queue=x.sycl_queue, + ) + dep_events.append(copy_ev) + host_events.append(ht_copy_ev) + return x_copy + return x def _gemm_batch_matmul(exec_q, x1, x2, res, x1_is_2D, x2_is_2D, dev_tasks_list): @@ -153,34 +126,61 @@ def _gemm_batch_matmul(exec_q, x1, x2, res, x1_is_2D, x2_is_2D, dev_tasks_list): return ht_blas_ev, ht_tasks_list, res -def _copy_array(x, dep_events, host_events, contig_copy=False, dtype=None): +def _op_res_dtype(*arrays, dtype, casting, sycl_queue): """ - Creating a copy of input array if needed. + _op_res_dtype(*arrays, dtype, casting, sycl_queue) - If `contig_copy` is ``True``, a C-contiguous copy of input array is returned. - In this case, the copy array has the input array data type unless `dtype` is - determined. - If `contig_copy` is ``False`` and input array data type is different than `dtype`, - a C-contiguous copy of input array with specified `dtype` is returned. + Determines the output array data type and an intermediate data type + used in performing calculations related to a specific math function. + If dtype is ``None``, the output array data type of the operation is + determined based on the Promotion Type Rule and device capabilities. + Otherwise, `dtype` is used as output array dtype, if input arrays + can cast to it according to the casting rule determined. If casting + cannot be done, a ``TypeError`` is raised. + The intermediate data type is the data type used for performing the math + function calculations. If output array dtype is a floating-point data type, + it is also used for the intermediate data type. If output array dtype is an + integral data type, the default floating point data type of the device where + input arrays are allocated on are used for intermediate data type. + + Parameters + ---------- + arrays : {dpnp.ndarray, usm_ndarray} + Input arrays. + dtype : dtype + If not ``None``, data type of the output array. + casting : {'no', 'equiv', 'safe', 'same_kind', 'unsafe'}, optional + Controls what kind of data casting may occur. + sycl_queue : {SyclQueue} + A SYCL queue to use for determining default floating point datat type. + + Returns + ------- + op_dtype, res_dtype : + `op_dtype` is the data type used in performing math function calculations. + The input arrays of the math function are cast to `op_dtype` and then + the calculations are performed. + `res_dtype` is the output data type. When the result is obtained, it is cast + to `res_dtype`. """ - if contig_copy: - copy = contig_copy - else: - copy = x.dtype != dtype if dtype is not None else False + res_dtype = dpnp.result_type(*arrays) + default_dtype = dpnp.default_float_type(sycl_queue=sycl_queue) - if copy: - x_copy = dpnp.empty_like(x, dtype=dtype, order="C") - ht_copy_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( - src=dpnp.get_usm_ndarray(x), - dst=x_copy.get_array(), - sycl_queue=x.sycl_queue, - ) - dep_events.append(copy_ev) - host_events.append(ht_copy_ev) - return x_copy - return x + if dtype is not None: + if dpnp.can_cast(res_dtype, dtype, casting=casting): + res_dtype = dtype + else: + raise TypeError( + f"Cannot cast ufunc 'matmul' output from dtype({res_dtype}) to dtype({dtype}) with casting rule {casting}" + ) + + op_dtype = ( + res_dtype if dpnp.issubdtype(res_dtype, dpnp.inexact) else default_dtype + ) + + return op_dtype, res_dtype def dpnp_dot( @@ -394,6 +394,11 @@ def dpnp_matmul( dtype=gemm_dtype, ) + # TODO: investigate usage of gemv (gemv_batch) function + # from BLAS when one of the inputs is a vector to + # gain performance. + # TODO: investigate usage of syrk function from BLAS in + # case of a.T @ a and a @ a.T to gain performance. if x1_is_2D and x2_is_2D: ht_blas_ev, _ = bi._gemm( exec_q, diff --git a/tests/third_party/cupy/linalg_tests/test_eigenvalue.py b/tests/third_party/cupy/linalg_tests/test_eigenvalue.py index 99dcfb2127c..b620bd39e98 100644 --- a/tests/third_party/cupy/linalg_tests/test_eigenvalue.py +++ b/tests/third_party/cupy/linalg_tests/test_eigenvalue.py @@ -15,12 +15,6 @@ def _get_hermitian(xp, a, UPLO): return xp.tril(a) + xp.tril(a, k=-1).swapaxes(-2, -1).conj() -# TODO: -# remove once dpnp.dot and dpnp.matmul support complex types -def _wrap_as_numpy_array(xp, a): - return a.asnumpy() if xp is cupy else a - - @testing.parameterize( *testing.product( { @@ -57,20 +51,12 @@ def test_eigh(self, xp, dtype): else: tol = 1e-5 - # TODO: remove _wrap_as_numpy_array() once @ support complex types - testing.assert_allclose( - _wrap_as_numpy_array(xp, A) @ _wrap_as_numpy_array(xp, v), - _wrap_as_numpy_array(xp, v) - @ numpy.diag(_wrap_as_numpy_array(xp, w)), - atol=tol, - rtol=tol, - ) + testing.assert_allclose(A @ v, v @ xp.diag(w), atol=tol, rtol=tol) # Check if v @ vt is an identity matrix testing.assert_allclose( - _wrap_as_numpy_array(xp, v) - @ _wrap_as_numpy_array(xp, v).swapaxes(-2, -1).conj(), - numpy.identity(_wrap_as_numpy_array(xp, A).shape[-1], _dtype), + v @ v.swapaxes(-2, -1).conj(), + xp.identity(A.shape[-1], _dtype), atol=tol, rtol=tol, ) @@ -121,11 +107,6 @@ def test_eigh_complex_batched(self, xp, dtype): # them through the eigen equation A*v=w*v. A = _get_hermitian(xp, a, self.UPLO) - # TODO: remove _wrap_as_numpy_array() once dpnp.dot() support complex types - A = _wrap_as_numpy_array(xp, A) - v = _wrap_as_numpy_array(xp, v) - w = _wrap_as_numpy_array(xp, w) - for i in range(a.shape[0]): testing.assert_allclose( A[i].dot(v[i]), w[i] * v[i], rtol=1e-5, atol=1e-5 diff --git a/tests/third_party/cupy/math_tests/test_matmul.py b/tests/third_party/cupy/math_tests/test_matmul.py index d21ec7a2d68..887ed9ae1b9 100644 --- a/tests/third_party/cupy/math_tests/test_matmul.py +++ b/tests/third_party/cupy/math_tests/test_matmul.py @@ -73,6 +73,61 @@ def test_cupy_matmul(self, xp, dtype1): return xp.matmul(x1, x2) +@testing.parameterize( + *testing.product( + { + "shape_pair": [ + # dot test + ((2, 3), (3, 4), (2, 4)), + # ((0,), (0,), (0,)), + # matmul test + ((5, 3, 2), (5, 2, 4), (5, 3, 4)), + ((0, 3, 2), (0, 2, 4), (0, 3, 4)), + ], + } + ) +) +class TestMatmulOut(unittest.TestCase): + @testing.for_all_dtypes(name="dtype1") + @testing.for_all_dtypes(name="dtype2") + @testing.numpy_cupy_allclose( + rtol=1e-3, atol=1e-3, accept_error=TypeError # required for uint8 + ) + def test_cupy_matmul_noncontiguous(self, xp, dtype1, dtype2): + x1 = testing.shaped_arange(self.shape_pair[0], xp, dtype1) + x2 = testing.shaped_arange(self.shape_pair[1], xp, dtype2) + out = xp.zeros(self.shape_pair[2], dtype=dtype1)[::-1] + ret = xp.matmul(x1, x2, out=out) + assert ret is out + return ret + + @testing.for_all_dtypes(name="dtype1") + @testing.for_all_dtypes(name="dtype2") + @testing.numpy_cupy_allclose(rtol=1e-3, atol=1e-3) # required for uint8 + def test_cupy_matmul_out_cast(self, xp, dtype1, dtype2): + x1 = testing.shaped_arange(self.shape_pair[0], xp, dtype1) + x2 = testing.shaped_arange(self.shape_pair[1], xp, dtype2) + out = xp.zeros(self.shape_pair[2], dtype=bool) + ret = xp.matmul(x1, x2, out=out, casting="unsafe") + assert ret is out + return ret + + +class TestMatmulOutOverlap: + @pytest.mark.parametrize( + "shape", + [ + (900, 900), + (2, 600, 600), + ], + ) + @testing.for_dtypes([numpy.int32, numpy.float64]) + @testing.numpy_cupy_allclose(rtol=1e-5, atol=1e-5) + def test_overlap_both(self, xp, dtype, shape): + a = xp.ones(shape, dtype=dtype) + return xp.matmul(a, a, out=a) + + @testing.parameterize( *testing.product( {