@@ -77,24 +77,78 @@ void dpnp_rng_srand_c(size_t seed)
77
77
}
78
78
79
79
template <typename _DistrType, typename _EngineType, typename _DataType>
80
- static inline DPCTLSyclEventRef dpnp_rng_generate (const _DistrType& distr,
81
- _EngineType& engine,
82
- const int64_t size,
83
- _DataType* result) {
80
+ static inline DPCTLSyclEventRef
81
+ dpnp_rng_generate (const _DistrType& distr, _EngineType& engine, const int64_t size, _DataType* result)
82
+ {
84
83
DPCTLSyclEventRef event_ref = nullptr ;
85
84
sycl::event event;
86
85
87
86
// perform rng generation
88
- try {
87
+ try
88
+ {
89
89
event = mkl_rng::generate<_DistrType, _EngineType>(distr, engine, size, result);
90
90
event_ref = reinterpret_cast <DPCTLSyclEventRef>(&event);
91
- } catch (const std::exception &e) {
91
+ }
92
+ catch (const std::exception& e)
93
+ {
92
94
// TODO: add error reporting
93
95
return event_ref;
94
96
}
95
97
return DPCTLEvent_Copy (event_ref);
96
98
}
97
99
100
+ template <typename _EngineType, typename _DataType>
101
+ static inline DPCTLSyclEventRef dpnp_rng_generate_uniform (
102
+ _EngineType& engine, sycl::queue* q, const _DataType a, const _DataType b, const int64_t size, _DataType* result)
103
+ {
104
+ DPCTLSyclEventRef event_ref = nullptr ;
105
+
106
+ if constexpr (std::is_same<_DataType, int32_t >::value)
107
+ {
108
+ if (q->get_device ().has (sycl::aspect::fp64))
109
+ {
110
+ /* *
111
+ * A note from oneMKL for oneapi::mkl::rng::uniform (Discrete):
112
+ * The oneapi::mkl::rng::uniform_method::standard uses the s BRNG type on GPU devices.
113
+ * This might cause the produced numbers to have incorrect statistics (due to rounding error)
114
+ * when abs(b-a) > 2^23 || abs(b) > 2^23 || abs(a) > 2^23. To get proper statistics for this case,
115
+ * use the oneapi::mkl::rng::uniform_method::accurate method instead.
116
+ */
117
+ using method_type = mkl_rng::uniform_method::accurate;
118
+ mkl_rng::uniform<_DataType, method_type> distribution (a, b);
119
+
120
+ // perform generation
121
+ try
122
+ {
123
+ sycl::event event = mkl_rng::generate (distribution, engine, size, result);
124
+
125
+ event_ref = reinterpret_cast <DPCTLSyclEventRef>(&event);
126
+ return DPCTLEvent_Copy (event_ref);
127
+ }
128
+ catch (const oneapi::mkl::unsupported_device&)
129
+ {
130
+ // fall through to try with uniform_method::standard
131
+ }
132
+ catch (const oneapi::mkl::unimplemented&)
133
+ {
134
+ // fall through to try with uniform_method::standard
135
+ }
136
+ catch (const std::exception& e)
137
+ {
138
+ // TODO: add error reporting
139
+ return event_ref;
140
+ }
141
+ }
142
+ }
143
+
144
+ // uniform_method::standard is a method used by default
145
+ using method_type = mkl_rng::uniform_method::standard;
146
+ mkl_rng::uniform<_DataType, method_type> distribution (a, b);
147
+
148
+ // perform generation
149
+ return dpnp_rng_generate (distribution, engine, size, result);
150
+ }
151
+
98
152
template <typename _DataType>
99
153
DPCTLSyclEventRef dpnp_rng_beta_c (DPCTLSyclQueueRef q_ref,
100
154
void * result,
@@ -1392,49 +1446,75 @@ DPCTLSyclEventRef dpnp_rng_normal_c(DPCTLSyclQueueRef q_ref,
1392
1446
{
1393
1447
// avoid warning unused variable
1394
1448
(void )dep_event_vec_ref;
1395
- (void )q_ref;
1396
1449
1397
1450
DPCTLSyclEventRef event_ref = nullptr ;
1451
+ sycl::queue* q = reinterpret_cast <sycl::queue*>(q_ref);
1398
1452
1399
1453
if (!size)
1400
1454
{
1401
1455
return event_ref;
1402
1456
}
1457
+ assert (q != nullptr );
1403
1458
1404
- mt19937_struct* random_state = static_cast <mt19937_struct *>(random_state_in);
1405
- _DataType* result = static_cast <_DataType *>(result_out);
1459
+ _DataType* result = static_cast <_DataType*>(result_out);
1406
1460
1407
1461
// set mean of distribution
1408
1462
const _DataType mean = static_cast <_DataType>(mean_in);
1409
1463
// set standard deviation of distribution
1410
1464
const _DataType stddev = static_cast <_DataType>(stddev_in);
1411
1465
1412
1466
mkl_rng::gaussian<_DataType> distribution (mean, stddev);
1413
- mkl_rng::mt19937 *engine = static_cast <mkl_rng::mt19937 *>(random_state->engine );
1414
1467
1415
- // perform generation
1416
- return dpnp_rng_generate<mkl_rng::gaussian<_DataType>, mkl_rng::mt19937, _DataType>(
1417
- distribution, *engine, size, result);
1468
+ if (q->get_device ().is_cpu ())
1469
+ {
1470
+ mt19937_struct* random_state = static_cast <mt19937_struct*>(random_state_in);
1471
+ mkl_rng::mt19937* engine = static_cast <mkl_rng::mt19937*>(random_state->engine );
1472
+
1473
+ // perform generation with MT19937 engine
1474
+ event_ref = dpnp_rng_generate (distribution, *engine, size, result);
1475
+ }
1476
+ else
1477
+ {
1478
+ mcg59_struct* random_state = static_cast <mcg59_struct*>(random_state_in);
1479
+ mkl_rng::mcg59* engine = static_cast <mkl_rng::mcg59*>(random_state->engine );
1480
+
1481
+ // perform generation with MCG59 engine
1482
+ event_ref = dpnp_rng_generate (distribution, *engine, size, result);
1483
+ }
1484
+ return event_ref;
1418
1485
}
1419
1486
1420
1487
template <typename _DataType>
1421
1488
void dpnp_rng_normal_c (void * result, const _DataType mean, const _DataType stddev, const size_t size)
1422
1489
{
1423
- DPCTLSyclQueueRef q_ref = reinterpret_cast <DPCTLSyclQueueRef>(&DPNP_QUEUE);
1490
+ sycl::queue* q = &DPNP_QUEUE;
1491
+ DPCTLSyclQueueRef q_ref = reinterpret_cast <DPCTLSyclQueueRef>(q);
1424
1492
DPCTLEventVectorRef dep_event_vec_ref = nullptr ;
1425
- mt19937_struct* mt19937 = new mt19937_struct ();
1426
- mt19937->engine = &DPNP_RNG_ENGINE;
1493
+ DPCTLSyclEventRef event_ref = nullptr ;
1427
1494
1428
- DPCTLSyclEventRef event_ref = dpnp_rng_normal_c<_DataType>(q_ref,
1429
- result,
1430
- mean,
1431
- stddev,
1432
- static_cast <int64_t >(size),
1433
- mt19937,
1434
- dep_event_vec_ref);
1435
- DPCTLEvent_WaitAndThrow (event_ref);
1436
- DPCTLEvent_Delete (event_ref);
1437
- delete mt19937;
1495
+ if (q->get_device ().is_cpu ())
1496
+ {
1497
+ mt19937_struct* mt19937 = new mt19937_struct ();
1498
+ mt19937->engine = &DPNP_RNG_ENGINE;
1499
+
1500
+ event_ref = dpnp_rng_normal_c<_DataType>(
1501
+ q_ref, result, mean, stddev, static_cast <int64_t >(size), mt19937, dep_event_vec_ref);
1502
+ DPCTLEvent_WaitAndThrow (event_ref);
1503
+ DPCTLEvent_Delete (event_ref);
1504
+ delete mt19937;
1505
+ }
1506
+ else
1507
+ {
1508
+ // MCG59 engine is assumed to provide a better performance on GPU than MT19937
1509
+ mcg59_struct* mcg59 = new mcg59_struct ();
1510
+ mcg59->engine = &DPNP_RNG_MCG59_ENGINE;
1511
+
1512
+ event_ref = dpnp_rng_normal_c<_DataType>(
1513
+ q_ref, result, mean, stddev, static_cast <int64_t >(size), mcg59, dep_event_vec_ref);
1514
+ DPCTLEvent_WaitAndThrow (event_ref);
1515
+ DPCTLEvent_Delete (event_ref);
1516
+ delete mcg59;
1517
+ }
1438
1518
}
1439
1519
1440
1520
template <typename _DataType>
@@ -2149,74 +2229,75 @@ DPCTLSyclEventRef dpnp_rng_uniform_c(DPCTLSyclQueueRef q_ref,
2149
2229
return event_ref;
2150
2230
}
2151
2231
2152
- sycl::queue * q = reinterpret_cast <sycl::queue *>(q_ref);
2232
+ sycl::queue* q = reinterpret_cast <sycl::queue*>(q_ref);
2153
2233
2154
- mt19937_struct* random_state = static_cast <mt19937_struct *>(random_state_in);
2155
- _DataType* result = static_cast <_DataType *>(result_out);
2234
+ _DataType* result = static_cast <_DataType*>(result_out);
2156
2235
2157
2236
// set left bound of distribution
2158
2237
const _DataType a = static_cast <_DataType>(low);
2159
2238
// set right bound of distribution
2160
2239
const _DataType b = static_cast <_DataType>(high);
2161
2240
2162
- mkl_rng::mt19937 *engine = static_cast <mkl_rng::mt19937 *>(random_state->engine );
2163
-
2164
- if constexpr (std::is_same<_DataType, int32_t >::value) {
2165
- if (q->get_device ().has (sycl::aspect::fp64)) {
2166
- /* *
2167
- * A note from oneMKL for oneapi::mkl::rng::uniform (Discrete):
2168
- * The oneapi::mkl::rng::uniform_method::standard uses the s BRNG type on GPU devices.
2169
- * This might cause the produced numbers to have incorrect statistics (due to rounding error)
2170
- * when abs(b-a) > 2^23 || abs(b) > 2^23 || abs(a) > 2^23. To get proper statistics for this case,
2171
- * use the oneapi::mkl::rng::uniform_method::accurate method instead.
2172
- */
2173
- using method_type = mkl_rng::uniform_method::accurate;
2174
- mkl_rng::uniform<_DataType, method_type> distribution (a, b);
2241
+ if (q->get_device ().is_cpu ())
2242
+ {
2243
+ mt19937_struct* random_state = static_cast <mt19937_struct*>(random_state_in);
2244
+ mkl_rng::mt19937* engine = static_cast <mkl_rng::mt19937*>(random_state->engine );
2175
2245
2176
- // perform generation
2177
- try {
2178
- auto event = mkl_rng::generate<mkl_rng::uniform<_DataType, method_type>, mkl_rng::mt19937>(
2179
- distribution, *engine, size, result);
2180
- event_ref = reinterpret_cast <DPCTLSyclEventRef>(&event);
2181
- return DPCTLEvent_Copy (event_ref);
2182
- } catch (const oneapi::mkl::unsupported_device&) {
2183
- // fall through to try with uniform_method::standard
2184
- } catch (const oneapi::mkl::unimplemented&) {
2185
- // fall through to try with uniform_method::standard
2186
- } catch (const std::exception &e) {
2187
- // TODO: add error reporting
2188
- return event_ref;
2189
- }
2190
- }
2246
+ // perform generation with MT19937 engine
2247
+ event_ref = dpnp_rng_generate_uniform (*engine, q, a, b, size, result);
2191
2248
}
2249
+ else
2250
+ {
2251
+ mcg59_struct* random_state = static_cast <mcg59_struct*>(random_state_in);
2252
+ mkl_rng::mcg59* engine = static_cast <mkl_rng::mcg59*>(random_state->engine );
2192
2253
2193
- // uniform_method::standard is a method used by default
2194
- using method_type = mkl_rng::uniform_method::standard;
2195
- mkl_rng::uniform<_DataType, method_type> distribution (a, b);
2196
-
2197
- // perform generation
2198
- return dpnp_rng_generate<mkl_rng::uniform<_DataType, method_type>, mkl_rng::mt19937, _DataType>(
2199
- distribution, *engine, size, result);
2254
+ // perform generation with MCG59 engine
2255
+ event_ref = dpnp_rng_generate_uniform (*engine, q, a, b, size, result);
2256
+ }
2257
+ return event_ref;
2200
2258
}
2201
2259
2202
2260
template <typename _DataType>
2203
2261
void dpnp_rng_uniform_c (void * result, const long low, const long high, const size_t size)
2204
2262
{
2205
- DPCTLSyclQueueRef q_ref = reinterpret_cast <DPCTLSyclQueueRef>(&DPNP_QUEUE);
2263
+ sycl::queue* q = &DPNP_QUEUE;
2264
+ DPCTLSyclQueueRef q_ref = reinterpret_cast <DPCTLSyclQueueRef>(q);
2206
2265
DPCTLEventVectorRef dep_event_vec_ref = nullptr ;
2207
- mt19937_struct* mt19937 = new mt19937_struct ();
2208
- mt19937->engine = &DPNP_RNG_ENGINE;
2266
+ DPCTLSyclEventRef event_ref = nullptr ;
2209
2267
2210
- DPCTLSyclEventRef event_ref = dpnp_rng_uniform_c<_DataType>(q_ref,
2211
- result,
2212
- static_cast <double >(low),
2213
- static_cast <double >(high),
2214
- static_cast <int64_t >(size),
2215
- mt19937,
2216
- dep_event_vec_ref);
2217
- DPCTLEvent_WaitAndThrow (event_ref);
2218
- DPCTLEvent_Delete (event_ref);
2219
- delete mt19937;
2268
+ if (q->get_device ().is_cpu ())
2269
+ {
2270
+ mt19937_struct* mt19937 = new mt19937_struct ();
2271
+ mt19937->engine = &DPNP_RNG_ENGINE;
2272
+
2273
+ event_ref = dpnp_rng_uniform_c<_DataType>(q_ref,
2274
+ result,
2275
+ static_cast <double >(low),
2276
+ static_cast <double >(high),
2277
+ static_cast <int64_t >(size),
2278
+ mt19937,
2279
+ dep_event_vec_ref);
2280
+ DPCTLEvent_WaitAndThrow (event_ref);
2281
+ DPCTLEvent_Delete (event_ref);
2282
+ delete mt19937;
2283
+ }
2284
+ else
2285
+ {
2286
+ // MCG59 engine is assumed to provide a better performance on GPU than MT19937
2287
+ mcg59_struct* mcg59 = new mcg59_struct ();
2288
+ mcg59->engine = &DPNP_RNG_MCG59_ENGINE;
2289
+
2290
+ event_ref = dpnp_rng_uniform_c<_DataType>(q_ref,
2291
+ result,
2292
+ static_cast <double >(low),
2293
+ static_cast <double >(high),
2294
+ static_cast <int64_t >(size),
2295
+ mcg59,
2296
+ dep_event_vec_ref);
2297
+ DPCTLEvent_WaitAndThrow (event_ref);
2298
+ DPCTLEvent_Delete (event_ref);
2299
+ delete mcg59;
2300
+ }
2220
2301
}
2221
2302
2222
2303
template <typename _DataType>
0 commit comments