@@ -77,24 +77,78 @@ void dpnp_rng_srand_c(size_t seed)
7777}
7878
7979template <typename _DistrType, typename _EngineType, typename _DataType>
80- static inline DPCTLSyclEventRef dpnp_rng_generate (const _DistrType& distr,
81- _EngineType& engine,
82- const int64_t size,
83- _DataType* result) {
80+ static inline DPCTLSyclEventRef
81+ dpnp_rng_generate (const _DistrType& distr, _EngineType& engine, const int64_t size, _DataType* result)
82+ {
8483 DPCTLSyclEventRef event_ref = nullptr ;
8584 sycl::event event;
8685
8786 // perform rng generation
88- try {
87+ try
88+ {
8989 event = mkl_rng::generate<_DistrType, _EngineType>(distr, engine, size, result);
9090 event_ref = reinterpret_cast <DPCTLSyclEventRef>(&event);
91- } catch (const std::exception &e) {
91+ }
92+ catch (const std::exception& e)
93+ {
9294 // TODO: add error reporting
9395 return event_ref;
9496 }
9597 return DPCTLEvent_Copy (event_ref);
9698}
9799
100+ template <typename _EngineType, typename _DataType>
101+ static inline DPCTLSyclEventRef dpnp_rng_generate_uniform (
102+ _EngineType& engine, sycl::queue* q, const _DataType a, const _DataType b, const int64_t size, _DataType* result)
103+ {
104+ DPCTLSyclEventRef event_ref = nullptr ;
105+
106+ if constexpr (std::is_same<_DataType, int32_t >::value)
107+ {
108+ if (q->get_device ().has (sycl::aspect::fp64))
109+ {
110+ /* *
111+ * A note from oneMKL for oneapi::mkl::rng::uniform (Discrete):
112+ * The oneapi::mkl::rng::uniform_method::standard uses the s BRNG type on GPU devices.
113+ * This might cause the produced numbers to have incorrect statistics (due to rounding error)
114+ * when abs(b-a) > 2^23 || abs(b) > 2^23 || abs(a) > 2^23. To get proper statistics for this case,
115+ * use the oneapi::mkl::rng::uniform_method::accurate method instead.
116+ */
117+ using method_type = mkl_rng::uniform_method::accurate;
118+ mkl_rng::uniform<_DataType, method_type> distribution (a, b);
119+
120+ // perform generation
121+ try
122+ {
123+ sycl::event event = mkl_rng::generate (distribution, engine, size, result);
124+
125+ event_ref = reinterpret_cast <DPCTLSyclEventRef>(&event);
126+ return DPCTLEvent_Copy (event_ref);
127+ }
128+ catch (const oneapi::mkl::unsupported_device&)
129+ {
130+ // fall through to try with uniform_method::standard
131+ }
132+ catch (const oneapi::mkl::unimplemented&)
133+ {
134+ // fall through to try with uniform_method::standard
135+ }
136+ catch (const std::exception& e)
137+ {
138+ // TODO: add error reporting
139+ return event_ref;
140+ }
141+ }
142+ }
143+
144+ // uniform_method::standard is a method used by default
145+ using method_type = mkl_rng::uniform_method::standard;
146+ mkl_rng::uniform<_DataType, method_type> distribution (a, b);
147+
148+ // perform generation
149+ return dpnp_rng_generate (distribution, engine, size, result);
150+ }
151+
98152template <typename _DataType>
99153DPCTLSyclEventRef dpnp_rng_beta_c (DPCTLSyclQueueRef q_ref,
100154 void * result,
@@ -1392,49 +1446,75 @@ DPCTLSyclEventRef dpnp_rng_normal_c(DPCTLSyclQueueRef q_ref,
13921446{
13931447 // avoid warning unused variable
13941448 (void )dep_event_vec_ref;
1395- (void )q_ref;
13961449
13971450 DPCTLSyclEventRef event_ref = nullptr ;
1451+ sycl::queue* q = reinterpret_cast <sycl::queue*>(q_ref);
13981452
13991453 if (!size)
14001454 {
14011455 return event_ref;
14021456 }
1457+ assert (q != nullptr );
14031458
1404- mt19937_struct* random_state = static_cast <mt19937_struct *>(random_state_in);
1405- _DataType* result = static_cast <_DataType *>(result_out);
1459+ _DataType* result = static_cast <_DataType*>(result_out);
14061460
14071461 // set mean of distribution
14081462 const _DataType mean = static_cast <_DataType>(mean_in);
14091463 // set standard deviation of distribution
14101464 const _DataType stddev = static_cast <_DataType>(stddev_in);
14111465
14121466 mkl_rng::gaussian<_DataType> distribution (mean, stddev);
1413- mkl_rng::mt19937 *engine = static_cast <mkl_rng::mt19937 *>(random_state->engine );
14141467
1415- // perform generation
1416- return dpnp_rng_generate<mkl_rng::gaussian<_DataType>, mkl_rng::mt19937, _DataType>(
1417- distribution, *engine, size, result);
1468+ if (q->get_device ().is_cpu ())
1469+ {
1470+ mt19937_struct* random_state = static_cast <mt19937_struct*>(random_state_in);
1471+ mkl_rng::mt19937* engine = static_cast <mkl_rng::mt19937*>(random_state->engine );
1472+
1473+ // perform generation with MT19937 engine
1474+ event_ref = dpnp_rng_generate (distribution, *engine, size, result);
1475+ }
1476+ else
1477+ {
1478+ mcg59_struct* random_state = static_cast <mcg59_struct*>(random_state_in);
1479+ mkl_rng::mcg59* engine = static_cast <mkl_rng::mcg59*>(random_state->engine );
1480+
1481+ // perform generation with MCG59 engine
1482+ event_ref = dpnp_rng_generate (distribution, *engine, size, result);
1483+ }
1484+ return event_ref;
14181485}
14191486
14201487template <typename _DataType>
14211488void dpnp_rng_normal_c (void * result, const _DataType mean, const _DataType stddev, const size_t size)
14221489{
1423- DPCTLSyclQueueRef q_ref = reinterpret_cast <DPCTLSyclQueueRef>(&DPNP_QUEUE);
1490+ sycl::queue* q = &DPNP_QUEUE;
1491+ DPCTLSyclQueueRef q_ref = reinterpret_cast <DPCTLSyclQueueRef>(q);
14241492 DPCTLEventVectorRef dep_event_vec_ref = nullptr ;
1425- mt19937_struct* mt19937 = new mt19937_struct ();
1426- mt19937->engine = &DPNP_RNG_ENGINE;
1493+ DPCTLSyclEventRef event_ref = nullptr ;
14271494
1428- DPCTLSyclEventRef event_ref = dpnp_rng_normal_c<_DataType>(q_ref,
1429- result,
1430- mean,
1431- stddev,
1432- static_cast <int64_t >(size),
1433- mt19937,
1434- dep_event_vec_ref);
1435- DPCTLEvent_WaitAndThrow (event_ref);
1436- DPCTLEvent_Delete (event_ref);
1437- delete mt19937;
1495+ if (q->get_device ().is_cpu ())
1496+ {
1497+ mt19937_struct* mt19937 = new mt19937_struct ();
1498+ mt19937->engine = &DPNP_RNG_ENGINE;
1499+
1500+ event_ref = dpnp_rng_normal_c<_DataType>(
1501+ q_ref, result, mean, stddev, static_cast <int64_t >(size), mt19937, dep_event_vec_ref);
1502+ DPCTLEvent_WaitAndThrow (event_ref);
1503+ DPCTLEvent_Delete (event_ref);
1504+ delete mt19937;
1505+ }
1506+ else
1507+ {
1508+ // MCG59 engine is assumed to provide a better performance on GPU than MT19937
1509+ mcg59_struct* mcg59 = new mcg59_struct ();
1510+ mcg59->engine = &DPNP_RNG_MCG59_ENGINE;
1511+
1512+ event_ref = dpnp_rng_normal_c<_DataType>(
1513+ q_ref, result, mean, stddev, static_cast <int64_t >(size), mcg59, dep_event_vec_ref);
1514+ DPCTLEvent_WaitAndThrow (event_ref);
1515+ DPCTLEvent_Delete (event_ref);
1516+ delete mcg59;
1517+ }
14381518}
14391519
14401520template <typename _DataType>
@@ -2149,74 +2229,75 @@ DPCTLSyclEventRef dpnp_rng_uniform_c(DPCTLSyclQueueRef q_ref,
21492229 return event_ref;
21502230 }
21512231
2152- sycl::queue * q = reinterpret_cast <sycl::queue *>(q_ref);
2232+ sycl::queue* q = reinterpret_cast <sycl::queue*>(q_ref);
21532233
2154- mt19937_struct* random_state = static_cast <mt19937_struct *>(random_state_in);
2155- _DataType* result = static_cast <_DataType *>(result_out);
2234+ _DataType* result = static_cast <_DataType*>(result_out);
21562235
21572236 // set left bound of distribution
21582237 const _DataType a = static_cast <_DataType>(low);
21592238 // set right bound of distribution
21602239 const _DataType b = static_cast <_DataType>(high);
21612240
2162- mkl_rng::mt19937 *engine = static_cast <mkl_rng::mt19937 *>(random_state->engine );
2163-
2164- if constexpr (std::is_same<_DataType, int32_t >::value) {
2165- if (q->get_device ().has (sycl::aspect::fp64)) {
2166- /* *
2167- * A note from oneMKL for oneapi::mkl::rng::uniform (Discrete):
2168- * The oneapi::mkl::rng::uniform_method::standard uses the s BRNG type on GPU devices.
2169- * This might cause the produced numbers to have incorrect statistics (due to rounding error)
2170- * when abs(b-a) > 2^23 || abs(b) > 2^23 || abs(a) > 2^23. To get proper statistics for this case,
2171- * use the oneapi::mkl::rng::uniform_method::accurate method instead.
2172- */
2173- using method_type = mkl_rng::uniform_method::accurate;
2174- mkl_rng::uniform<_DataType, method_type> distribution (a, b);
2241+ if (q->get_device ().is_cpu ())
2242+ {
2243+ mt19937_struct* random_state = static_cast <mt19937_struct*>(random_state_in);
2244+ mkl_rng::mt19937* engine = static_cast <mkl_rng::mt19937*>(random_state->engine );
21752245
2176- // perform generation
2177- try {
2178- auto event = mkl_rng::generate<mkl_rng::uniform<_DataType, method_type>, mkl_rng::mt19937>(
2179- distribution, *engine, size, result);
2180- event_ref = reinterpret_cast <DPCTLSyclEventRef>(&event);
2181- return DPCTLEvent_Copy (event_ref);
2182- } catch (const oneapi::mkl::unsupported_device&) {
2183- // fall through to try with uniform_method::standard
2184- } catch (const oneapi::mkl::unimplemented&) {
2185- // fall through to try with uniform_method::standard
2186- } catch (const std::exception &e) {
2187- // TODO: add error reporting
2188- return event_ref;
2189- }
2190- }
2246+ // perform generation with MT19937 engine
2247+ event_ref = dpnp_rng_generate_uniform (*engine, q, a, b, size, result);
21912248 }
2249+ else
2250+ {
2251+ mcg59_struct* random_state = static_cast <mcg59_struct*>(random_state_in);
2252+ mkl_rng::mcg59* engine = static_cast <mkl_rng::mcg59*>(random_state->engine );
21922253
2193- // uniform_method::standard is a method used by default
2194- using method_type = mkl_rng::uniform_method::standard;
2195- mkl_rng::uniform<_DataType, method_type> distribution (a, b);
2196-
2197- // perform generation
2198- return dpnp_rng_generate<mkl_rng::uniform<_DataType, method_type>, mkl_rng::mt19937, _DataType>(
2199- distribution, *engine, size, result);
2254+ // perform generation with MCG59 engine
2255+ event_ref = dpnp_rng_generate_uniform (*engine, q, a, b, size, result);
2256+ }
2257+ return event_ref;
22002258}
22012259
22022260template <typename _DataType>
22032261void dpnp_rng_uniform_c (void * result, const long low, const long high, const size_t size)
22042262{
2205- DPCTLSyclQueueRef q_ref = reinterpret_cast <DPCTLSyclQueueRef>(&DPNP_QUEUE);
2263+ sycl::queue* q = &DPNP_QUEUE;
2264+ DPCTLSyclQueueRef q_ref = reinterpret_cast <DPCTLSyclQueueRef>(q);
22062265 DPCTLEventVectorRef dep_event_vec_ref = nullptr ;
2207- mt19937_struct* mt19937 = new mt19937_struct ();
2208- mt19937->engine = &DPNP_RNG_ENGINE;
2266+ DPCTLSyclEventRef event_ref = nullptr ;
22092267
2210- DPCTLSyclEventRef event_ref = dpnp_rng_uniform_c<_DataType>(q_ref,
2211- result,
2212- static_cast <double >(low),
2213- static_cast <double >(high),
2214- static_cast <int64_t >(size),
2215- mt19937,
2216- dep_event_vec_ref);
2217- DPCTLEvent_WaitAndThrow (event_ref);
2218- DPCTLEvent_Delete (event_ref);
2219- delete mt19937;
2268+ if (q->get_device ().is_cpu ())
2269+ {
2270+ mt19937_struct* mt19937 = new mt19937_struct ();
2271+ mt19937->engine = &DPNP_RNG_ENGINE;
2272+
2273+ event_ref = dpnp_rng_uniform_c<_DataType>(q_ref,
2274+ result,
2275+ static_cast <double >(low),
2276+ static_cast <double >(high),
2277+ static_cast <int64_t >(size),
2278+ mt19937,
2279+ dep_event_vec_ref);
2280+ DPCTLEvent_WaitAndThrow (event_ref);
2281+ DPCTLEvent_Delete (event_ref);
2282+ delete mt19937;
2283+ }
2284+ else
2285+ {
2286+ // MCG59 engine is assumed to provide a better performance on GPU than MT19937
2287+ mcg59_struct* mcg59 = new mcg59_struct ();
2288+ mcg59->engine = &DPNP_RNG_MCG59_ENGINE;
2289+
2290+ event_ref = dpnp_rng_uniform_c<_DataType>(q_ref,
2291+ result,
2292+ static_cast <double >(low),
2293+ static_cast <double >(high),
2294+ static_cast <int64_t >(size),
2295+ mcg59,
2296+ dep_event_vec_ref);
2297+ DPCTLEvent_WaitAndThrow (event_ref);
2298+ DPCTLEvent_Delete (event_ref);
2299+ delete mcg59;
2300+ }
22202301}
22212302
22222303template <typename _DataType>
0 commit comments