@@ -1464,18 +1464,35 @@ def shuffle(x1):
1464
1464
return
1465
1465
1466
1466
1467
- def seed (seed = None ):
1467
+ def seed (seed = None ,
1468
+ device = None ,
1469
+ sycl_queue = None ):
1468
1470
"""
1469
- Reseed a legacy mt19937 random number generator engine.
1471
+ Reseed a legacy MT19937 random number generator engine.
1472
+
1473
+ Parameters
1474
+ ----------
1475
+ device : {None, string, SyclDevice, SyclQueue}, optional
1476
+ An array API concept of device where an array with generated numbers will be created.
1477
+ The `device` can be ``None`` (the default), an OneAPI filter selector string,
1478
+ an instance of :class:`dpctl.SyclDevice` corresponding to a non-partitioned SYCL device,
1479
+ an instance of :class:`dpctl.SyclQueue`, or a `Device` object returned by
1480
+ :obj:`dpnp.dpnp_array.dpnp_array.device` property.
1481
+ sycl_queue : {None, SyclQueue}, optional
1482
+ A SYCL queue to use for an array with generated numbers.
1470
1483
1471
1484
Limitations
1472
1485
-----------
1473
- Parameter ``seed`` is supported as a scalar.
1474
- Otherwise, the function will use :obj:`numpy.random.seed` on the backend
1475
- and will be executed on fallback backend.
1486
+ Parameter `seed` is supported as either a scalar or an array of maximumum three integer scalars.
1476
1487
1477
1488
"""
1478
1489
1490
+ # update a mt19937 random number for both RandomState and legacy functionality
1491
+ global _dpnp_random_states
1492
+
1493
+ sycl_queue = dpnp .get_normalized_queue_device (device = device , sycl_queue = sycl_queue )
1494
+ _dpnp_random_states [sycl_queue ] = RandomState (seed = seed , sycl_queue = sycl_queue )
1495
+
1479
1496
if not use_origin_backend (seed ):
1480
1497
# TODO:
1481
1498
# array_like of ints for `seed`
@@ -1488,12 +1505,6 @@ def seed(seed=None):
1488
1505
else :
1489
1506
# TODO:
1490
1507
# migrate to a single approach with RandomState class
1491
-
1492
- # update a mt19937 random number for both RandomState and legacy functionality
1493
- global _dpnp_random_states
1494
- for sycl_queue in _dpnp_random_states .keys ():
1495
- _dpnp_random_states [sycl_queue ] = RandomState (seed = seed , sycl_queue = sycl_queue )
1496
-
1497
1508
dpnp_rng_srand (seed )
1498
1509
1499
1510
# always reseed numpy engine also
0 commit comments