Skip to content

Commit ab992b9

Browse files
authored
MersenneTwister: hash seeds like for Xoshiro (#51436)
This addresses a part of #37165: > It's common that sequential seeds for RNGs are not as independent as one might like. This clears out this problem for `MersenneTwister`, and makes it easy to add the same feature to other RNGs via a new `hash_seed` function, which replaces `make_seed`. This is an alternative to #37766.
1 parent 3a85776 commit ab992b9

File tree

9 files changed

+170
-182
lines changed

9 files changed

+170
-182
lines changed

stdlib/Random/docs/src/index.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -126,8 +126,8 @@ Random.SamplerSimple
126126
Decoupling pre-computation from actually generating the values is part of the API, and is also available to the user. As an example, assume that `rand(rng, 1:20)` has to be called repeatedly in a loop: the way to take advantage of this decoupling is as follows:
127127

128128
```julia
129-
rng = MersenneTwister()
130-
sp = Random.Sampler(rng, 1:20) # or Random.Sampler(MersenneTwister, 1:20)
129+
rng = Xoshiro()
130+
sp = Random.Sampler(rng, 1:20) # or Random.Sampler(Xoshiro, 1:20)
131131
for x in X
132132
n = rand(rng, sp) # similar to n = rand(rng, 1:20)
133133
# use n
@@ -159,8 +159,8 @@ Scalar and array methods for `Die` now work as expected:
159159
julia> rand(Die)
160160
Die(5)
161161
162-
julia> rand(MersenneTwister(0), Die)
163-
Die(11)
162+
julia> rand(Xoshiro(0), Die)
163+
Die(10)
164164
165165
julia> rand(Die, 3)
166166
3-element Vector{Die}:

stdlib/Random/src/DSFMT.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,8 @@ function dsfmt_init_gen_rand(s::DSFMT_state, seed::UInt32)
6565
s.val, seed)
6666
end
6767

68-
function dsfmt_init_by_array(s::DSFMT_state, seed::Vector{UInt32})
68+
function dsfmt_init_by_array(s::DSFMT_state, seed::StridedVector{UInt32})
69+
strides(seed) == (1,) || throw(ArgumentError("seed must have its stride equal to 1"))
6970
ccall((:dsfmt_init_by_array,:libdSFMT),
7071
Cvoid,
7172
(Ptr{Cvoid}, Ptr{UInt32}, Int32),

stdlib/Random/src/RNGs.jl

Lines changed: 68 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ The entropy is obtained from the operating system.
1212
"""
1313
struct RandomDevice <: AbstractRNG; end
1414
RandomDevice(seed::Nothing) = RandomDevice()
15-
seed!(rng::RandomDevice) = rng
15+
seed!(rng::RandomDevice, ::Nothing) = rng
1616

1717
rand(rd::RandomDevice, sp::SamplerBoolBitInteger) = Libc.getrandom!(Ref{sp[]}())[]
1818
rand(rd::RandomDevice, ::SamplerType{Bool}) = rand(rd, UInt8) % Bool
@@ -44,7 +44,7 @@ const MT_CACHE_I = 501 << 4 # number of bytes in the UInt128 cache
4444
@assert dsfmt_get_min_array_size() <= MT_CACHE_F
4545

4646
mutable struct MersenneTwister <: AbstractRNG
47-
seed::Vector{UInt32}
47+
seed::Any
4848
state::DSFMT_state
4949
vals::Vector{Float64}
5050
ints::Vector{UInt128}
@@ -70,7 +70,7 @@ mutable struct MersenneTwister <: AbstractRNG
7070
end
7171
end
7272

73-
MersenneTwister(seed::Vector{UInt32}, state::DSFMT_state) =
73+
MersenneTwister(seed, state::DSFMT_state) =
7474
MersenneTwister(seed, state,
7575
Vector{Float64}(undef, MT_CACHE_F),
7676
Vector{UInt128}(undef, MT_CACHE_I >> 4),
@@ -92,19 +92,17 @@ See the [`seed!`](@ref) function for reseeding an already existing `MersenneTwis
9292
9393
# Examples
9494
```jldoctest
95-
julia> rng = MersenneTwister(1234);
95+
julia> rng = MersenneTwister(123);
9696
9797
julia> x1 = rand(rng, 2)
9898
2-element Vector{Float64}:
99-
0.5908446386657102
100-
0.7667970365022592
99+
0.37453777969575874
100+
0.8735343642013971
101101
102-
julia> rng = MersenneTwister(1234);
103-
104-
julia> x2 = rand(rng, 2)
102+
julia> x2 = rand(MersenneTwister(123), 2)
105103
2-element Vector{Float64}:
106-
0.5908446386657102
107-
0.7667970365022592
104+
0.37453777969575874
105+
0.8735343642013971
108106
109107
julia> x1 == x2
110108
true
@@ -115,7 +113,7 @@ MersenneTwister(seed=nothing) =
115113

116114

117115
function copy!(dst::MersenneTwister, src::MersenneTwister)
118-
copyto!(resize!(dst.seed, length(src.seed)), src.seed)
116+
dst.seed = src.seed
119117
copy!(dst.state, src.state)
120118
copyto!(dst.vals, src.vals)
121119
copyto!(dst.ints, src.ints)
@@ -129,7 +127,7 @@ function copy!(dst::MersenneTwister, src::MersenneTwister)
129127
end
130128

131129
copy(src::MersenneTwister) =
132-
MersenneTwister(copy(src.seed), copy(src.state), copy(src.vals), copy(src.ints),
130+
MersenneTwister(src.seed, copy(src.state), copy(src.vals), copy(src.ints),
133131
src.idxF, src.idxI, src.adv, src.adv_jump, src.adv_vals, src.adv_ints)
134132

135133

@@ -144,12 +142,10 @@ hash(r::MersenneTwister, h::UInt) =
144142

145143
function show(io::IO, rng::MersenneTwister)
146144
# seed
147-
seed = from_seed(rng.seed)
148-
seed_str = seed <= typemax(Int) ? string(seed) : "0x" * string(seed, base=16) # DWIM
149145
if rng.adv_jump == 0 && rng.adv == 0
150-
return print(io, MersenneTwister, "(", seed_str, ")")
146+
return print(io, MersenneTwister, "(", repr(rng.seed), ")")
151147
end
152-
print(io, MersenneTwister, "(", seed_str, ", (")
148+
print(io, MersenneTwister, "(", repr(rng.seed), ", (")
153149
# state
154150
adv = Integer[rng.adv_jump, rng.adv]
155151
if rng.adv_vals != -1 || rng.adv_ints != -1
@@ -277,76 +273,84 @@ end
277273

278274
### seeding
279275

280-
#### make_seed()
276+
#### random_seed() & hash_seed()
281277

282-
# make_seed produces values of type Vector{UInt32}, suitable for MersenneTwister seeding
283-
function make_seed()
278+
# random_seed tries to produce a random seed of type UInt128 from system entropy
279+
function random_seed()
284280
try
285-
return rand(RandomDevice(), UInt32, 4)
281+
# as MersenneTwister prints its seed when `show`ed, 128 bits is a good compromise for
282+
# almost surely always getting distinct seeds, while having them printed reasonably tersely
283+
return rand(RandomDevice(), UInt128)
286284
catch ex
287285
ex isa IOError || rethrow()
288286
@warn "Entropy pool not available to seed RNG; using ad-hoc entropy sources."
289-
return make_seed(Libc.rand())
287+
return Libc.rand()
290288
end
291289
end
292290

293-
"""
294-
make_seed(n::Integer) -> Vector{UInt32}
295-
296-
Transform `n` into a bit pattern encoded as a `Vector{UInt32}`, suitable for
297-
RNG seeding routines.
298-
299-
`make_seed` is "injective" : if `n != m`, then `make_seed(n) != `make_seed(m)`.
300-
Moreover, if `n == m`, then `make_seed(n) == make_seed(m)`.
301-
302-
This is an internal function, subject to change.
303-
"""
304-
function make_seed(n::Integer)
305-
neg = signbit(n)
291+
function hash_seed(seed::Integer)
292+
ctx = SHA.SHA2_256_CTX()
293+
neg = signbit(seed)
306294
if neg
307-
n = ~n
308-
end
309-
@assert n >= 0
310-
seed = UInt32[]
311-
# we directly encode the bit pattern of `n` into the resulting vector `seed`;
312-
# to greatly limit breaking the streams of random numbers, we encode the sign bit
313-
# as the upper bit of `seed[end]` (i.e. for most positive seeds, `make_seed` returns
314-
# the same vector as when we didn't encode the sign bit)
315-
while !iszero(n)
316-
push!(seed, n & 0xffffffff)
317-
n >>>= 32
295+
seed = ~seed
318296
end
319-
if isempty(seed) || !iszero(seed[end] & 0x80000000)
320-
push!(seed, zero(UInt32))
321-
end
322-
if neg
323-
seed[end] |= 0x80000000
297+
@assert seed >= 0
298+
while true
299+
word = (seed % UInt32) & 0xffffffff
300+
seed >>>= 32
301+
SHA.update!(ctx, reinterpret(NTuple{4, UInt8}, word))
302+
iszero(seed) && break
324303
end
325-
seed
304+
# make sure the hash of negative numbers is different from the hash of positive numbers
305+
neg && SHA.update!(ctx, (0x01,))
306+
SHA.digest!(ctx)
326307
end
327308

328-
# inverse of make_seed(::Integer)
329-
function from_seed(a::Vector{UInt32})::BigInt
330-
neg = !iszero(a[end] & 0x80000000)
331-
seed = sum((i == length(a) ? a[i] & 0x7fffffff : a[i]) * big(2)^(32*(i-1))
332-
for i in 1:length(a))
333-
neg ? ~seed : seed
309+
function hash_seed(seed::Union{AbstractArray{UInt32}, AbstractArray{UInt64}})
310+
ctx = SHA.SHA2_256_CTX()
311+
for xx in seed
312+
SHA.update!(ctx, reinterpret(NTuple{8, UInt8}, UInt64(xx)))
313+
end
314+
# discriminate from hash_seed(::Integer)
315+
SHA.update!(ctx, (0x10,))
316+
SHA.digest!(ctx)
334317
end
335318

336319

320+
"""
321+
hash_seed(seed) -> AbstractVector{UInt8}
322+
323+
Return a cryptographic hash of `seed` of size 256 bits (32 bytes).
324+
`seed` can currently be of type `Union{Integer, DenseArray{UInt32}, DenseArray{UInt64}}`,
325+
but modules can extend this function for types they own.
326+
327+
`hash_seed` is "injective" : if `n != m`, then `hash_seed(n) != `hash_seed(m)`.
328+
Moreover, if `n == m`, then `hash_seed(n) == hash_seed(m)`.
329+
330+
This is an internal function subject to change.
331+
"""
332+
hash_seed
333+
337334
#### seed!()
338335

339-
function seed!(r::MersenneTwister, seed::Vector{UInt32})
340-
copyto!(resize!(r.seed, length(seed)), seed)
341-
dsfmt_init_by_array(r.state, r.seed)
336+
function initstate!(r::MersenneTwister, data::StridedVector, seed)
337+
# we deepcopy `seed` because the caller might mutate it, and it's useful
338+
# to keep it constant inside `MersenneTwister`; but multiple instances
339+
# can share the same seed without any problem (e.g. in `copy`)
340+
r.seed = deepcopy(seed)
341+
dsfmt_init_by_array(r.state, reinterpret(UInt32, data))
342342
reset_caches!(r)
343343
r.adv = 0
344344
r.adv_jump = 0
345345
return r
346346
end
347347

348-
seed!(r::MersenneTwister) = seed!(r, make_seed())
349-
seed!(r::MersenneTwister, n::Integer) = seed!(r, make_seed(n))
348+
# when a seed is not provided, we generate one via `RandomDevice()` in `random_seed()` rather
349+
# than calling directly `initstate!` with `rand(RandomDevice(), UInt32, whatever)` because the
350+
# seed is printed in `show(::MersenneTwister)`, so we need one; the cost of `hash_seed` is a
351+
# small overhead compared to `initstate!`, so this simple solution is fine
352+
seed!(r::MersenneTwister, ::Nothing) = seed!(r, random_seed())
353+
seed!(r::MersenneTwister, seed) = initstate!(r, hash_seed(seed), seed)
350354

351355

352356
### Global RNG
@@ -713,7 +717,7 @@ end
713717
function _randjump(r::MersenneTwister, jumppoly::DSFMT.GF2X)
714718
adv = r.adv
715719
adv_jump = r.adv_jump
716-
s = MersenneTwister(copy(r.seed), DSFMT.dsfmt_jump(r.state, jumppoly))
720+
s = MersenneTwister(r.seed, DSFMT.dsfmt_jump(r.state, jumppoly))
717721
reset_caches!(s)
718722
s.adv = adv
719723
s.adv_jump = adv_jump

stdlib/Random/src/Random.jl

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -356,8 +356,8 @@ julia> rand(Int, 2)
356356
357357
julia> using Random
358358
359-
julia> rand(MersenneTwister(0), Dict(1=>2, 3=>4))
360-
1=>2
359+
julia> rand(Xoshiro(0), Dict(1=>2, 3=>4))
360+
3 => 4
361361
362362
julia> rand((2, 3))
363363
3
@@ -389,15 +389,13 @@ but without allocating a new array.
389389
390390
# Examples
391391
```jldoctest
392-
julia> rng = MersenneTwister(1234);
393-
394-
julia> rand!(rng, zeros(5))
392+
julia> rand!(Xoshiro(123), zeros(5))
395393
5-element Vector{Float64}:
396-
0.5908446386657102
397-
0.7667970365022592
398-
0.5662374165061859
399-
0.4600853424625171
400-
0.7940257103317943
394+
0.521213795535383
395+
0.5868067574533484
396+
0.8908786980927811
397+
0.19090669902576285
398+
0.5256623915420473
401399
```
402400
"""
403401
rand!
@@ -452,6 +450,11 @@ julia> rand(Xoshiro(), Bool) # not reproducible either
452450
true
453451
```
454452
"""
455-
seed!(rng::AbstractRNG, ::Nothing) = seed!(rng)
453+
seed!(rng::AbstractRNG) = seed!(rng, nothing)
454+
#=
455+
We have this generic definition instead of the alternative option
456+
`seed!(rng::AbstractRNG, ::Nothing) = seed!(rng)`
457+
because it would lead too easily to ambiguities, e.g. when we define `seed!(::Xoshiro, seed)`.
458+
=#
456459

457460
end # module

stdlib/Random/src/Xoshiro.jl

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -230,16 +230,20 @@ rng_native_52(::TaskLocalRNG) = UInt64
230230
## Shared implementation between Xoshiro and TaskLocalRNG
231231

232232
# this variant of setstate! initializes the internal splitmix state, a.k.a. `s4`
233-
@inline initstate!(x::Union{TaskLocalRNG, Xoshiro}, (s0, s1, s2, s3)::NTuple{4, UInt64}) =
233+
@inline function initstate!(x::Union{TaskLocalRNG, Xoshiro}, state)
234+
length(state) == 4 && eltype(state) == UInt64 ||
235+
throw(ArgumentError("initstate! expects a list of 4 `UInt64` values"))
236+
s0, s1, s2, s3 = state
234237
setstate!(x, (s0, s1, s2, s3, 1s0 + 3s1 + 5s2 + 7s3))
238+
end
235239

236240
copy(rng::Union{TaskLocalRNG, Xoshiro}) = Xoshiro(getstate(rng)...)
237241
copy!(dst::Union{TaskLocalRNG, Xoshiro}, src::Union{TaskLocalRNG, Xoshiro}) = setstate!(dst, getstate(src))
238242
==(x::Union{TaskLocalRNG, Xoshiro}, y::Union{TaskLocalRNG, Xoshiro}) = getstate(x) == getstate(y)
239243
# use a magic (random) number to scramble `h` so that `hash(x)` is distinct from `hash(getstate(x))`
240244
hash(x::Union{TaskLocalRNG, Xoshiro}, h::UInt) = hash(getstate(x), h + 0x49a62c2dda6fa9be % UInt)
241245

242-
function seed!(rng::Union{TaskLocalRNG, Xoshiro})
246+
function seed!(rng::Union{TaskLocalRNG, Xoshiro}, ::Nothing)
243247
# as we get good randomness from RandomDevice, we can skip hashing
244248
rd = RandomDevice()
245249
s0 = rand(rd, UInt64)
@@ -249,14 +253,9 @@ function seed!(rng::Union{TaskLocalRNG, Xoshiro})
249253
initstate!(rng, (s0, s1, s2, s3))
250254
end
251255

252-
function seed!(rng::Union{TaskLocalRNG, Xoshiro}, seed::Union{Vector{UInt32}, Vector{UInt64}})
253-
c = SHA.SHA2_256_CTX()
254-
SHA.update!(c, reinterpret(UInt8, seed))
255-
s0, s1, s2, s3 = reinterpret(UInt64, SHA.digest!(c))
256-
initstate!(rng, (s0, s1, s2, s3))
257-
end
256+
seed!(rng::Union{TaskLocalRNG, Xoshiro}, seed) =
257+
initstate!(rng, reinterpret(UInt64, hash_seed(seed)))
258258

259-
seed!(rng::Union{TaskLocalRNG, Xoshiro}, seed::Integer) = seed!(rng, make_seed(seed))
260259

261260
@inline function rand(x::Union{TaskLocalRNG, Xoshiro}, ::SamplerType{UInt64})
262261
s0, s1, s2, s3 = getstate(x)

0 commit comments

Comments
 (0)