Skip to content

Commit edea2de

Browse files
chethegaJeffBezanson
authored andcommitted
Introduce task-local and free-standing xoshiro RNG
Co-Authored-By: Jeff Bezanson <jeff.bezanson@gmail.com> Co-Authored-by: Rafael Fourquet <fourquet.rafael@gmail.com> - try to use high bits instead of low when we take a subset of them - use shift and multiply instead of mask and subtract for generating floats
1 parent 58ffe7e commit edea2de

File tree

9 files changed

+661
-11
lines changed

9 files changed

+661
-11
lines changed

src/jltypes.c

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2524,7 +2524,7 @@ void jl_init_types(void) JL_GC_DISABLED
25242524
NULL,
25252525
jl_any_type,
25262526
jl_emptysvec,
2527-
jl_perm_symsvec(10,
2527+
jl_perm_symsvec(14,
25282528
"next",
25292529
"queue",
25302530
"storage",
@@ -2534,8 +2534,12 @@ void jl_init_types(void) JL_GC_DISABLED
25342534
"code",
25352535
"_state",
25362536
"sticky",
2537-
"_isexception"),
2538-
jl_svec(10,
2537+
"_isexception",
2538+
"rngState0",
2539+
"rngState1",
2540+
"rngState2",
2541+
"rngState3"),
2542+
jl_svec(14,
25392543
jl_any_type,
25402544
jl_any_type,
25412545
jl_any_type,
@@ -2545,7 +2549,11 @@ void jl_init_types(void) JL_GC_DISABLED
25452549
jl_any_type,
25462550
jl_uint8_type,
25472551
jl_bool_type,
2548-
jl_bool_type),
2552+
jl_bool_type,
2553+
jl_uint64_type,
2554+
jl_uint64_type,
2555+
jl_uint64_type,
2556+
jl_uint64_type),
25492557
0, 1, 6);
25502558
jl_value_t *listt = jl_new_struct(jl_uniontype_type, jl_task_type, jl_nothing_type);
25512559
jl_svecset(jl_task_type->types, 0, listt);

src/julia.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1805,6 +1805,10 @@ typedef struct _jl_task_t {
18051805
uint8_t _state;
18061806
uint8_t sticky; // record whether this Task can be migrated to a new thread
18071807
uint8_t _isexception; // set if `result` is an exception to throw or that we exited with
1808+
uint64_t rngState0; // really rngState[4], but more convenient to split
1809+
uint64_t rngState1;
1810+
uint64_t rngState2;
1811+
uint64_t rngState3;
18081812

18091813
// hidden state:
18101814
// saved gc stack top for context switches

src/task.c

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
#include "julia_internal.h"
3636
#include "threading.h"
3737
#include "julia_assert.h"
38+
#include "support/hashing.h"
3839

3940
#ifdef __cplusplus
4041
extern "C" {
@@ -648,6 +649,63 @@ JL_DLLEXPORT void jl_rethrow_other(jl_value_t *e JL_MAYBE_UNROOTED)
648649
throw_internal(ct, NULL);
649650
}
650651

652+
/* This is xoshiro256++ 1.0, used for tasklocal random number generation in julia.
653+
This implementation is intended for embedders and internal use by the runtime, and is
654+
based on the reference implementation on http://prng.di.unimi.it
655+
656+
Credits go to Sebastiano Vigna for coming up with this PRNG.
657+
658+
There is a pure julia implementation in stdlib that tends to be faster when used from
659+
within julia, due to inlining and more agressive architecture-specific optimizations.
660+
*/
661+
JL_DLLEXPORT uint64_t jl_tasklocal_genrandom(jl_task_t *task) JL_NOTSAFEPOINT
662+
{
663+
uint64_t s0 = task->rngState0;
664+
uint64_t s1 = task->rngState1;
665+
uint64_t s2 = task->rngState2;
666+
uint64_t s3 = task->rngState3;
667+
668+
uint64_t t = s0 << 17;
669+
uint64_t tmp = s0 + s3;
670+
uint64_t res = ((tmp << 23) | (tmp >> 41)) + s0;
671+
s2 ^= s0;
672+
s3 ^= s1;
673+
s1 ^= s2;
674+
s0 ^= s3;
675+
s2 ^= t;
676+
s3 = (s3 << 45) | (s3 >> 19);
677+
678+
task->rngState0 = s0;
679+
task->rngState1 = s1;
680+
task->rngState2 = s2;
681+
task->rngState3 = s3;
682+
return res;
683+
}
684+
685+
void rng_split(jl_task_t *from, jl_task_t *to) JL_NOTSAFEPOINT
686+
{
687+
/* TODO: consider a less ad-hoc construction
688+
Ideally we could just use the output of the random stream to seed the initial
689+
state of the child. Out of an overabundance of caution we multiply with
690+
effectively random coefficients, to break possible self-interactions.
691+
692+
It is not the goal to mix bits -- we work under the assumption that the
693+
source is well-seeded, and its output looks effectively random.
694+
However, xoshiro has never been studied in the mode where we seed the
695+
initial state with the output of another xoshiro instance.
696+
697+
Constants have nothing up their sleeve:
698+
0x02011ce34bce797f == hash(UInt(1))|0x01
699+
0x5a94851fb48a6e05 == hash(UInt(2))|0x01
700+
0x3688cf5d48899fa7 == hash(UInt(3))|0x01
701+
0x867b4bb4c42e5661 == hash(UInt(4))|0x01
702+
*/
703+
to->rngState0 = 0x02011ce34bce797f * jl_tasklocal_genrandom(from);
704+
to->rngState1 = 0x5a94851fb48a6e05 * jl_tasklocal_genrandom(from);
705+
to->rngState2 = 0x3688cf5d48899fa7 * jl_tasklocal_genrandom(from);
706+
to->rngState3 = 0x867b4bb4c42e5661 * jl_tasklocal_genrandom(from);
707+
}
708+
651709
JL_DLLEXPORT jl_task_t *jl_new_task(jl_function_t *start, jl_value_t *completion_future, size_t ssize)
652710
{
653711
jl_task_t *ct = jl_current_task;
@@ -683,6 +741,8 @@ JL_DLLEXPORT jl_task_t *jl_new_task(jl_function_t *start, jl_value_t *completion
683741
t->_isexception = 0;
684742
// Inherit logger state from parent task
685743
t->logstate = ct->logstate;
744+
// Fork task-local random state from parent
745+
rng_split(ct, t);
686746
// there is no active exception handler available on this stack yet
687747
t->eh = NULL;
688748
t->sticky = 1;

stdlib/Random/src/RNGs.jl

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,6 @@
22

33
## RandomDevice
44

5-
# SamplerUnion(X, Y, ...}) == Union{SamplerType{X}, SamplerType{Y}, ...}
6-
SamplerUnion(U...) = Union{Any[SamplerType{T} for T in U]...}
7-
const SamplerBoolBitInteger = SamplerUnion(Bool, BitInteger_types...)
8-
95
if Sys.iswindows()
106
struct RandomDevice <: AbstractRNG
117
buffer::Vector{UInt128}
@@ -382,6 +378,7 @@ end
382378

383379
function __init__()
384380
resize!(empty!(THREAD_RNGs), Threads.nthreads()) # ensures that we didn't save a bad object
381+
seed!(TaskLocalRNG())
385382
end
386383

387384

stdlib/Random/src/Random.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ export rand!, randn!,
2727
shuffle, shuffle!,
2828
randperm, randperm!,
2929
randcycle, randcycle!,
30-
AbstractRNG, MersenneTwister, RandomDevice
30+
AbstractRNG, MersenneTwister, RandomDevice, TaskLocalRNG, Xoshiro
3131

3232
## general definitions
3333

@@ -291,11 +291,17 @@ rand( ::Type{X}, dims::Dims) where {X} = rand(default_rng(), X, d
291291
rand(r::AbstractRNG, ::Type{X}, d::Integer, dims::Integer...) where {X} = rand(r, X, Dims((d, dims...)))
292292
rand( ::Type{X}, d::Integer, dims::Integer...) where {X} = rand(X, Dims((d, dims...)))
293293

294+
# SamplerUnion(X, Y, ...}) == Union{SamplerType{X}, SamplerType{Y}, ...}
295+
SamplerUnion(U...) = Union{Any[SamplerType{T} for T in U]...}
296+
const SamplerBoolBitInteger = SamplerUnion(Bool, BitInteger_types...)
294297

298+
299+
include("Xoshiro.jl")
295300
include("RNGs.jl")
296301
include("generation.jl")
297302
include("normal.jl")
298303
include("misc.jl")
304+
include("XoshiroSimd.jl")
299305

300306
## rand & rand! & seed! docstrings
301307

0 commit comments

Comments
 (0)