@@ -341,3 +341,58 @@ def test_add_dtype_error(
341
341
assert_raises_regex (
342
342
TypeError , "Output array of type.*is needed" , dpt .add , ar1 , ar2 , y
343
343
)
344
+
345
+
346
+ @pytest .mark .parametrize ("dtype" , _all_dtypes )
347
+ def test_add_inplace_python_scalar (dtype ):
348
+ q = get_queue_or_skip ()
349
+ skip_if_dtype_not_supported (dtype , q )
350
+ X = dpt .zeros ((10 , 10 ), dtype = dtype , sycl_queue = q )
351
+ dt_kind = X .dtype .kind
352
+ if dt_kind in "ui" :
353
+ X += int (0 )
354
+ elif dt_kind == "f" :
355
+ X += float (0 )
356
+ elif dt_kind == "c" :
357
+ X += complex (0 )
358
+ elif dt_kind == "b" :
359
+ X += bool (0 )
360
+
361
+
362
+ @pytest .mark .parametrize ("op1_dtype" , _all_dtypes )
363
+ @pytest .mark .parametrize ("op2_dtype" , _all_dtypes )
364
+ def test_add_inplace_dtype_matrix (op1_dtype , op2_dtype ):
365
+ q = get_queue_or_skip ()
366
+ skip_if_dtype_not_supported (op1_dtype , q )
367
+ skip_if_dtype_not_supported (op2_dtype , q )
368
+
369
+ if dpt .can_cast (op2_dtype , op1_dtype , casting = "safe" ):
370
+ sz = 127
371
+ ar1 = dpt .ones (sz , dtype = op1_dtype )
372
+ ar2 = dpt .ones_like (ar1 , dtype = op2_dtype )
373
+
374
+ ar1 += ar2
375
+ assert (
376
+ dpt .asnumpy (ar1 ) == np .full (ar1 .shape , 2 , dtype = ar1 .dtype )
377
+ ).all ()
378
+
379
+ ar3 = dpt .ones (sz , dtype = op1_dtype )
380
+ ar4 = dpt .ones (2 * sz , dtype = op2_dtype )
381
+
382
+ ar3 [::- 1 ] += ar4 [::2 ]
383
+ assert (
384
+ dpt .asnumpy (ar3 ) == np .full (ar3 .shape , 2 , dtype = ar3 .dtype )
385
+ ).all ()
386
+
387
+ else :
388
+ assert pytest .raises (TypeError )
389
+
390
+
391
+ def test_add_inplace_broadcasting ():
392
+ get_queue_or_skip ()
393
+
394
+ m = dpt .ones ((100 , 5 ), dtype = "i4" )
395
+ v = dpt .arange (5 , dtype = "i4" )
396
+
397
+ m += v
398
+ assert (dpt .asnumpy (m ) == np .arange (1 , 6 , dtype = "i4" )[np .newaxis , :]).all ()
0 commit comments