Skip to content

Commit 76ef808

Browse files
authored
Merge branch 'master' into dependabot/github_actions/github/codeql-action-3.28.1
2 parents 02c3ef8 + e3b7d07 commit 76ef808

File tree

5 files changed

+83
-25
lines changed

5 files changed

+83
-25
lines changed

.github/workflows/conda-package.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,13 +99,13 @@ jobs:
9999
MAX_BUILD_CMPL_MKL_VERSION: '2025.1a0'
100100

101101
- name: Upload artifact
102-
uses: actions/upload-artifact@6f51ac03b9356f520e9adb1b1b7802705f340c2b # v4.5.0
102+
uses: actions/upload-artifact@65c4c4a1ddee5b72f698fdd19549f0f0fb45cf08 # v4.6.0
103103
with:
104104
name: ${{ env.PACKAGE_NAME }} ${{ runner.os }} Python ${{ matrix.python }}
105105
path: ${{ env.CONDA_BLD }}${{ env.PACKAGE_NAME }}-*.tar.bz2
106106

107107
- name: Upload wheels artifact
108-
uses: actions/upload-artifact@6f51ac03b9356f520e9adb1b1b7802705f340c2b # v4.5.0
108+
uses: actions/upload-artifact@65c4c4a1ddee5b72f698fdd19549f0f0fb45cf08 # v4.6.0
109109
with:
110110
name: ${{ env.PACKAGE_NAME }} ${{ runner.os }} Wheels Python ${{ matrix.python }}
111111
path: ${{ env.WHEELS_OUTPUT_FOLDER }}${{ env.PACKAGE_NAME }}-*.whl

.github/workflows/openssf-scorecard.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ jobs:
6060
# Upload the results as artifacts (optional). Commenting out will disable uploads of run results in SARIF
6161
# format to the repository Actions tab.
6262
- name: "Upload artifact"
63-
uses: actions/upload-artifact@6f51ac03b9356f520e9adb1b1b7802705f340c2b # v4.5.0
63+
uses: actions/upload-artifact@65c4c4a1ddee5b72f698fdd19549f0f0fb45cf08 # v4.6.0
6464
with:
6565
name: SARIF file
6666
path: results.sarif

dpnp/dpnp_array.py

Lines changed: 31 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ def __init__(
9494
offset=offset,
9595
order=order,
9696
buffer_ctor_kwargs={"queue": sycl_queue_normalized},
97+
array_namespace=dpnp,
9798
)
9899

99100
@property
@@ -201,6 +202,31 @@ def __and__(self, other):
201202
# '__array_ufunc__',
202203
# '__array_wrap__',
203204

205+
def __array_namespace__(self, /, *, api_version=None):
206+
"""
207+
Returns array namespace, member functions of which implement data API.
208+
209+
Parameters
210+
----------
211+
api_version : str, optional
212+
Request namespace compliant with given version of array API. If
213+
``None``, namespace for the most recent supported version is
214+
returned.
215+
Default: ``None``.
216+
217+
Returns
218+
-------
219+
out : any
220+
An object representing the array API namespace. It should have
221+
every top-level function defined in the specification as
222+
an attribute. It may contain other public names as well, but it is
223+
recommended to only include those names that are part of the
224+
specification.
225+
226+
"""
227+
228+
return self._array_obj.__array_namespace__(api_version=api_version)
229+
204230
def __bool__(self):
205231
"""``True`` if self else ``False``."""
206232
return self._array_obj.__bool__()
@@ -327,15 +353,7 @@ def __getitem__(self, key):
327353
key = _get_unwrapped_index_key(key)
328354

329355
item = self._array_obj.__getitem__(key)
330-
if not isinstance(item, dpt.usm_ndarray):
331-
raise RuntimeError(
332-
"Expected dpctl.tensor.usm_ndarray, got {}"
333-
"".format(type(item))
334-
)
335-
336-
res = self.__new__(dpnp_array)
337-
res._array_obj = item
338-
return res
356+
return dpnp_array._create_from_usm_ndarray(item)
339357

340358
# '__getstate__',
341359

@@ -606,6 +624,7 @@ def _create_from_usm_ndarray(usm_ary: dpt.usm_ndarray):
606624
)
607625
res = dpnp_array.__new__(dpnp_array)
608626
res._array_obj = usm_ary
627+
res._array_obj._set_namespace(dpnp)
609628
return res
610629

611630
def all(self, axis=None, out=None, keepdims=False, *, where=True):
@@ -1749,17 +1768,16 @@ def transpose(self, *axes):
17491768
if axes_len == 1 and isinstance(axes[0], (tuple, list)):
17501769
axes = axes[0]
17511770

1752-
res = self.__new__(dpnp_array)
17531771
if ndim == 2 and axes_len == 0:
1754-
res._array_obj = self._array_obj.T
1772+
usm_res = self._array_obj.T
17551773
else:
17561774
if len(axes) == 0 or axes[0] is None:
17571775
# self.transpose().shape == self.shape[::-1]
17581776
# self.transpose(None).shape == self.shape[::-1]
17591777
axes = tuple((ndim - x - 1) for x in range(ndim))
17601778

1761-
res._array_obj = dpt.permute_dims(self._array_obj, axes)
1762-
return res
1779+
usm_res = dpt.permute_dims(self._array_obj, axes)
1780+
return dpnp_array._create_from_usm_ndarray(usm_res)
17631781

17641782
def var(
17651783
self,

dpnp/dpnp_iface_indexing.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -622,14 +622,8 @@ def diagonal(a, offset=0, axis1=0, axis2=1):
622622
out_strides = a_straides[:-2] + (1,)
623623
out_offset = a_element_offset
624624

625-
return dpnp_array._create_from_usm_ndarray(
626-
dpt.usm_ndarray(
627-
out_shape,
628-
dtype=a.dtype,
629-
buffer=a.get_array(),
630-
strides=out_strides,
631-
offset=out_offset,
632-
)
625+
return dpnp_array(
626+
out_shape, buffer=a, strides=out_strides, offset=out_offset
633627
)
634628

635629

dpnp/tests/test_ndarray.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
import dpctl.tensor as dpt
22
import numpy
33
import pytest
4-
from numpy.testing import assert_allclose, assert_array_equal
4+
from numpy.testing import (
5+
assert_allclose,
6+
assert_array_equal,
7+
assert_raises_regex,
8+
)
59

610
import dpnp
711

@@ -104,6 +108,48 @@ def test_flags_writable():
104108
assert not a.imag.flags.writable
105109

106110

111+
class TestArrayNamespace:
112+
def test_basic(self):
113+
a = dpnp.arange(2)
114+
xp = a.__array_namespace__()
115+
assert xp is dpnp
116+
117+
@pytest.mark.parametrize("api_version", [None, "2023.12"])
118+
def test_api_version(self, api_version):
119+
a = dpnp.arange(2)
120+
xp = a.__array_namespace__(api_version=api_version)
121+
assert xp is dpnp
122+
123+
@pytest.mark.parametrize("api_version", ["2021.12", "2022.12", "2024.12"])
124+
def test_unsupported_api_version(self, api_version):
125+
a = dpnp.arange(2)
126+
assert_raises_regex(
127+
ValueError,
128+
"Only 2023.12 is supported",
129+
a.__array_namespace__,
130+
api_version=api_version,
131+
)
132+
133+
@pytest.mark.parametrize(
134+
"api_version",
135+
[
136+
2023,
137+
(2022,),
138+
[
139+
2021,
140+
],
141+
],
142+
)
143+
def test_wrong_api_version(self, api_version):
144+
a = dpnp.arange(2)
145+
assert_raises_regex(
146+
TypeError,
147+
"Expected type str",
148+
a.__array_namespace__,
149+
api_version=api_version,
150+
)
151+
152+
107153
class TestItem:
108154
@pytest.mark.parametrize("args", [2, 7, (1, 2), (2, 0)])
109155
def test_basic(self, args):

0 commit comments

Comments
 (0)