Skip to content

Commit 66147d0

Browse files
Tasks: don't advance task RNG on task spawn
Previously we had this unfortunate behavior: julia> Random.seed!(123) TaskLocalRNG() julia> randn() -0.6457306721039767 julia> Random.seed!(123) TaskLocalRNG() julia> fetch(@async nothing) julia> randn() 0.4922456865251828 In other words: the mere act of spawning a child task affects the parent task's RNG (by advancing it four times). This PR preserves the desirable parts of the previous situation: when seeded, the parent and child RNG streams are reproducible. Moreover, it fixes the undesirable behavior: julia> Random.seed!(123) TaskLocalRNG() julia> randn() -0.6457306721039767 julia> Random.seed!(123) TaskLocalRNG() julia> fetch(@async nothing) julia> randn() -0.6457306721039767 In other words: the parent RNG is unaffected by spawning a child. The design is documented in detail in a comment preceding the jl_rng_split function.
1 parent 6b934f9 commit 66147d0

File tree

7 files changed

+227
-38
lines changed

7 files changed

+227
-38
lines changed

base/sysimg.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ let
2727
task.rngState1 = 0x7431eaead385992c
2828
task.rngState2 = 0x503e1d32781c2608
2929
task.rngState3 = 0x3a77f7189200c20b
30+
task.rngState4 = 0x5502376d099035ae
3031

3132
# Stdlibs sorted in dependency, then alphabetical, order by contrib/print_sorted_stdlibs.jl
3233
# Run with the `--exclude-jlls` option to filter out all JLL packages

src/gc.c

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -382,9 +382,9 @@ static void jl_gc_run_finalizers_in_list(jl_task_t *ct, arraylist_t *list) JL_NO
382382
ct->sticky = sticky;
383383
}
384384

385-
static uint64_t finalizer_rngState[4];
385+
static uint64_t finalizer_rngState[JL_RNG_SIZE];
386386

387-
void jl_rng_split(uint64_t to[4], uint64_t from[4]) JL_NOTSAFEPOINT;
387+
void jl_rng_split(uint64_t dst[JL_RNG_SIZE], uint64_t src[JL_RNG_SIZE]) JL_NOTSAFEPOINT;
388388

389389
JL_DLLEXPORT void jl_gc_init_finalizer_rng_state(void)
390390
{
@@ -413,7 +413,7 @@ static void run_finalizers(jl_task_t *ct)
413413
jl_atomic_store_relaxed(&jl_gc_have_pending_finalizers, 0);
414414
arraylist_new(&to_finalize, 0);
415415

416-
uint64_t save_rngState[4];
416+
uint64_t save_rngState[JL_RNG_SIZE];
417417
memcpy(&save_rngState[0], &ct->rngState[0], sizeof(save_rngState));
418418
jl_rng_split(ct->rngState, finalizer_rngState);
419419

src/jltypes.c

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2768,7 +2768,7 @@ void jl_init_types(void) JL_GC_DISABLED
27682768
NULL,
27692769
jl_any_type,
27702770
jl_emptysvec,
2771-
jl_perm_symsvec(15,
2771+
jl_perm_symsvec(16,
27722772
"next",
27732773
"queue",
27742774
"storage",
@@ -2780,11 +2780,12 @@ void jl_init_types(void) JL_GC_DISABLED
27802780
"rngState1",
27812781
"rngState2",
27822782
"rngState3",
2783+
"rngState4",
27832784
"_state",
27842785
"sticky",
27852786
"_isexception",
27862787
"priority"),
2787-
jl_svec(15,
2788+
jl_svec(16,
27882789
jl_any_type,
27892790
jl_any_type,
27902791
jl_any_type,
@@ -2796,6 +2797,7 @@ void jl_init_types(void) JL_GC_DISABLED
27962797
jl_uint64_type,
27972798
jl_uint64_type,
27982799
jl_uint64_type,
2800+
jl_uint64_type,
27992801
jl_uint8_type,
28002802
jl_bool_type,
28012803
jl_bool_type,

src/julia.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1910,6 +1910,8 @@ typedef struct _jl_handler_t {
19101910
size_t world_age;
19111911
} jl_handler_t;
19121912

1913+
#define JL_RNG_SIZE 5 // xoshiro 4 + splitmix 1
1914+
19131915
typedef struct _jl_task_t {
19141916
JL_DATA_TYPE
19151917
jl_value_t *next; // invasive linked list for scheduler
@@ -1921,7 +1923,7 @@ typedef struct _jl_task_t {
19211923
jl_function_t *start;
19221924
// 4 byte padding on 32-bit systems
19231925
// uint32_t padding0;
1924-
uint64_t rngState[4];
1926+
uint64_t rngState[JL_RNG_SIZE];
19251927
_Atomic(uint8_t) _state;
19261928
uint8_t sticky; // record whether this Task can be migrated to a new thread
19271929
_Atomic(uint8_t) _isexception; // set if `result` is an exception to throw or that we exited with

src/task.c

Lines changed: 153 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -866,28 +866,160 @@ uint64_t jl_genrandom(uint64_t rngState[4]) JL_NOTSAFEPOINT
866866
return res;
867867
}
868868

869-
void jl_rng_split(uint64_t to[4], uint64_t from[4]) JL_NOTSAFEPOINT
869+
/*
870+
The jl_rng_split function forks a tasks RNG state in a way that is essentially
871+
guaranteed to avoid collisions between the RNG streams of all forked tasks. The
872+
main RNG is the xoshiro256++ RNG whose state is stored in rngState[0..3]. There
873+
is a small internal RNG used for task forking stored in rngState[4]. This state
874+
is a LCG (linear congruential generator), which is put through four different
875+
variations of the strongest PCG output function, referred to as PCG-RXS-M-XS-64.
876+
This output function is invertible: it maps a 64-bit state to 64-bit output, so
877+
it's not recommended for general purpose RNG usage. In our usage, however, the
878+
invertability is actually a benefit, and we only use the RNG output internally.
879+
880+
The goal of this function is to perturb the state of each child task's RNG in
881+
such a way each for an entire tree of tasks spawned starting with a given seed
882+
in a root task, no two tasks have the same RNG state. Moreover, we want to do
883+
this in a way that is deterministic and repeatable based the root task's seed
884+
and the task tree strucutre. The RNG state of a parent task is allowed to alter
885+
the RNG state of a child task. The mere fact that a child was spawned should not
886+
alter the RNG output of the parent, but, of course, children spawned after that
887+
should have distinct RNG states from previously spawned children.
888+
889+
The basic approach is that used by the DotMix [1] and SplitMix [2] systems: each
890+
task is uniquely identified by a sequence of "pedigree" numbers, indicating
891+
where in the task tree it was spawned. This vector of pedigree coordinates is
892+
then reduced to a single value by computing a dot product with a common vector
893+
of random weights. The DotMix paper provides a proof that this dot product hash
894+
value (referred to as a "compression function") is collision resistant in the
895+
sense the the pairwise collision probability of two distinct tasks is 1/N where
896+
N is the number of possible weight values. Both DotMix and SplitMix use a prime
897+
value of N because the proof reqires that the difference between two distinct
898+
pedigree coordinates must be invertible, which is guaranteed by N being prime.
899+
We take a different approach, however---we limit pedigree coordinates to being
900+
binary instead: when a task spawns a child, both tasks share the same pedigree
901+
prefix, which the parent appending a zero and the child appending a one. This
902+
way a binary vector uniquely identifies each task. Since the coordinates are
903+
binary, the difference between coordinates in the proof can be taken to always
904+
be one, which must be invertible, regardless of whether N is prime or not. This
905+
allows us to compute the dot product using native machine arithmetic, modulo
906+
2^64 instead of arithmetic in a prime modulus. It also means that when updating
907+
the dot product incrementally, as described in SplitMix, we don't need to
908+
multiply weights by anything, since the weight is always zero in the parent (no
909+
change) and one in the child, which simply entails adding the weight.
910+
911+
We use the internal LCG maintained in rngState[4] to generate random weights:
912+
each time a child is forked, we update the LCG in both parent and child tasks.
913+
In the parent, that's all we do; the main RNG state is unchanged, but the next
914+
time the parent forks a child, the Dot/SplitMix weight used will be different,
915+
corresponding to being a level deeper in the binary task tree. In the child, we
916+
use the LCG state to generate four pseduoranodm 64-bit weights (more below) and
917+
add each weight to one of the xoshiro256 state registers, rngState[0..3]. If we
918+
assume the main RNG remains unused in all tasks, each register rngState[0..3]
919+
accumulates a different Dot/SplitMix dot product hash as additional child tasks
920+
are spawned. Each one is collision resistant with a pairwise collision chance of
921+
only 1/2^64. Assuming that the four pseduoranodm 64-bit weight streams are
922+
sufficiently independent, the pairwise collision probability for distinct tasks
923+
is 1/2^256. If we somehow managed to spawn a quadrillion tasks, the probability
924+
of a collision would be on the order of 1/10^48. Practically impossible.
925+
926+
What about the random "junk" that's in the xoshiro256 state registers? For a
927+
tree of tasks spawned with no intervining samples taken from the main RNG, they
928+
all start with the same junk which doesn't affect the chance of collision; the
929+
Dot/SplitMix papers suggest adding a random base value to the dot product
930+
anyway, so we can consider whatever happens to be in the xoshiro256 registers to
931+
be that. What if the main RNG is used betweeen task forks? In that case, the
932+
state registers bits are "shuffled" according to the xoshiro256 update
933+
implemented in jl_genrandom above. The unmodified DotMix collision resistance
934+
proof obviously doesn't apply then, but we can modify the setup by adding a
935+
constant difference between the two compression functions and note that we still
936+
have a 1/N change of the weight value hitting that exact difference. This proves
937+
collision resistance even between tasks whose dot product hashes are computed
938+
with arbitrary offsets. Thus we can conclude collision resistance even in the
939+
face of different starting states of the main RNG. Does this seem too good to be
940+
true? Perhaps another way of thiking of it will help: suppose we seeded each
941+
task randomly? Then there would only be a 1/2^256 chance of collision as well.
942+
So essentially what the proof is telling us is that the dot product construction
943+
is a good way to randomly seed each task. From that perspective, adding
944+
arbitrary junk to each random seed doesn't worsen (or improve) its randomness.
945+
946+
The random weights added to rngState[0..3] in successive child tasks are
947+
generated by applying four different variations on the PCG-RXS-M-XS-64 output
948+
function to the same 64-bit LCG state. Another obvious way to generate four
949+
weights would be to iterate the LCG four times per child task split. A reason
950+
not to do that is that the LCG update is highly linear and there is a risk that
951+
if the weights are linearly related, they will not provide independent collision
952+
resistance and instead of a pairwise collision probability of 1/2^256. The PCG
953+
output function is designed to obfuscate linear relationships between outputs
954+
and does so quite well, as PCG-RXS-M-XS manaages to pass various statistical RNG
955+
tests with only 36 bits of state, let alone the 64 bits we're using. Different
956+
output functions seems like a better way to expand a single state into four
957+
streams. It also means that the full period of the LCG is available to each
958+
rngState[0..3] register, rather than just 2^60. Since collision resistance is
959+
proportional to the number of possible weights, this is a benefit. It's an
960+
obvious concern to worry about whether the approach of using different output
961+
functions produces weights that are independent enough to provide full collision
962+
resistance. We obviously can't test that with 256 bits, but we have tested it
963+
with a reduced state analogue, using and 8-bit LCG and four variations on the
964+
PCG-RXS-M-XS-8 output function to generate four 8-bit dot products. This test
965+
does indicate sufficient independence: one register has collisions at 2^5 while
966+
four registers only start having collisions at 2^20, which is what we'd expect
967+
if they were truly independent.
968+
969+
It may also be worth noting that in the specific case where a parent task spawns
970+
a sequence of child tasks with no intervening usage of its main RNG, then the
971+
parent and child tasks are actually guaranteed to have different RNG states.
972+
states. This is true because each of the four PCG streams produces each possible
973+
2^64 bit output exactly once in the full 2^64 period of the LCG generator. Thus,
974+
each of up to 2^64 children will be perturbed by different weights. But what
975+
about the parent colliding with a child? That can only happen if each of the
976+
rngState[0..3] registers is perturbed by zero, which cannot happen. Consider
977+
this part of each output function:
978+
979+
p ^= p >> ((p >> 59) + 5);
980+
p *= m[i];
981+
p ^= p >> 43
982+
983+
It's easy to check that this maps zero to zero. Thus, if the different `p`
984+
values are zero in the end, then they all had to be zero at the beginning, which
985+
is impossible since they each differ from `x` by different additive constants.
986+
Of course, this doesn't help if the task tree structure is more deeply nested or
987+
if there are intervinging uses of the main RNG, in which case we're back to
988+
relying on "merely" 256 bits of collision resistance, but it's nice to know that
989+
in what is likely the most common case RNG collisions are actually impossible.
990+
991+
[1]: http://supertech.csail.mit.edu/papers/dprng.pdf
992+
993+
[2]: https://gee.cs.oswego.edu/dl/papers/oopsla14.pdf
994+
*/
995+
void jl_rng_split(uint64_t dst[JL_RNG_SIZE], uint64_t src[JL_RNG_SIZE])
870996
{
871-
/* TODO: consider a less ad-hoc construction
872-
Ideally we could just use the output of the random stream to seed the initial
873-
state of the child. Out of an overabundance of caution we multiply with
874-
effectively random coefficients, to break possible self-interactions.
875-
876-
It is not the goal to mix bits -- we work under the assumption that the
877-
source is well-seeded, and its output looks effectively random.
878-
However, xoshiro has never been studied in the mode where we seed the
879-
initial state with the output of another xoshiro instance.
880-
881-
Constants have nothing up their sleeve:
882-
0x02011ce34bce797f == hash(UInt(1))|0x01
883-
0x5a94851fb48a6e05 == hash(UInt(2))|0x01
884-
0x3688cf5d48899fa7 == hash(UInt(3))|0x01
885-
0x867b4bb4c42e5661 == hash(UInt(4))|0x01
886-
*/
887-
to[0] = 0x02011ce34bce797f * jl_genrandom(from);
888-
to[1] = 0x5a94851fb48a6e05 * jl_genrandom(from);
889-
to[2] = 0x3688cf5d48899fa7 * jl_genrandom(from);
890-
to[3] = 0x867b4bb4c42e5661 * jl_genrandom(from);
997+
// load and advance the internal LCG state
998+
uint64_t x = src[4];
999+
src[4] = dst[4] = x * 0xd1342543de82ef95 + 1;
1000+
// high spectrum multiplier from https://arxiv.org/abs/2001.05304
1001+
1002+
static const uint64_t a[4] = {
1003+
0xe5f8fa077b92a8a8, // random additive offsets...
1004+
0x7a0cd918958c124d,
1005+
0x86222f7d388588d4,
1006+
0xd30cbd35f2b64f52
1007+
};
1008+
static const uint64_t m[4] = {
1009+
0xaef17502108ef2d9, // standard PCG multiplier
1010+
0xf34026eeb86766af, // random odd multipliers...
1011+
0x38fd70ad58dd9fbb,
1012+
0x6677f9b93ab0c04d
1013+
};
1014+
1015+
// PCG-RXS-M-XS output with four variants
1016+
for (int i = 0; i < 4; i++) {
1017+
uint64_t p = x + a[i];
1018+
p ^= p >> ((p >> 59) + 5);
1019+
p *= m[i];
1020+
p ^= p >> 43;
1021+
dst[i] = src[i] + p; // SplitMix dot product
1022+
}
8911023
}
8921024

8931025
JL_DLLEXPORT jl_task_t *jl_new_task(jl_function_t *start, jl_value_t *completion_future, size_t ssize)

stdlib/Random/src/Xoshiro.jl

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -113,12 +113,17 @@ struct TaskLocalRNG <: AbstractRNG end
113113
TaskLocalRNG(::Nothing) = TaskLocalRNG()
114114
rng_native_52(::TaskLocalRNG) = UInt64
115115

116-
function setstate!(x::TaskLocalRNG, s0::UInt64, s1::UInt64, s2::UInt64, s3::UInt64)
116+
function setstate!(
117+
x::TaskLocalRNG,
118+
s0::UInt64, s1::UInt64, s2::UInt64, s3::UInt64, # xoshiro256 state
119+
s4::UInt64 = hash((s0, s1, s2, s3)), # splitmix weight rng state
120+
)
117121
t = current_task()
118122
t.rngState0 = s0
119123
t.rngState1 = s1
120124
t.rngState2 = s2
121125
t.rngState3 = s3
126+
t.rngState4 = s4
122127
x
123128
end
124129

@@ -128,11 +133,11 @@ end
128133
tmp = s0 + s3
129134
res = ((tmp << 23) | (tmp >> 41)) + s0
130135
t = s1 << 17
131-
s2 = xor(s2, s0)
132-
s3 = xor(s3, s1)
133-
s1 = xor(s1, s2)
134-
s0 = xor(s0, s3)
135-
s2 = xor(s2, t)
136+
s2 ⊻= s0
137+
s3 ⊻= s1
138+
s1 ⊻= s2
139+
s0 ⊻= s3
140+
s2 ⊻= t
136141
s3 = s3 << 45 | s3 >> 19
137142
task.rngState0, task.rngState1, task.rngState2, task.rngState3 = s0, s1, s2, s3
138143
res
@@ -159,7 +164,7 @@ seed!(rng::Union{TaskLocalRNG, Xoshiro}, seed::Integer) = seed!(rng, make_seed(s
159164
@inline function rand(rng::Union{TaskLocalRNG, Xoshiro}, ::SamplerType{UInt128})
160165
first = rand(rng, UInt64)
161166
second = rand(rng,UInt64)
162-
second + UInt128(first)<<64
167+
second + UInt128(first) << 64
163168
end
164169

165170
@inline rand(rng::Union{TaskLocalRNG, Xoshiro}, ::SamplerType{Int128}) = rand(rng, UInt128) % Int128
@@ -178,14 +183,14 @@ end
178183

179184
function copy!(dst::TaskLocalRNG, src::Xoshiro)
180185
t = current_task()
181-
t.rngState0, t.rngState1, t.rngState2, t.rngState3 = src.s0, src.s1, src.s2, src.s3
182-
dst
186+
setstate!(dst, src.s0, src.s1, src.s2, src.s3)
187+
return dst
183188
end
184189

185190
function copy!(dst::Xoshiro, src::TaskLocalRNG)
186191
t = current_task()
187-
dst.s0, dst.s1, dst.s2, dst.s3 = t.rngState0, t.rngState1, t.rngState2, t.rngState3
188-
dst
192+
setstate!(dst, t.rngState0, t.rngState1, t.rngState2, t.rngState3)
193+
return dst
189194
end
190195

191196
function ==(a::Xoshiro, b::TaskLocalRNG)

stdlib/Random/test/runtests.jl

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1018,3 +1018,50 @@ guardseed() do
10181018
@test f42752(true) === val
10191019
end
10201020
end
1021+
1022+
@testset "TaskLocalRNG: stream collision smoke test" begin
1023+
# spawn a trinary tree of tasks:
1024+
# - spawn three recursive child tasks in each
1025+
# - generate a random UInt64 in each before, after and between
1026+
# - collect and count all the generated random values
1027+
# these should all be distinct across all tasks
1028+
function gen(d)
1029+
r = rand(UInt64)
1030+
vals = [r]
1031+
if d 0
1032+
append!(vals, gent(d - 1))
1033+
isodd(r) && append!(vals, gent(d - 1))
1034+
push!(vals, rand(UInt64))
1035+
iseven(r) && append!(vals, gent(d - 1))
1036+
end
1037+
push!(vals, rand(UInt64))
1038+
end
1039+
gent(d) = fetch(@async gen(d))
1040+
seeds = rand(RandomDevice(), UInt64, 5)
1041+
for seed in seeds
1042+
Random.seed!(seed)
1043+
vals = gen(6)
1044+
@test allunique(vals)
1045+
end
1046+
end
1047+
1048+
@testset "TaskLocalRNG: child doesn't affect parent" begin
1049+
seeds = rand(RandomDevice(), UInt64, 5)
1050+
for seed in seeds
1051+
Random.seed!(seed)
1052+
x = rand(UInt64)
1053+
y = rand(UInt64)
1054+
n = 3
1055+
for i = 1:n
1056+
Random.seed!(seed)
1057+
@sync for j = 0:i
1058+
@async rand(UInt64)
1059+
end
1060+
@test x == rand(UInt64)
1061+
@sync for j = 0:(n-i)
1062+
@async rand(UInt64)
1063+
end
1064+
@test y == rand(UInt64)
1065+
end
1066+
end
1067+
end

0 commit comments

Comments
 (0)