Skip to content

Commit ca0ef9b

Browse files
Merge pull request #816 from IntelPython/add-zeros
Adding dpctl.tensor.zeros
1 parent 7528ce8 commit ca0ef9b

File tree

3 files changed

+74
-1
lines changed

3 files changed

+74
-1
lines changed

dpctl/tensor/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
"""
2323

2424
from dpctl.tensor._copy_utils import asnumpy, astype, copy, from_numpy, to_numpy
25-
from dpctl.tensor._ctors import arange, asarray, empty
25+
from dpctl.tensor._ctors import arange, asarray, empty, zeros
2626
from dpctl.tensor._device import Device
2727
from dpctl.tensor._dlpack import from_dlpack
2828
from dpctl.tensor._manipulation_functions import (
@@ -45,6 +45,7 @@
4545
"astype",
4646
"copy",
4747
"empty",
48+
"zeros",
4849
"flip",
4950
"reshape",
5051
"roll",

dpctl/tensor/_ctors.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -512,3 +512,51 @@ def arange(
512512
hev, _ = ti._linspace_step(start, _step, res, sycl_queue)
513513
hev.wait()
514514
return res
515+
516+
517+
def zeros(
518+
sh, dtype="f8", order="C", device=None, usm_type="device", sycl_queue=None
519+
):
520+
"""
521+
Creates `usm_ndarray` with zero elements.
522+
523+
Args:
524+
shape (tuple): Dimensions of the array to be created.
525+
dtype (optional): data type of the array. Can be typestring,
526+
a `numpy.dtype` object, `numpy` char string, or a numpy
527+
scalar type. Default: "f8"
528+
order ("C", or F"): memory layout for the array. Default: "C"
529+
device (optional): array API concept of device where the output array
530+
is created. `device` can be `None`, a oneAPI filter selector string,
531+
an instance of :class:`dpctl.SyclDevice` corresponding to a
532+
non-partitioned SYCL device, an instance of
533+
:class:`dpctl.SyclQueue`, or a `Device` object returnedby
534+
`dpctl.tensor.usm_array.device`. Default: `None`.
535+
usm_type ("device"|"shared"|"host", optional): The type of SYCL USM
536+
allocation for the output array. Default: `"device"`.
537+
sycl_queue (:class:`dpctl.SyclQueue`, optional): The SYCL queue to use
538+
for output array allocation and copying. `sycl_queue` and `device`
539+
are exclusive keywords, i.e. use one or another. If both are
540+
specified, a `TypeError` is raised unless both imply the same
541+
underlying SYCL queue to be used. If both a `None`, the
542+
`dpctl.SyclQueue()` is used for allocation and copying.
543+
Default: `None`.
544+
"""
545+
dtype = np.dtype(dtype)
546+
if not isinstance(order, str) or len(order) == 0 or order[0] not in "CcFf":
547+
raise ValueError(
548+
"Unrecognized order keyword value, expecting 'F' or 'C'."
549+
)
550+
else:
551+
order = order[0].upper()
552+
dpctl.utils.validate_usm_type(usm_type, allow_none=False)
553+
sycl_queue = normalize_queue_device(sycl_queue=sycl_queue, device=device)
554+
res = dpt.usm_ndarray(
555+
sh,
556+
dtype=dtype,
557+
buffer=usm_type,
558+
order=order,
559+
buffer_ctor_kwargs={"queue": sycl_queue},
560+
)
561+
res.usm_data.memset()
562+
return res

dpctl/tests/test_usm_ndarray_ctor.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -957,3 +957,27 @@ def test_real_imag_views():
957957
assert np.array_equal(dpt.to_numpy(X.imag), Xnp.imag)
958958
assert np.array_equal(dpt.to_numpy(X[1:].real), Xnp[1:].real)
959959
assert np.array_equal(dpt.to_numpy(X[1:].imag), Xnp[1:].imag)
960+
961+
962+
@pytest.mark.parametrize(
963+
"dtype",
964+
[
965+
"b1",
966+
"i1",
967+
"u1",
968+
"i2",
969+
"u2",
970+
"i4",
971+
"u4",
972+
"i8",
973+
"u8",
974+
"f2",
975+
"f4",
976+
"f8",
977+
"c8",
978+
"c16",
979+
],
980+
)
981+
def test_zeros(dtype):
982+
X = dpt.zeros(10, dtype=dtype)
983+
assert np.array_equal(dpt.asnumpy(X), np.zeros(10, dtype=dtype))

0 commit comments

Comments
 (0)