Skip to content

Commit b747c13

Browse files
committed
Add device and sycl_queue to dpnp.random.seed() & use random values if seed is None
1 parent 20b6f79 commit b747c13

File tree

2 files changed

+28
-12
lines changed

2 files changed

+28
-12
lines changed

dpnp/random/dpnp_iface_random.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1464,18 +1464,35 @@ def shuffle(x1):
14641464
return
14651465

14661466

1467-
def seed(seed=None):
1467+
def seed(seed=None,
1468+
device=None,
1469+
sycl_queue=None):
14681470
"""
1469-
Reseed a legacy mt19937 random number generator engine.
1471+
Reseed a legacy MT19937 random number generator engine.
1472+
1473+
Parameters
1474+
----------
1475+
device : {None, string, SyclDevice, SyclQueue}, optional
1476+
An array API concept of device where an array with generated numbers will be created.
1477+
The `device` can be ``None`` (the default), an OneAPI filter selector string,
1478+
an instance of :class:`dpctl.SyclDevice` corresponding to a non-partitioned SYCL device,
1479+
an instance of :class:`dpctl.SyclQueue`, or a `Device` object returned by
1480+
:obj:`dpnp.dpnp_array.dpnp_array.device` property.
1481+
sycl_queue : {None, SyclQueue}, optional
1482+
A SYCL queue to use for an array with generated numbers.
14701483
14711484
Limitations
14721485
-----------
1473-
Parameter ``seed`` is supported as a scalar.
1474-
Otherwise, the function will use :obj:`numpy.random.seed` on the backend
1475-
and will be executed on fallback backend.
1486+
Parameter `seed` is supported as either a scalar or an array of maximumum three integer scalars.
14761487
14771488
"""
14781489

1490+
# update a mt19937 random number for both RandomState and legacy functionality
1491+
global _dpnp_random_states
1492+
1493+
sycl_queue = dpnp.get_normalized_queue_device(device=device, sycl_queue=sycl_queue)
1494+
_dpnp_random_states[sycl_queue] = RandomState(seed=seed, sycl_queue=sycl_queue)
1495+
14791496
if not use_origin_backend(seed):
14801497
# TODO:
14811498
# array_like of ints for `seed`
@@ -1488,12 +1505,6 @@ def seed(seed=None):
14881505
else:
14891506
# TODO:
14901507
# migrate to a single approach with RandomState class
1491-
1492-
# update a mt19937 random number for both RandomState and legacy functionality
1493-
global _dpnp_random_states
1494-
for sycl_queue in _dpnp_random_states.keys():
1495-
_dpnp_random_states[sycl_queue] = RandomState(seed=seed, sycl_queue=sycl_queue)
1496-
14971508
dpnp_rng_srand(seed)
14981509

14991510
# always reseed numpy engine also

dpnp/random/dpnp_random_state.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,12 @@ class RandomState:
7676
"""
7777

7878
def __init__(self, seed=None, device=None, sycl_queue=None):
79-
self._seed = 1 if seed is None else seed
79+
if seed is None:
80+
# ask NumPy to generate an array of three random integers as default seed value
81+
self._seed = numpy.random.randint(low=0, high=numpy.iinfo(numpy.uint32).max + 1, size=3)
82+
else:
83+
self._seed = seed
84+
8085
self._sycl_queue = dpnp.get_normalized_queue_device(device=device, sycl_queue=sycl_queue)
8186
self._sycl_device = self._sycl_queue.sycl_device
8287

0 commit comments

Comments
 (0)