Skip to content

Commit

Permalink
Random: introduce gentype, instead of punning on eltype (#27756)
Browse files Browse the repository at this point in the history
In some cases it makes sense to define what type of value `rand(rng, x)`
will produce, via the newly introduced `gentype(x)`, without having
`eltype(x)` be meaningful.
  • Loading branch information
rfourquet authored Jul 10, 2018
1 parent ddc4908 commit 8a22d90
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 12 deletions.
35 changes: 27 additions & 8 deletions stdlib/Random/src/Random.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,25 @@ export srand,

abstract type AbstractRNG end

"""
Random.gentype(T)
Determine the type of the elements generated by calling `rand([rng], x)`,
where `x::T`, and `x` is not a type.
The definition `gentype(x) = gentype(typeof(x))` is provided for convenience,
and `gentype(T)` defaults to `eltype(T)`.
NOTE: `rand([rng], X)`, where `X` is a type, is always assumed to produce
an object of type `X`.
# Examples
```jldoctest
julia> gentype(1:10)
Int64
```
"""
gentype(::Type{X}) where {X} = eltype(X)
gentype(x) = gentype(typeof(x))


### integers

Expand Down Expand Up @@ -72,7 +91,7 @@ for UI = (:UInt10, :UInt10Raw, :UInt23, :UInt23Raw, :UInt52, :UInt52Raw,
end
end

Base.eltype(::Type{<:UniformBits{T}}) where {T} = T
gentype(::Type{<:UniformBits{T}}) where {T} = T

### floats

Expand All @@ -88,15 +107,15 @@ const CloseOpen12_64 = CloseOpen12{Float64}
CloseOpen01(::Type{T}=Float64) where {T<:AbstractFloat} = CloseOpen01{T}()
CloseOpen12(::Type{T}=Float64) where {T<:AbstractFloat} = CloseOpen12{T}()

Base.eltype(::Type{<:FloatInterval{T}}) where {T<:AbstractFloat} = T
gentype(::Type{<:FloatInterval{T}}) where {T<:AbstractFloat} = T

const BitFloatType = Union{Type{Float16},Type{Float32},Type{Float64}}

### Sampler

abstract type Sampler{E} end

Base.eltype(::Type{<:Sampler{E}}) where {E} = E
gentype(::Type{<:Sampler{E}}) where {E} = E

# temporarily for BaseBenchmarks
RangeGenerator(x) = Sampler(GLOBAL_RNG, x)
Expand Down Expand Up @@ -133,7 +152,7 @@ struct SamplerTrivial{T,E} <: Sampler{E}
self::T
end

SamplerTrivial(x::T) where {T} = SamplerTrivial{T,eltype(T)}(x)
SamplerTrivial(x::T) where {T} = SamplerTrivial{T,gentype(T)}(x)

Sampler(::AbstractRNG, x, ::Repetition) = SamplerTrivial(x)

Expand All @@ -145,14 +164,14 @@ struct SamplerSimple{T,S,E} <: Sampler{E}
data::S
end

SamplerSimple(x::T, data::S) where {T,S} = SamplerSimple{T,S,eltype(T)}(x, data)
SamplerSimple(x::T, data::S) where {T,S} = SamplerSimple{T,S,gentype(T)}(x, data)

Base.getindex(sp::SamplerSimple) = sp.self

# simple sampler carrying a (type) tag T and data
struct SamplerTag{T,S,E} <: Sampler{E}
data::S
SamplerTag{T}(s::S) where {T,S} = new{T,S,eltype(T)}(s)
SamplerTag{T}(s::S) where {T,S} = new{T,S,gentype(T)}(s)
end


Expand Down Expand Up @@ -223,7 +242,7 @@ end
rand(r::AbstractRNG, dims::Integer...) = rand(r, Float64, Dims(dims))
rand( dims::Integer...) = rand(Float64, Dims(dims))

rand(r::AbstractRNG, X, dims::Dims) = rand!(r, Array{eltype(X)}(undef, dims), X)
rand(r::AbstractRNG, X, dims::Dims) = rand!(r, Array{gentype(X)}(undef, dims), X)
rand( X, dims::Dims) = rand(GLOBAL_RNG, X, dims)

rand(r::AbstractRNG, X, d::Integer, dims::Integer...) = rand(r, X, Dims((d, dims...)))
Expand All @@ -232,7 +251,7 @@ rand( X, d::Integer, dims::Integer...) = rand(X, Dims((d, dims...
# rand(r, ()) would match both this method and rand(r, dims::Dims)
# moreover, a call like rand(r, NotImplementedType()) would be an infinite loop

rand(r::AbstractRNG, ::Type{X}, dims::Dims) where {X} = rand!(r, Array{eltype(X)}(undef, dims), X)
rand(r::AbstractRNG, ::Type{X}, dims::Dims) where {X} = rand!(r, Array{X}(undef, dims), X)
rand( ::Type{X}, dims::Dims) where {X} = rand(GLOBAL_RNG, X, dims)

rand(r::AbstractRNG, ::Type{X}, d::Integer, dims::Integer...) where {X} = rand(r, X, Dims((d, dims...)))
Expand Down
8 changes: 4 additions & 4 deletions stdlib/Random/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -681,8 +681,8 @@ end
end
end

@testset "eltype for UniformBits" begin
@test eltype(Random.UInt52()) == UInt64
@test eltype(Random.UInt52(UInt128)) == UInt128
@test eltype(Random.UInt104()) == UInt128
@testset "gentype for UniformBits" begin
@test Random.gentype(Random.UInt52()) == UInt64
@test Random.gentype(Random.UInt52(UInt128)) == UInt128
@test Random.gentype(Random.UInt104()) == UInt128
end

0 comments on commit 8a22d90

Please sign in to comment.