Skip to content

Commit 56e98b2

Browse files
committed
Permits clip arguments min and max to both be None
Also resolves gh-1489
1 parent 3d6a635 commit 56e98b2

File tree

2 files changed

+80
-11
lines changed

2 files changed

+80
-11
lines changed

dpctl/tensor/_clip.py

Lines changed: 63 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -168,9 +168,9 @@ def _resolve_one_strong_one_weak_types(st_dtype, dtype, dev):
168168
return dpt.dtype(ti.default_device_int_type(dev))
169169
if isinstance(dtype, WeakComplexType):
170170
if st_dtype is dpt.float16 or st_dtype is dpt.float32:
171-
return st_dtype, dpt.complex64
171+
return dpt.complex64
172172
return _to_device_supported_dtype(dpt.complex128, dev)
173-
return (_to_device_supported_dtype(dpt.float64, dev),)
173+
return _to_device_supported_dtype(dpt.float64, dev)
174174
else:
175175
return st_dtype
176176
else:
@@ -197,8 +197,6 @@ def _check_clip_dtypes(res_dtype, arg1_dtype, arg2_dtype, sycl_dev):
197197

198198

199199
def _clip_none(x, val, out, order, _binary_fn):
200-
if order not in ["K", "C", "F", "A"]:
201-
order = "K"
202200
q1, x_usm_type = x.sycl_queue, x.usm_type
203201
q2, val_usm_type = _get_queue_usm_type(val)
204202
if q2 is None:
@@ -391,9 +389,8 @@ def _clip_none(x, val, out, order, _binary_fn):
391389
return out
392390

393391

394-
# need to handle logic for min or max being None
395-
def clip(x, min=None, max=None, out=None, order="K"):
396-
"""clip(x, min, max, out=None, order="K")
392+
def clip(x, /, min=None, max=None, out=None, order="K"):
393+
"""clip(x, min=None, max=None, out=None, order="K")
397394
398395
Clips to the range [`min_i`, `max_i`] for each element `x_i`
399396
in `x`.
@@ -405,11 +402,9 @@ def clip(x, min=None, max=None, out=None, order="K"):
405402
min ({None, usm_ndarray}, optional): Array containing minimum values.
406403
Must be compatible with `x` and `max` according
407404
to broadcasting rules.
408-
Only one of `min` and `max` can be `None`.
409405
max ({None, usm_ndarray}, optional): Array containing maximum values.
410406
Must be compatible with `x` and `min` according
411407
to broadcasting rules.
412-
Only one of `min` and `max` can be `None`.
413408
out ({None, usm_ndarray}, optional):
414409
Output array to populate.
415410
Array must have the correct shape and the expected data type.
@@ -428,10 +423,67 @@ def clip(x, min=None, max=None, out=None, order="K"):
428423
"Expected `x` to be of dpctl.tensor.usm_ndarray type, got "
429424
f"{type(x)}"
430425
)
426+
if order not in ["K", "C", "F", "A"]:
427+
order = "K"
431428
if min is None and max is None:
432-
raise ValueError(
433-
"only one of `min` and `max` is permitted to be `None`"
429+
exec_q = x.sycl_queue
430+
orig_out = out
431+
if out is not None:
432+
if not isinstance(out, dpt.usm_ndarray):
433+
raise TypeError(
434+
"output array must be of usm_ndarray type, got "
435+
f"{type(out)}"
436+
)
437+
438+
if out.shape != x.shape:
439+
raise ValueError(
440+
"The shape of input and output arrays are "
441+
f"inconsistent. Expected output shape is {x.shape}, "
442+
f"got {out.shape}"
443+
)
444+
445+
if x.dtype != out.dtype:
446+
raise ValueError(
447+
f"Output array of type {x.dtype} is needed, "
448+
f"got {out.dtype}"
449+
)
450+
451+
if (
452+
dpctl.utils.get_execution_queue((exec_q, out.sycl_queue))
453+
is None
454+
):
455+
raise ExecutionPlacementError(
456+
"Input and output allocation queues are not compatible"
457+
)
458+
459+
if ti._array_overlap(x, out):
460+
if not ti._same_logical_tensors(x, out):
461+
out = dpt.empty_like(out)
462+
else:
463+
return out
464+
else:
465+
if order == "K":
466+
out = _empty_like_orderK(x, x.dtype)
467+
else:
468+
if order == "A":
469+
order = "F" if x.flags.f_contiguous else "C"
470+
out = dpt.empty_like(x, order=order)
471+
472+
ht_copy_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
473+
src=x, dst=out, sycl_queue=exec_q
434474
)
475+
if not (orig_out is None or orig_out is out):
476+
# Copy the out data from temporary buffer to original memory
477+
ht_copy_out_ev, _ = ti._copy_usm_ndarray_into_usm_ndarray(
478+
src=out,
479+
dst=orig_out,
480+
sycl_queue=exec_q,
481+
depends=[copy_ev],
482+
)
483+
ht_copy_out_ev.wait()
484+
out = orig_out
485+
ht_copy_ev.wait()
486+
return out
435487
elif max is None:
436488
return _clip_none(x, min, out, order, tei._maximum)
437489
elif min is None:

dpctl/tests/test_tensor_clip.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,15 @@ def test_clip_out_need_temporary():
194194
dpt.clip(x[:6], 2, 3, out=x[-6:])
195195
assert dpt.all(x[:-6] == 1) and dpt.all(x[-6:] == 2)
196196

197+
x = dpt.arange(12, dtype="i4")
198+
dpt.clip(x[:6], out=x[-6:])
199+
expected = dpt.arange(6, dtype="i4")
200+
assert dpt.all(x[:-6] == expected) and dpt.all(x[-6:] == expected)
201+
202+
x = dpt.ones(10, dtype="i4")
203+
dpt.clip(x, out=x)
204+
assert dpt.all(x == 1)
205+
197206
x = dpt.full(6, 3, dtype="i4")
198207
a_min = dpt.full(10, 2, dtype="i4")
199208
a_max = dpt.asarray(4, dtype="i4")
@@ -636,3 +645,11 @@ def test_clip_unaligned():
636645

637646
expected = dpt.full(512, 2, dtype="i4")
638647
assert dpt.all(dpt.clip(x[1:], a_min, a_max) == expected)
648+
649+
650+
def test_clip_none_args():
651+
get_queue_or_skip()
652+
653+
x = dpt.arange(10, dtype="i4")
654+
r = dpt.clip(x)
655+
assert dpt.all(x == r)

0 commit comments

Comments
 (0)