Skip to content

Commit

Permalink
allow rand! with explicit SIMD to be used for various dense arrays (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
nsajko authored Jan 21, 2025
1 parent cb55389 commit b70761f
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 4 deletions.
7 changes: 4 additions & 3 deletions stdlib/Random/src/XoshiroSimd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -292,20 +292,21 @@ end
return i
end

const MutableDenseArray = Union{Base.MutableDenseArrayType{T}, UnsafeView{T}} where {T}

function rand!(rng::Union{TaskLocalRNG, Xoshiro}, dst::Array{T}, ::SamplerTrivial{CloseOpen01{T}}) where {T<:Union{Float16,Float32,Float64}}
function rand!(rng::Union{TaskLocalRNG, Xoshiro}, dst::MutableDenseArray{T}, ::SamplerTrivial{CloseOpen01{T}}) where {T<:Union{Float16,Float32,Float64}}
GC.@preserve dst xoshiro_bulk(rng, convert(Ptr{UInt8}, pointer(dst)), length(dst)*sizeof(T), T, xoshiroWidth(), _bits2float)
dst
end

for T in BitInteger_types
@eval function rand!(rng::Union{TaskLocalRNG, Xoshiro}, dst::Union{Array{$T}, UnsafeView{$T}}, ::SamplerType{$T})
@eval function rand!(rng::Union{TaskLocalRNG, Xoshiro}, dst::MutableDenseArray{$T}, ::SamplerType{$T})
GC.@preserve dst xoshiro_bulk(rng, convert(Ptr{UInt8}, pointer(dst)), length(dst)*sizeof($T), UInt8, xoshiroWidth())
dst
end
end

function rand!(rng::Union{TaskLocalRNG, Xoshiro}, dst::Array{Bool}, ::SamplerType{Bool})
function rand!(rng::Union{TaskLocalRNG, Xoshiro}, dst::MutableDenseArray{Bool}, ::SamplerType{Bool})
GC.@preserve dst xoshiro_bulk(rng, convert(Ptr{UInt8}, pointer(dst)), length(dst), Bool, xoshiroWidth())
dst
end
Expand Down
4 changes: 3 additions & 1 deletion stdlib/Random/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -370,9 +370,10 @@ for rng in ([], [MersenneTwister(0)], [RandomDevice()], [Xoshiro()])
a8 = rand!(rng..., GenericArray{T}(undef, 2, 3), cc) ::GenericArray{T, 2}
a9 = rand!(rng..., OffsetArray(Array{T}(undef, 5), 9), cc) ::OffsetArray{T, 1}
a10 = rand!(rng..., OffsetArray(Array{T}(undef, 2, 3), (-2, 4)), cc) ::OffsetArray{T, 2}
a11 = rand!(rng..., Memory{T}(undef, 5), cc) ::Memory{T}
@test size(a1) == (5,)
@test size(a2) == size(a3) == (2, 3)
for a in [a0, a1..., a2..., a3..., a4..., a5..., a6..., a7..., a8..., a9..., a10...]
for a in [a0, a1..., a2..., a3..., a4..., a5..., a6..., a7..., a8..., a9..., a10..., a11...]
if C isa Type
@test a isa C
else
Expand All @@ -392,6 +393,7 @@ for rng in ([], [MersenneTwister(0)], [RandomDevice()], [Xoshiro()])
(T <: Tuple || T <: Pair) && continue
X = T == Bool ? T[0,1] : T[0,1,2]
for A in (Vector{T}(undef, 5),
Memory{T}(undef, 5),
Matrix{T}(undef, 2, 3),
GenericArray{T}(undef, 5),
GenericArray{T}(undef, 2, 3),
Expand Down

0 comments on commit b70761f

Please sign in to comment.