Skip to content

Commit

Permalink
Implement asynchronous fill method using dpctl kernels (#2055)
Browse files Browse the repository at this point in the history
* Enhance `dpnp_array.fill` method

Leverages dpctl's strided fill and memset for setting contiguous memory to 0

* Fix missing disclaimer in dpnp_arraycreation.py

* Import `dpnp_array` directly

* Skip `test_fill_with_numpy_scalar_ndarray`

New fill implementation does not permit NumPy array values, consistent with fill_diagonal

* Add dependencies to zeros and full kernels in `dpnp_fill`

* Remove redundant validation of first `dpnp_fill` argument

* Improve `dpnp_fill` array/scalar path logic

* Disallow inputs to `dpnp_fill` on separate queues

* Adjust skip message for `test_fill_with_numpy_scalar_ndarray`

* Tweak error messages in `dpnp_fill`

* Add tests for new `fill` method

* Update docstring for `fill` method

* Fix pre-commit in cupy fill tests

* Change `asarray` to `astype` in `dpnp_fill`

NumPy arrays are no longer permitted and queue coercion does not occur in the `fill` method, so `astype` is sufficient

* Expand TEST_SCOPE to include `test_fill.py`

* Remove redundant check from `dpnp_fill`

* Use `_cast_fill_val` private function from `dpctl.tensor._ctors`

* Add tests per PR review by @antonwolfy

* Improve validation of `val` for `fill` method

* Add to permit NumPy bools as `dpnp_fill` scalar fill values

* Use `dpnp.bool` in `dpnp_fill` and make `isinstance` check more efficient

* Replace branching for `fill` scalar type with `_cast_fill_value`

* Add additional tests for `fill`

`test_fill_non_scalar` now checks for strings and `test_fill_bool` added to verify bools are properly cast to 1

---------

Co-authored-by: Anton <100830759+antonwolfy@users.noreply.github.com>
  • Loading branch information
ndgrigorian and antonwolfy authored Oct 25, 2024
1 parent e236ad9 commit 29239b6
Show file tree
Hide file tree
Showing 6 changed files with 201 additions and 6 deletions.
1 change: 1 addition & 0 deletions .github/workflows/conda-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ env:
test_copy.py
test_counting.py
test_fft.py
test_fill.py
test_flat.py
test_histogram.py
test_indexing.py
Expand Down
26 changes: 26 additions & 0 deletions dpnp/dpnp_algo/dpnp_arraycreation.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,29 @@
# -*- coding: utf-8 -*-
# *****************************************************************************
# Copyright (c) 2016-2024, Intel Corporation
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# - Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
# - Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
# THE POSSIBILITY OF SUCH DAMAGE.
# *****************************************************************************

import math
import operator

Expand Down
78 changes: 78 additions & 0 deletions dpnp/dpnp_algo/dpnp_fill.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# -*- coding: utf-8 -*-
# *****************************************************************************
# Copyright (c) 2016-2024, Intel Corporation
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# - Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
# - Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
# THE POSSIBILITY OF SUCH DAMAGE.
# *****************************************************************************

from numbers import Number

import dpctl.tensor as dpt
import dpctl.utils as dpu
from dpctl.tensor._ctors import _cast_fill_val
from dpctl.tensor._tensor_impl import (
_copy_usm_ndarray_into_usm_ndarray,
_full_usm_ndarray,
_zeros_usm_ndarray,
)

import dpnp


def dpnp_fill(arr, val):
arr = dpnp.get_usm_ndarray(arr)
exec_q = arr.sycl_queue

# if val is an array, process it
if dpnp.is_supported_array_type(val):
val = dpnp.get_usm_ndarray(val)
if val.shape != ():
raise ValueError("`val` must be a scalar or 0D-array")
if dpu.get_execution_queue((exec_q, val.sycl_queue)) is None:
raise dpu.ExecutionPlacementError(
"Input arrays have incompatible queues."
)
a_val = dpt.astype(val, arr.dtype)
a_val = dpt.broadcast_to(a_val, arr.shape)
_manager = dpu.SequentialOrderManager[exec_q]
dep_evs = _manager.submitted_events
h_ev, c_ev = _copy_usm_ndarray_into_usm_ndarray(
src=a_val, dst=arr, sycl_queue=exec_q, depends=dep_evs
)
_manager.add_event_pair(h_ev, c_ev)
return
elif not isinstance(val, (Number, dpnp.bool)):
raise TypeError(
f"array cannot be filled with `val` of type {type(val)}"
)
val = _cast_fill_val(val, arr.dtype)

_manager = dpu.SequentialOrderManager[exec_q]
dep_evs = _manager.submitted_events

# can leverage efficient memset when val is 0
if arr.flags["FORC"] and val == 0:
h_ev, zeros_ev = _zeros_usm_ndarray(arr, exec_q, depends=dep_evs)
_manager.add_event_pair(h_ev, zeros_ev)
else:
h_ev, fill_ev = _full_usm_ndarray(val, arr, exec_q, depends=dep_evs)
_manager.add_event_pair(h_ev, fill_ev)
11 changes: 8 additions & 3 deletions dpnp/dpnp_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -928,13 +928,16 @@ def fill(self, value):
"""
Fill the array with a scalar value.
For full documentation refer to :obj:`numpy.ndarray.fill`.
Parameters
----------
value : scalar
value : {dpnp.ndarray, usm_ndarray, scalar}
All elements of `a` will be assigned this value.
Examples
--------
>>> import dpnp as np
>>> a = np.array([1, 2])
>>> a.fill(0)
>>> a
Expand All @@ -946,8 +949,10 @@ def fill(self, value):
"""

for i in range(self.size):
self.flat[i] = value
# lazy import avoids circular imports
from .dpnp_algo.dpnp_fill import dpnp_fill

dpnp_fill(self, value)

@property
def flags(self):
Expand Down
87 changes: 87 additions & 0 deletions tests/test_fill.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import dpctl
import numpy as np
import pytest
from dpctl.utils import ExecutionPlacementError
from numpy.testing import assert_array_equal

import dpnp as dnp


@pytest.mark.parametrize(
"val, error",
[
pytest.param(dnp.ones(2, dtype="i4"), ValueError, id="array"),
pytest.param(dict(), TypeError, id="dictionary"),
pytest.param("0", TypeError, id="string"),
],
)
def test_fill_non_scalar(val, error):
a = dnp.ones(5, dtype="i4")
with pytest.raises(error):
a.fill(val)


def test_fill_compute_follows_data():
q1 = dpctl.SyclQueue()
q2 = dpctl.SyclQueue()

a = dnp.ones(5, dtype="i4", sycl_queue=q1)
val = dnp.ones((), dtype=a.dtype, sycl_queue=q2)

with pytest.raises(ExecutionPlacementError):
a.fill(val)


def test_fill_strided_array():
a = dnp.zeros(100, dtype="i4")
b = a[::-2]

expected = dnp.tile(dnp.asarray([0, 1], dtype=a.dtype), 50)

b.fill(1)
assert_array_equal(b, 1)
assert_array_equal(a, expected)


@pytest.mark.parametrize("order", ["C", "F"])
def test_fill_strided_2d_array(order):
a = dnp.zeros((10, 10), dtype="i4", order=order)
b = a[::-2, ::2]

expected = dnp.copy(a)
expected[::-2, ::2] = 1

b.fill(1)
assert_array_equal(b, 1)
assert_array_equal(a, expected)


@pytest.mark.parametrize("order", ["C", "F"])
def test_fill_memset(order):
a = dnp.ones((10, 10), dtype="i4", order=order)
a.fill(0)

assert_array_equal(a, 0)


def test_fill_float_complex_to_int():
a = dnp.ones((10, 10), dtype="i4")

a.fill(complex(2, 0))
assert_array_equal(a, 2)

a.fill(float(3))
assert_array_equal(a, 3)


def test_fill_complex_to_float():
a = dnp.ones((10, 10), dtype="f4")

a.fill(complex(2, 0))
assert_array_equal(a, 2)


def test_fill_bool():
a = dnp.full(5, fill_value=7, dtype="i4")
a.fill(True)
assert_array_equal(a, 1)
Original file line number Diff line number Diff line change
Expand Up @@ -274,9 +274,7 @@ def test_fill(self, xp, dtype):
a.fill(1)
return a

@testing.with_requires("numpy>=1.24.0")
@testing.for_all_dtypes_combination(("dtype1", "dtype2"))
@testing.numpy_cupy_array_equal(accept_error=ComplexWarning)
@pytest.mark.skip("Numpy allows Numpy scalar arrays as fill value")
def test_fill_with_numpy_scalar_ndarray(self, xp, dtype1, dtype2):
a = testing.shaped_arange((2, 3, 4), xp, dtype1)
a.fill(numpy.ones((), dtype=dtype2))
Expand Down

0 comments on commit 29239b6

Please sign in to comment.