Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
vtavana committed Jan 29, 2024
1 parent d3ba4f3 commit 815ec86
Show file tree
Hide file tree
Showing 4 changed files with 142 additions and 102 deletions.
17 changes: 8 additions & 9 deletions dpnp/dpnp_iface_linearalgebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,19 +82,14 @@ def dot(a, b, out=None):
Returns the dot product of `a` and `b`.
If `out` is given, then it is returned.
Limitations
-----------
Parameters `x1` and `x2` are supported as either scalar, :class:`dpnp.ndarray`
or :class:`dpctl.tensor.usm_ndarray`, but both `x1` and `x2` can not be scalars at the same time.
Keyword argument ``kwargs`` is currently unsupported.
Otherwise the functions will be executed sequentially on CPU.
Input array data types are limited by supported DPNP :ref:`Data types`.
See Also
--------
:obj:`dpnp.ndarray.dot` : Equivalent method.
:obj:`dpnp.tensordot` : Sum products over arbitrary axes.
:obj:`dpnp.vdot` : Complex-conjugating dot product.
:obj:`dpnp.einsum` : Einstein summation convention.
:obj:`dpnp.matmul` : Matrix product of two arrays.
:obj:`dpnp.linalg.multi_dot` : Chained dot product.
Examples
--------
Expand Down Expand Up @@ -135,15 +130,19 @@ def dot(a, b, out=None):
raise ValueError("Only C-contiguous array is acceptable.")

if dpnp.isscalar(a) or dpnp.isscalar(b):
# TODO: investigate usage of axpy (axpy_batch) or scal
# functions from BLAS here instead of dpnp.multiply
return dpnp.multiply(a, b, out=out)
elif a.ndim == 0 or b.ndim == 0:
# TODO: investigate usage of axpy (axpy_batch) or scal
# functions from BLAS here instead of dpnp.multiply
return dpnp.multiply(a, b, out=out)
elif a.ndim == 1 and b.ndim == 1:
return dpnp_dot(a, b, out=out)
elif a.ndim == 2 and b.ndim == 2:
# NumPy does not allow casting even if it is safe
return dpnp.matmul(a, b, out=out, casting="no")
elif a.ndim > 1 and b.ndim == 1:
elif a.ndim == 1 or b.ndim == 1:
# NumPy does not allow casting even if it is safe
return dpnp.matmul(a, b, out=out, casting="no")
else:
Expand Down
147 changes: 76 additions & 71 deletions dpnp/dpnp_utils/dpnp_utils_linearalgebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,61 +34,34 @@
__all__ = ["dpnp_dot", "dpnp_matmul"]


def _op_res_dtype(*arrays, dtype, casting, sycl_queue):
def _copy_array(x, dep_events, host_events, contig_copy=False, dtype=None):
"""
_op_res_dtype(*arrays, dtype, casting, sycl_queue)
Determines the output array data type and an intermediate data type
used in performing calculations related to a specific math function.
If dtype is ``None``, the output array data type of the operation is
determined based on the Promotion Type Rule and device capabilities.
Otherwise, `dtype` is used as output array dtype, if input arrays
can cast to it according to the casting rule determined. If casting
cannot be done, a ``TypeError`` is raised.
The intermediate data type is the data type used for performing the math
function calculations. If output array dtype is a floating-point data type,
it is also used for the intermediate data type. If output array dtype is an
integral data type, the default floating point data type of the device where
input arrays are allocated on are used for intermediate data type.
Parameters
----------
arrays : {dpnp.ndarray, usm_ndarray}
Input arrays.
dtype : dtype
If not ``None``, data type of the output array.
casting : {'no', 'equiv', 'safe', 'same_kind', 'unsafe'}, optional
Controls what kind of data casting may occur.
sycl_queue : {SyclQueue}
A SYCL queue to use for determining default floating point datat type.
Creating a copy of input array if needed.
Returns
-------
op_dtype, res_dtype :
`op_dtype` is the data type used in performing math function calculations.
The input arrays of the math function are cast to `op_dtype` and then
the calculations are performed.
`res_dtype` is the output data type. When the result is obtained, it is cast
to `res_dtype`.
If `contig_copy` is ``True``, a C-contiguous copy of input array is returned.
In this case, the copy array has the input array data type unless `dtype` is
determined.
If `contig_copy` is ``False`` and input array data type is different than `dtype`,
a C-contiguous copy of input array with specified `dtype` is returned.
"""

res_dtype = dpnp.result_type(*arrays)
default_dtype = dpnp.default_float_type(sycl_queue=sycl_queue)

if dtype is not None:
if dpnp.can_cast(res_dtype, dtype, casting=casting):
res_dtype = dtype
else:
raise TypeError(
f"Cannot cast ufunc 'matmul' output from dtype({res_dtype}) to dtype({dtype}) with casting rule {casting}"
)

op_dtype = (
res_dtype if dpnp.issubdtype(res_dtype, dpnp.inexact) else default_dtype
)
if contig_copy:
copy = contig_copy
else:
copy = x.dtype != dtype if dtype is not None else False

return op_dtype, res_dtype
if copy:
x_copy = dpnp.empty_like(x, dtype=dtype, order="C")
ht_copy_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
src=dpnp.get_usm_ndarray(x),
dst=x_copy.get_array(),
sycl_queue=x.sycl_queue,
)
dep_events.append(copy_ev)
host_events.append(ht_copy_ev)
return x_copy
return x


def _gemm_batch_matmul(exec_q, x1, x2, res, x1_is_2D, x2_is_2D, dev_tasks_list):
Expand Down Expand Up @@ -153,34 +126,61 @@ def _gemm_batch_matmul(exec_q, x1, x2, res, x1_is_2D, x2_is_2D, dev_tasks_list):
return ht_blas_ev, ht_tasks_list, res


def _copy_array(x, dep_events, host_events, contig_copy=False, dtype=None):
def _op_res_dtype(*arrays, dtype, casting, sycl_queue):
"""
Creating a copy of input array if needed.
_op_res_dtype(*arrays, dtype, casting, sycl_queue)
If `contig_copy` is ``True``, a C-contiguous copy of input array is returned.
In this case, the copy array has the input array data type unless `dtype` is
determined.
If `contig_copy` is ``False`` and input array data type is different than `dtype`,
a C-contiguous copy of input array with specified `dtype` is returned.
Determines the output array data type and an intermediate data type
used in performing calculations related to a specific math function.
If dtype is ``None``, the output array data type of the operation is
determined based on the Promotion Type Rule and device capabilities.
Otherwise, `dtype` is used as output array dtype, if input arrays
can cast to it according to the casting rule determined. If casting
cannot be done, a ``TypeError`` is raised.
The intermediate data type is the data type used for performing the math
function calculations. If output array dtype is a floating-point data type,
it is also used for the intermediate data type. If output array dtype is an
integral data type, the default floating point data type of the device where
input arrays are allocated on are used for intermediate data type.
Parameters
----------
arrays : {dpnp.ndarray, usm_ndarray}
Input arrays.
dtype : dtype
If not ``None``, data type of the output array.
casting : {'no', 'equiv', 'safe', 'same_kind', 'unsafe'}, optional
Controls what kind of data casting may occur.
sycl_queue : {SyclQueue}
A SYCL queue to use for determining default floating point datat type.
Returns
-------
op_dtype, res_dtype :
`op_dtype` is the data type used in performing math function calculations.
The input arrays of the math function are cast to `op_dtype` and then
the calculations are performed.
`res_dtype` is the output data type. When the result is obtained, it is cast
to `res_dtype`.
"""

if contig_copy:
copy = contig_copy
else:
copy = x.dtype != dtype if dtype is not None else False
res_dtype = dpnp.result_type(*arrays)
default_dtype = dpnp.default_float_type(sycl_queue=sycl_queue)

if copy:
x_copy = dpnp.empty_like(x, dtype=dtype, order="C")
ht_copy_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
src=dpnp.get_usm_ndarray(x),
dst=x_copy.get_array(),
sycl_queue=x.sycl_queue,
)
dep_events.append(copy_ev)
host_events.append(ht_copy_ev)
return x_copy
return x
if dtype is not None:
if dpnp.can_cast(res_dtype, dtype, casting=casting):
res_dtype = dtype
else:
raise TypeError(
f"Cannot cast ufunc 'matmul' output from dtype({res_dtype}) to dtype({dtype}) with casting rule {casting}"
)

op_dtype = (
res_dtype if dpnp.issubdtype(res_dtype, dpnp.inexact) else default_dtype
)

return op_dtype, res_dtype


def dpnp_dot(
Expand Down Expand Up @@ -394,6 +394,11 @@ def dpnp_matmul(
dtype=gemm_dtype,
)

# TODO: investigate usage of gemv (gemv_batch) function
# from BLAS when one of the inputs is a vector to
# gain performance.
# TODO: investigate usage of syrk function from BLAS in
# case of a.T @ a and a @ a.T to gain performance.
if x1_is_2D and x2_is_2D:
ht_blas_ev, _ = bi._gemm(
exec_q,
Expand Down
25 changes: 3 additions & 22 deletions tests/third_party/cupy/linalg_tests/test_eigenvalue.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,6 @@ def _get_hermitian(xp, a, UPLO):
return xp.tril(a) + xp.tril(a, k=-1).swapaxes(-2, -1).conj()


# TODO:
# remove once dpnp.dot and dpnp.matmul support complex types
def _wrap_as_numpy_array(xp, a):
return a.asnumpy() if xp is cupy else a


@testing.parameterize(
*testing.product(
{
Expand Down Expand Up @@ -57,20 +51,12 @@ def test_eigh(self, xp, dtype):
else:
tol = 1e-5

# TODO: remove _wrap_as_numpy_array() once @ support complex types
testing.assert_allclose(
_wrap_as_numpy_array(xp, A) @ _wrap_as_numpy_array(xp, v),
_wrap_as_numpy_array(xp, v)
@ numpy.diag(_wrap_as_numpy_array(xp, w)),
atol=tol,
rtol=tol,
)
testing.assert_allclose(A @ v, v @ xp.diag(w), atol=tol, rtol=tol)

# Check if v @ vt is an identity matrix
testing.assert_allclose(
_wrap_as_numpy_array(xp, v)
@ _wrap_as_numpy_array(xp, v).swapaxes(-2, -1).conj(),
numpy.identity(_wrap_as_numpy_array(xp, A).shape[-1], _dtype),
v @ v.swapaxes(-2, -1).conj(),
xp.identity(A.shape[-1], _dtype),
atol=tol,
rtol=tol,
)
Expand Down Expand Up @@ -121,11 +107,6 @@ def test_eigh_complex_batched(self, xp, dtype):
# them through the eigen equation A*v=w*v.
A = _get_hermitian(xp, a, self.UPLO)

# TODO: remove _wrap_as_numpy_array() once dpnp.dot() support complex types
A = _wrap_as_numpy_array(xp, A)
v = _wrap_as_numpy_array(xp, v)
w = _wrap_as_numpy_array(xp, w)

for i in range(a.shape[0]):
testing.assert_allclose(
A[i].dot(v[i]), w[i] * v[i], rtol=1e-5, atol=1e-5
Expand Down
55 changes: 55 additions & 0 deletions tests/third_party/cupy/math_tests/test_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,61 @@ def test_cupy_matmul(self, xp, dtype1):
return xp.matmul(x1, x2)


@testing.parameterize(
*testing.product(
{
"shape_pair": [
# dot test
((2, 3), (3, 4), (2, 4)),
# ((0,), (0,), (0,)),
# matmul test
((5, 3, 2), (5, 2, 4), (5, 3, 4)),
((0, 3, 2), (0, 2, 4), (0, 3, 4)),
],
}
)
)
class TestMatmulOut(unittest.TestCase):
@testing.for_all_dtypes(name="dtype1")
@testing.for_all_dtypes(name="dtype2")
@testing.numpy_cupy_allclose(
rtol=1e-3, atol=1e-3, accept_error=TypeError # required for uint8
)
def test_cupy_matmul_noncontiguous(self, xp, dtype1, dtype2):
x1 = testing.shaped_arange(self.shape_pair[0], xp, dtype1)
x2 = testing.shaped_arange(self.shape_pair[1], xp, dtype2)
out = xp.zeros(self.shape_pair[2], dtype=dtype1)[::-1]
ret = xp.matmul(x1, x2, out=out)
assert ret is out
return ret

@testing.for_all_dtypes(name="dtype1")
@testing.for_all_dtypes(name="dtype2")
@testing.numpy_cupy_allclose(rtol=1e-3, atol=1e-3) # required for uint8
def test_cupy_matmul_out_cast(self, xp, dtype1, dtype2):
x1 = testing.shaped_arange(self.shape_pair[0], xp, dtype1)
x2 = testing.shaped_arange(self.shape_pair[1], xp, dtype2)
out = xp.zeros(self.shape_pair[2], dtype=bool)
ret = xp.matmul(x1, x2, out=out, casting="unsafe")
assert ret is out
return ret


class TestMatmulOutOverlap:
@pytest.mark.parametrize(
"shape",
[
(900, 900),
(2, 600, 600),
],
)
@testing.for_dtypes([numpy.int32, numpy.float64])
@testing.numpy_cupy_allclose(rtol=1e-5, atol=1e-5)
def test_overlap_both(self, xp, dtype, shape):
a = xp.ones(shape, dtype=dtype)
return xp.matmul(a, a, out=a)


@testing.parameterize(
*testing.product(
{
Expand Down

0 comments on commit 815ec86

Please sign in to comment.