Skip to content

Commit d09abe5

Browse files
authored
[Random] Add more comments and a helper function in Xoshiro code (#56144)
Follow up to #55994 and #55997. This should basically be a non-functional change and I see no performance difference, but the comments and the definition of a helper function should make the code easier to follow (I initially struggled in #55997) and extend to other types.
1 parent 9f92989 commit d09abe5

File tree

2 files changed

+26
-16
lines changed

2 files changed

+26
-16
lines changed

stdlib/Random/src/Xoshiro.jl

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -296,11 +296,16 @@ rand(r::Union{TaskLocalRNG, Xoshiro}, ::SamplerTrivial{UInt52Raw{UInt64}}) = ran
296296
rand(r::Union{TaskLocalRNG, Xoshiro}, ::SamplerTrivial{UInt52{UInt64}}) = rand(r, UInt64) >>> 12
297297
rand(r::Union{TaskLocalRNG, Xoshiro}, ::SamplerTrivial{UInt104{UInt128}}) = rand(r, UInt104Raw())
298298

299-
rand(r::Union{TaskLocalRNG, Xoshiro}, ::SamplerTrivial{CloseOpen01{Float16}}) =
300-
Float16(rand(r, UInt16) >>> 5) * Float16(0x1.0p-11)
301-
302-
rand(r::Union{TaskLocalRNG, Xoshiro}, ::SamplerTrivial{CloseOpen01{Float32}}) =
303-
Float32(rand(r, UInt32) >>> 8) * Float32(0x1.0p-24)
304-
305-
rand(r::Union{TaskLocalRNG, Xoshiro}, ::SamplerTrivial{CloseOpen01_64}) =
306-
Float64(rand(r, UInt64) >>> 11) * 0x1.0p-53
299+
for FT in (Float16, Float32, Float64)
300+
UT = Base.uinttype(FT)
301+
# Helper function: scale an unsigned integer to a floating point number of the same size
302+
# in the interval [0, 1). This is equivalent to, but more easily extensible than
303+
# Float16(i >>> 5) * Float16(0x1.0p-11)
304+
# Float32(i >>> 8) * Float32(0x1.0p-24)
305+
# Float32(i >>> 11) * Float64(0x1.0p-53)
306+
@eval @inline _uint2float(i::$(UT), ::Type{$(FT)}) =
307+
$(FT)(i >>> $(8 * sizeof(FT) - precision(FT))) * $(FT(2) ^ -precision(FT))
308+
309+
@eval rand(r::Union{TaskLocalRNG, Xoshiro}, ::SamplerTrivial{CloseOpen01{$(FT)}}) =
310+
_uint2float(rand(r, $(UT)), $(FT))
311+
end

stdlib/Random/src/XoshiroSimd.jl

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
module XoshiroSimd
44
# Getting the xoroshiro RNG to reliably vectorize is somewhat of a hassle without Simd.jl.
55
import ..Random: rand!
6-
using ..Random: TaskLocalRNG, rand, Xoshiro, CloseOpen01, UnsafeView, SamplerType, SamplerTrivial, getstate, setstate!
6+
using ..Random: TaskLocalRNG, rand, Xoshiro, CloseOpen01, UnsafeView, SamplerType, SamplerTrivial, getstate, setstate!, _uint2float
77
using Base: BitInteger_types
88
using Base.Libc: memcpy
99
using Core.Intrinsics: llvmcall
@@ -30,7 +30,12 @@ simdThreshold(::Type{Bool}) = 640
3030
Tuple{UInt64, Int64},
3131
x, y)
3232

33-
@inline _bits2float(x::UInt64, ::Type{Float64}) = reinterpret(UInt64, Float64(x >>> 11) * 0x1.0p-53)
33+
# `_bits2float(x::UInt64, T)` takes `x::UInt64` as input, it splits it in `N` parts where
34+
# `N = sizeof(UInt64) / sizeof(T)` (`N = 1` for `Float64`, `N = 2` for `Float32, etc...), it
35+
# truncates each part to the unsigned type of the same size as `T`, scales all of these
36+
# numbers to a value of type `T` in the range [0,1) with `_uint2float`, and then
37+
# recomposes another `UInt64` using all these parts.
38+
@inline _bits2float(x::UInt64, ::Type{Float64}) = reinterpret(UInt64, _uint2float(x, Float64))
3439
@inline function _bits2float(x::UInt64, ::Type{Float32})
3540
#=
3641
# this implementation uses more high bits, but is harder to vectorize
@@ -40,19 +45,19 @@ simdThreshold(::Type{Bool}) = 640
4045
=#
4146
ui = (x>>>32) % UInt32
4247
li = x % UInt32
43-
u = Float32(ui >>> 8) * Float32(0x1.0p-24)
44-
l = Float32(li >>> 8) * Float32(0x1.0p-24)
48+
u = _uint2float(ui, Float32)
49+
l = _uint2float(ui, Float32)
4550
(UInt64(reinterpret(UInt32, u)) << 32) | UInt64(reinterpret(UInt32, l))
4651
end
4752
@inline function _bits2float(x::UInt64, ::Type{Float16})
4853
i1 = (x>>>48) % UInt16
4954
i2 = (x>>>32) % UInt16
5055
i3 = (x>>>16) % UInt16
5156
i4 = x % UInt16
52-
f1 = Float16(i1 >>> 5) * Float16(0x1.0p-11)
53-
f2 = Float16(i2 >>> 5) * Float16(0x1.0p-11)
54-
f3 = Float16(i3 >>> 5) * Float16(0x1.0p-11)
55-
f4 = Float16(i4 >>> 5) * Float16(0x1.0p-11)
57+
f1 = _uint2float(i1, Float16)
58+
f2 = _uint2float(i2, Float16)
59+
f3 = _uint2float(i3, Float16)
60+
f4 = _uint2float(i4, Float16)
5661
return (UInt64(reinterpret(UInt16, f1)) << 48) | (UInt64(reinterpret(UInt16, f2)) << 32) | (UInt64(reinterpret(UInt16, f3)) << 16) | UInt64(reinterpret(UInt16, f4))
5762
end
5863

0 commit comments

Comments
 (0)