@@ -168,9 +168,9 @@ def _resolve_one_strong_one_weak_types(st_dtype, dtype, dev):
168
168
return dpt .dtype (ti .default_device_int_type (dev ))
169
169
if isinstance (dtype , WeakComplexType ):
170
170
if st_dtype is dpt .float16 or st_dtype is dpt .float32 :
171
- return st_dtype , dpt .complex64
171
+ return dpt .complex64
172
172
return _to_device_supported_dtype (dpt .complex128 , dev )
173
- return ( _to_device_supported_dtype (dpt .float64 , dev ), )
173
+ return _to_device_supported_dtype (dpt .float64 , dev )
174
174
else :
175
175
return st_dtype
176
176
else :
@@ -197,8 +197,6 @@ def _check_clip_dtypes(res_dtype, arg1_dtype, arg2_dtype, sycl_dev):
197
197
198
198
199
199
def _clip_none (x , val , out , order , _binary_fn ):
200
- if order not in ["K" , "C" , "F" , "A" ]:
201
- order = "K"
202
200
q1 , x_usm_type = x .sycl_queue , x .usm_type
203
201
q2 , val_usm_type = _get_queue_usm_type (val )
204
202
if q2 is None :
@@ -391,9 +389,8 @@ def _clip_none(x, val, out, order, _binary_fn):
391
389
return out
392
390
393
391
394
- # need to handle logic for min or max being None
395
- def clip (x , min = None , max = None , out = None , order = "K" ):
396
- """clip(x, min, max, out=None, order="K")
392
+ def clip (x , / , min = None , max = None , out = None , order = "K" ):
393
+ """clip(x, min=None, max=None, out=None, order="K")
397
394
398
395
Clips to the range [`min_i`, `max_i`] for each element `x_i`
399
396
in `x`.
@@ -405,11 +402,9 @@ def clip(x, min=None, max=None, out=None, order="K"):
405
402
min ({None, usm_ndarray}, optional): Array containing minimum values.
406
403
Must be compatible with `x` and `max` according
407
404
to broadcasting rules.
408
- Only one of `min` and `max` can be `None`.
409
405
max ({None, usm_ndarray}, optional): Array containing maximum values.
410
406
Must be compatible with `x` and `min` according
411
407
to broadcasting rules.
412
- Only one of `min` and `max` can be `None`.
413
408
out ({None, usm_ndarray}, optional):
414
409
Output array to populate.
415
410
Array must have the correct shape and the expected data type.
@@ -428,10 +423,67 @@ def clip(x, min=None, max=None, out=None, order="K"):
428
423
"Expected `x` to be of dpctl.tensor.usm_ndarray type, got "
429
424
f"{ type (x )} "
430
425
)
426
+ if order not in ["K" , "C" , "F" , "A" ]:
427
+ order = "K"
431
428
if min is None and max is None :
432
- raise ValueError (
433
- "only one of `min` and `max` is permitted to be `None`"
429
+ exec_q = x .sycl_queue
430
+ orig_out = out
431
+ if out is not None :
432
+ if not isinstance (out , dpt .usm_ndarray ):
433
+ raise TypeError (
434
+ "output array must be of usm_ndarray type, got "
435
+ f"{ type (out )} "
436
+ )
437
+
438
+ if out .shape != x .shape :
439
+ raise ValueError (
440
+ "The shape of input and output arrays are "
441
+ f"inconsistent. Expected output shape is { x .shape } , "
442
+ f"got { out .shape } "
443
+ )
444
+
445
+ if x .dtype != out .dtype :
446
+ raise ValueError (
447
+ f"Output array of type { x .dtype } is needed, "
448
+ f"got { out .dtype } "
449
+ )
450
+
451
+ if (
452
+ dpctl .utils .get_execution_queue ((exec_q , out .sycl_queue ))
453
+ is None
454
+ ):
455
+ raise ExecutionPlacementError (
456
+ "Input and output allocation queues are not compatible"
457
+ )
458
+
459
+ if ti ._array_overlap (x , out ):
460
+ if not ti ._same_logical_tensors (x , out ):
461
+ out = dpt .empty_like (out )
462
+ else :
463
+ return out
464
+ else :
465
+ if order == "K" :
466
+ out = _empty_like_orderK (x , x .dtype )
467
+ else :
468
+ if order == "A" :
469
+ order = "F" if x .flags .f_contiguous else "C"
470
+ out = dpt .empty_like (x , order = order )
471
+
472
+ ht_copy_ev , copy_ev = ti ._copy_usm_ndarray_into_usm_ndarray (
473
+ src = x , dst = out , sycl_queue = exec_q
434
474
)
475
+ if not (orig_out is None or orig_out is out ):
476
+ # Copy the out data from temporary buffer to original memory
477
+ ht_copy_out_ev , _ = ti ._copy_usm_ndarray_into_usm_ndarray (
478
+ src = out ,
479
+ dst = orig_out ,
480
+ sycl_queue = exec_q ,
481
+ depends = [copy_ev ],
482
+ )
483
+ ht_copy_out_ev .wait ()
484
+ out = orig_out
485
+ ht_copy_ev .wait ()
486
+ return out
435
487
elif max is None :
436
488
return _clip_none (x , min , out , order , tei ._maximum )
437
489
elif min is None :
0 commit comments