Skip to content

Commit 9e81f89

Browse files
Merge pull request #1337 from IntelPython/master
Merging developmental snapshot tagged 0.14.6dev2 to gold/2021
2 parents ccc7053 + 074ec3a commit 9e81f89

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

50 files changed

+1953
-354
lines changed

.github/workflows/conda-package.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -486,7 +486,7 @@ jobs:
486486
done
487487
488488
array-api-conformity:
489-
needs: test_linux
489+
needs: build_linux
490490
runs-on: ${{ matrix.runner }}
491491

492492
strategy:

dpctl/_backend.pxd

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,7 @@ cdef extern from "syclinterface/dpctl_sycl_platform_manager.h":
299299

300300

301301
cdef extern from "syclinterface/dpctl_sycl_platform_interface.h":
302+
cdef bool DPCTLPlatform_AreEq(const DPCTLSyclPlatformRef, const DPCTLSyclPlatformRef)
302303
cdef DPCTLSyclPlatformRef DPCTLPlatform_Copy(const DPCTLSyclPlatformRef)
303304
cdef DPCTLSyclPlatformRef DPCTLPlatform_Create()
304305
cdef DPCTLSyclPlatformRef DPCTLPlatform_CreateFromSelector(
@@ -308,6 +309,7 @@ cdef extern from "syclinterface/dpctl_sycl_platform_interface.h":
308309
cdef const char *DPCTLPlatform_GetName(const DPCTLSyclPlatformRef)
309310
cdef const char *DPCTLPlatform_GetVendor(const DPCTLSyclPlatformRef)
310311
cdef const char *DPCTLPlatform_GetVersion(const DPCTLSyclPlatformRef)
312+
cdef size_t DPCTLPlatform_Hash(const DPCTLSyclPlatformRef)
311313
cdef DPCTLPlatformVectorRef DPCTLPlatform_GetPlatforms()
312314
cdef DPCTLSyclContextRef DPCTLPlatform_GetDefaultContext(
313315
const DPCTLSyclPlatformRef)

dpctl/_sycl_platform.pxd

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
SYCL platform-related helper functions.
2222
"""
2323

24+
from libcpp cimport bool
25+
2426
from ._backend cimport DPCTLSyclDeviceSelectorRef, DPCTLSyclPlatformRef
2527

2628

@@ -40,6 +42,7 @@ cdef class SyclPlatform(_SyclPlatform):
4042
cdef int _init_from_selector(self, DPCTLSyclDeviceSelectorRef DSRef)
4143
cdef int _init_from__SyclPlatform(self, _SyclPlatform other)
4244
cdef DPCTLSyclPlatformRef get_platform_ref(self)
45+
cdef bool equals(self, SyclPlatform)
4346

4447

4548
cpdef list get_platforms()

dpctl/_sycl_platform.pyx

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,13 @@
2121
""" Implements SyclPlatform Cython extension type.
2222
"""
2323

24+
from libcpp cimport bool
25+
2426
from ._backend cimport ( # noqa: E211
2527
DPCTLCString_Delete,
2628
DPCTLDeviceSelector_Delete,
2729
DPCTLFilterSelector_Create,
30+
DPCTLPlatform_AreEq,
2831
DPCTLPlatform_Copy,
2932
DPCTLPlatform_Create,
3033
DPCTLPlatform_CreateFromSelector,
@@ -35,6 +38,7 @@ from ._backend cimport ( # noqa: E211
3538
DPCTLPlatform_GetPlatforms,
3639
DPCTLPlatform_GetVendor,
3740
DPCTLPlatform_GetVersion,
41+
DPCTLPlatform_Hash,
3842
DPCTLPlatformMgr_GetInfo,
3943
DPCTLPlatformMgr_PrintInfo,
4044
DPCTLPlatformVector_Delete,
@@ -274,6 +278,42 @@ cdef class SyclPlatform(_SyclPlatform):
274278
else:
275279
return SyclContext._create(CRef)
276280

281+
cdef bool equals(self, SyclPlatform other):
282+
"""
283+
Returns true if the :class:`dpctl.SyclPlatform` argument has the
284+
same underlying ``DPCTLSyclPlatformRef`` object as this
285+
:class:`dpctl.SyclPlatform` instance.
286+
287+
Returns:
288+
:obj:`bool`: ``True`` if the two :class:`dpctl.SyclPlatform` objects
289+
point to the same ``DPCTLSyclPlatformRef`` object, otherwise
290+
``False``.
291+
"""
292+
return DPCTLPlatform_AreEq(self._platform_ref, other.get_platform_ref())
293+
294+
def __eq__(self, other):
295+
"""
296+
Returns True if the :class:`dpctl.SyclPlatform` argument has the
297+
same underlying ``DPCTLSyclPlatformRef`` object as this
298+
:class:`dpctl.SyclPlatform` instance.
299+
300+
Returns:
301+
:obj:`bool`: ``True`` if the two :class:`dpctl.SyclPlatform` objects
302+
point to the same ``DPCTLSyclPlatformRef`` object, otherwise
303+
``False``.
304+
"""
305+
if isinstance(other, SyclPlatform):
306+
return self.equals(<SyclPlatform> other)
307+
else:
308+
return False
309+
310+
def __hash__(self):
311+
"""
312+
Returns a hash value by hashing the underlying ``sycl::platform`` object.
313+
314+
"""
315+
return DPCTLPlatform_Hash(self._platform_ref)
316+
277317

278318
def lsplatform(verbosity=0):
279319
"""

dpctl/_sycl_queue.pxd

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ from ._sycl_event cimport SyclEvent
2929
from .program._program cimport SyclKernel
3030

3131

32-
cdef void default_async_error_handler(int) nogil except *
32+
cdef void default_async_error_handler(int) except * nogil
3333

3434
cdef public api class _SyclQueue [
3535
object Py_SyclQueueObject, type Py_SyclQueueType

dpctl/tensor/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,8 @@
135135
logical_not,
136136
logical_or,
137137
logical_xor,
138+
maximum,
139+
minimum,
138140
multiply,
139141
negative,
140142
not_equal,
@@ -274,6 +276,8 @@
274276
"log1p",
275277
"log2",
276278
"log10",
279+
"maximum",
280+
"minimum",
277281
"multiply",
278282
"negative",
279283
"not_equal",

dpctl/tensor/_copy_utils.py

Lines changed: 147 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
16+
import builtins
1617
import operator
1718

1819
import numpy as np
@@ -245,6 +246,23 @@ def _broadcast_shapes(sh1, sh2):
245246
).shape
246247

247248

249+
def _broadcast_strides(X_shape, X_strides, res_ndim):
250+
"""
251+
Broadcasts strides to match the given dimensions;
252+
returns tuple type strides.
253+
"""
254+
out_strides = [0] * res_ndim
255+
X_shape_len = len(X_shape)
256+
str_dim = -X_shape_len
257+
for i in range(X_shape_len):
258+
shape_value = X_shape[i]
259+
if not shape_value == 1:
260+
out_strides[str_dim] = X_strides[i]
261+
str_dim += 1
262+
263+
return tuple(out_strides)
264+
265+
248266
def _copy_from_usm_ndarray_to_usm_ndarray(dst, src):
249267
if any(
250268
not isinstance(arg, dpt.usm_ndarray)
@@ -267,7 +285,7 @@ def _copy_from_usm_ndarray_to_usm_ndarray(dst, src):
267285
except ValueError as exc:
268286
raise ValueError("Shapes of two arrays are not compatible") from exc
269287

270-
if dst.size < src.size:
288+
if dst.size < src.size and dst.size < np.prod(common_shape):
271289
raise ValueError("Destination is smaller ")
272290

273291
if len(common_shape) > dst.ndim:
@@ -278,17 +296,127 @@ def _copy_from_usm_ndarray_to_usm_ndarray(dst, src):
278296
common_shape = common_shape[ones_count:]
279297

280298
if src.ndim < len(common_shape):
281-
new_src_strides = (0,) * (len(common_shape) - src.ndim) + src.strides
299+
new_src_strides = _broadcast_strides(
300+
src.shape, src.strides, len(common_shape)
301+
)
302+
src_same_shape = dpt.usm_ndarray(
303+
common_shape, dtype=src.dtype, buffer=src, strides=new_src_strides
304+
)
305+
elif src.ndim == len(common_shape):
306+
new_src_strides = _broadcast_strides(
307+
src.shape, src.strides, len(common_shape)
308+
)
282309
src_same_shape = dpt.usm_ndarray(
283310
common_shape, dtype=src.dtype, buffer=src, strides=new_src_strides
284311
)
285312
else:
286-
src_same_shape = src
287-
src_same_shape.shape = common_shape
313+
# since broadcasting succeeded, src.ndim is greater because of
314+
# leading sequence of ones, so we trim it
315+
n = len(common_shape)
316+
new_src_strides = _broadcast_strides(
317+
src.shape[-n:], src.strides[-n:], n
318+
)
319+
src_same_shape = dpt.usm_ndarray(
320+
common_shape,
321+
dtype=src.dtype,
322+
buffer=src.usm_data,
323+
strides=new_src_strides,
324+
offset=src._element_offset,
325+
)
288326

289327
_copy_same_shape(dst, src_same_shape)
290328

291329

330+
def _empty_like_orderK(X, dt, usm_type=None, dev=None):
331+
"""Returns empty array like `x`, using order='K'
332+
333+
For an array `x` that was obtained by permutation of a contiguous
334+
array the returned array will have the same shape and the same
335+
strides as `x`.
336+
"""
337+
if not isinstance(X, dpt.usm_ndarray):
338+
raise TypeError(f"Expected usm_ndarray, got {type(X)}")
339+
if usm_type is None:
340+
usm_type = X.usm_type
341+
if dev is None:
342+
dev = X.device
343+
fl = X.flags
344+
if fl["C"] or X.size <= 1:
345+
return dpt.empty_like(
346+
X, dtype=dt, usm_type=usm_type, device=dev, order="C"
347+
)
348+
elif fl["F"]:
349+
return dpt.empty_like(
350+
X, dtype=dt, usm_type=usm_type, device=dev, order="F"
351+
)
352+
st = list(X.strides)
353+
perm = sorted(
354+
range(X.ndim), key=lambda d: builtins.abs(st[d]), reverse=True
355+
)
356+
inv_perm = sorted(range(X.ndim), key=lambda i: perm[i])
357+
st_sorted = [st[i] for i in perm]
358+
sh = X.shape
359+
sh_sorted = tuple(sh[i] for i in perm)
360+
R = dpt.empty(sh_sorted, dtype=dt, usm_type=usm_type, device=dev, order="C")
361+
if min(st_sorted) < 0:
362+
sl = tuple(
363+
slice(None, None, -1)
364+
if st_sorted[i] < 0
365+
else slice(None, None, None)
366+
for i in range(X.ndim)
367+
)
368+
R = R[sl]
369+
return dpt.permute_dims(R, inv_perm)
370+
371+
372+
def _empty_like_pair_orderK(X1, X2, dt, res_shape, usm_type, dev):
373+
if not isinstance(X1, dpt.usm_ndarray):
374+
raise TypeError(f"Expected usm_ndarray, got {type(X1)}")
375+
if not isinstance(X2, dpt.usm_ndarray):
376+
raise TypeError(f"Expected usm_ndarray, got {type(X2)}")
377+
nd1 = X1.ndim
378+
nd2 = X2.ndim
379+
if nd1 > nd2 and X1.shape == res_shape:
380+
return _empty_like_orderK(X1, dt, usm_type, dev)
381+
elif nd1 < nd2 and X2.shape == res_shape:
382+
return _empty_like_orderK(X2, dt, usm_type, dev)
383+
fl1 = X1.flags
384+
fl2 = X2.flags
385+
if fl1["C"] or fl2["C"]:
386+
return dpt.empty(
387+
res_shape, dtype=dt, usm_type=usm_type, device=dev, order="C"
388+
)
389+
if fl1["F"] and fl2["F"]:
390+
return dpt.empty(
391+
res_shape, dtype=dt, usm_type=usm_type, device=dev, order="F"
392+
)
393+
st1 = list(X1.strides)
394+
st2 = list(X2.strides)
395+
max_ndim = max(nd1, nd2)
396+
st1 += [0] * (max_ndim - len(st1))
397+
st2 += [0] * (max_ndim - len(st2))
398+
perm = sorted(
399+
range(max_ndim),
400+
key=lambda d: (builtins.abs(st1[d]), builtins.abs(st2[d])),
401+
reverse=True,
402+
)
403+
inv_perm = sorted(range(max_ndim), key=lambda i: perm[i])
404+
st1_sorted = [st1[i] for i in perm]
405+
st2_sorted = [st2[i] for i in perm]
406+
sh = res_shape
407+
sh_sorted = tuple(sh[i] for i in perm)
408+
R = dpt.empty(sh_sorted, dtype=dt, usm_type=usm_type, device=dev, order="C")
409+
if max(min(st1_sorted), min(st2_sorted)) < 0:
410+
sl = tuple(
411+
slice(None, None, -1)
412+
if (st1_sorted[i] < 0 and st2_sorted[i] < 0)
413+
else slice(None, None, None)
414+
for i in range(nd1)
415+
)
416+
R = R[sl]
417+
return dpt.permute_dims(R, inv_perm)
418+
419+
292420
def copy(usm_ary, order="K"):
293421
"""copy(ary, order="K")
294422
@@ -334,28 +462,15 @@ def copy(usm_ary, order="K"):
334462
"Unrecognized value of the order keyword. "
335463
"Recognized values are 'A', 'C', 'F', or 'K'"
336464
)
337-
c_contig = usm_ary.flags.c_contiguous
338-
f_contig = usm_ary.flags.f_contiguous
339-
R = dpt.usm_ndarray(
340-
usm_ary.shape,
341-
dtype=usm_ary.dtype,
342-
buffer=usm_ary.usm_type,
343-
order=copy_order,
344-
buffer_ctor_kwargs={"queue": usm_ary.sycl_queue},
345-
)
346-
if order == "K" and (not c_contig and not f_contig):
347-
original_strides = usm_ary.strides
348-
ind = sorted(
349-
range(usm_ary.ndim),
350-
key=lambda i: abs(original_strides[i]),
351-
reverse=True,
352-
)
353-
new_strides = tuple(R.strides[ind[i]] for i in ind)
465+
if order == "K":
466+
R = _empty_like_orderK(usm_ary, usm_ary.dtype)
467+
else:
354468
R = dpt.usm_ndarray(
355469
usm_ary.shape,
356470
dtype=usm_ary.dtype,
357-
buffer=R.usm_data,
358-
strides=new_strides,
471+
buffer=usm_ary.usm_type,
472+
order=copy_order,
473+
buffer_ctor_kwargs={"queue": usm_ary.sycl_queue},
359474
)
360475
_copy_same_shape(R, usm_ary)
361476
return R
@@ -432,26 +547,15 @@ def astype(usm_ary, newdtype, order="K", casting="unsafe", copy=True):
432547
"Unrecognized value of the order keyword. "
433548
"Recognized values are 'A', 'C', 'F', or 'K'"
434549
)
435-
R = dpt.usm_ndarray(
436-
usm_ary.shape,
437-
dtype=target_dtype,
438-
buffer=usm_ary.usm_type,
439-
order=copy_order,
440-
buffer_ctor_kwargs={"queue": usm_ary.sycl_queue},
441-
)
442-
if order == "K" and (not c_contig and not f_contig):
443-
original_strides = usm_ary.strides
444-
ind = sorted(
445-
range(usm_ary.ndim),
446-
key=lambda i: abs(original_strides[i]),
447-
reverse=True,
448-
)
449-
new_strides = tuple(R.strides[ind[i]] for i in ind)
550+
if order == "K":
551+
R = _empty_like_orderK(usm_ary, target_dtype)
552+
else:
450553
R = dpt.usm_ndarray(
451554
usm_ary.shape,
452555
dtype=target_dtype,
453-
buffer=R.usm_data,
454-
strides=new_strides,
556+
buffer=usm_ary.usm_type,
557+
order=copy_order,
558+
buffer_ctor_kwargs={"queue": usm_ary.sycl_queue},
455559
)
456560
_copy_from_usm_ndarray_to_usm_ndarray(R, usm_ary)
457561
return R
@@ -492,6 +596,8 @@ def _extract_impl(ary, ary_mask, axis=0):
492596
dst = dpt.empty(
493597
dst_shape, dtype=ary.dtype, usm_type=ary.usm_type, device=ary.device
494598
)
599+
if dst.size == 0:
600+
return dst
495601
hev, _ = ti._extract(
496602
src=ary,
497603
cumsum=cumsum,
@@ -517,7 +623,7 @@ def _nonzero_impl(ary):
517623
mask_nelems, dtype=cumsum_dt, sycl_queue=exec_q, order="C"
518624
)
519625
mask_count = ti.mask_positions(ary, cumsum, sycl_queue=exec_q)
520-
indexes_dt = ti.default_device_int_type(exec_q.sycl_device)
626+
indexes_dt = ti.default_device_index_type(exec_q.sycl_device)
521627
indexes = dpt.empty(
522628
(ary.ndim, mask_count),
523629
dtype=indexes_dt,

0 commit comments

Comments
 (0)