Skip to content
This repository was archived by the owner on May 23, 2022. It is now read-only.

add getobs and nobs common implementations #51

Merged
merged 6 commits into from
Jul 30, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/LearnBase.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
module LearnBase

import StatsBase
using StatsBase: nobs

# AGGREGATION MODES
Expand Down
58 changes: 54 additions & 4 deletions src/observation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@ Specify the default observation dimension for `data`.
Falls back to `nothing` when an observation dimension is undefined.

By default, the following implementations are provided:
- `default_obsdim(A::AbstractArray) = ndims(A)`
- `default_obsdim(tup::Tuple) = map(default_obsdim, tup)`
```julia
default_obsdim(x::nothing) = nothing
default_obsdim(x::AbstractArray) = ndims(x)
````
"""
default_obsdim(data) = nothing
default_obsdim(A::AbstractArray) = ndims(A)
default_obsdim(tup::Tuple) = map(default_obsdim, tup)
default_obsdim(A::AbstractArray{T,N}) where {T,N} = N

"""
getobs(data, idx; obsdim = default_obsdim(data))
Expand Down Expand Up @@ -74,6 +75,24 @@ getobs(dataset, 1:2) # -> (X[:,1:2], Y[1:2])
"""
function getobs end

function getobs(data::AbstractArray{T,N}, idx; obsdim::Union{Int,Nothing}=nothing) where {T, N}
od = obsdim === nothing ? default_obsdim(data) : obsdim
_idx = ntuple(i -> i == od ? idx : Colon(), N)
data[_idx...]
end

function getobs(data::Union{Tuple, NamedTuple}, i; obsdim::Union{Int,Nothing}=default_obsdim(data))
# We don't force users to handle the obsdim keyword if not necessary.
fobs = obsdim === nothing ? Base.Fix2(getobs, i) : x -> getobs(x, i; obsdim=obsdim)
map(fobs, data)
end

function getobs(data::D, i; obsdim::Union{Int,Nothing}=default_obsdim(data)) where {D<:AbstractDict}
fobs = obsdim === nothing ? Base.Fix2(getobs, i) : x -> getobs(x, i; obsdim=obsdim)
# Cannot return D because the value type can change
Dict(k => fobs(v) for (k, v) in pairs(data))
end

"""
getobs!(buffer, data, idx; obsdim = default_obsdim(obsdim))

Expand Down Expand Up @@ -123,6 +142,37 @@ to disptach on which dimension of `data` denotes the observations.
"""
function datasubset end


# We don't own nobs but pirate it for basic types
"""
nobs(data; [obsdim])

Return the number of observations in the dataset `data`.

If it makes sense for the type of `data`, `obsdim` can be used
to indicate which dimension of `data` denotes the observations.
See [`default_obsdim`](@ref) for defining a default dimension.
"""
function StatsBase.nobs(data::AbstractArray; obsdim::Union{Int,Nothing}=nothing)
od = obsdim === nothing ? default_obsdim(data) : obsdim
size(data, od)
end

function StatsBase.nobs(data::Union{Tuple, NamedTuple, AbstractDict}; obsdim::Union{Int,Nothing} = default_obsdim(data))
length(data) > 0 || throw(ArgumentError("Need at least one data input"))

# We don't force users to handle the obsdim
# keyword if not necessary.
fnobs = obsdim === nothing ? nobs : x -> nobs(x; obsdim=obsdim)

n = fnobs(data[first(keys(data))])
for i in keys(data)
ni = fnobs(data[i])
n == ni || throw(DimensionMismatch("All data inputs should have the same number of observations, i.e. size in the last dimension. "))
end
return n
end

# todeprecate
function target end
function gettarget end
102 changes: 71 additions & 31 deletions test/observation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,38 +5,78 @@ using LearnBase: getobs, nobs, default_obsdim
@test typeof(LearnBase.gettargets) <: Function
@test typeof(LearnBase.datasubset) <: Function

@testset "getobs" begin

function LearnBase.getobs(x::AbstractArray{T,N}, idx; obsdim=default_obsdim(x)) where {T,N}
_idx = ntuple(i-> i == obsdim ? idx : Colon(), N)
return x[_idx...]
@testset "getobs and nobs" begin

@testset "array" begin
a = rand(2,3)
@test nobs(a) == 3
@test @inferred getobs(a, 1) == a[:,1]
@test @inferred getobs(a, 2) == a[:,2]
@test @inferred getobs(a, 1:2) == a[:,1:2]
@test @inferred getobs(a, 1, obsdim=1) == a[1,:]
@test @inferred getobs(a, 2, obsdim=1) == a[2,:]
@test @inferred getobs(a, 2, obsdim=nothing) ≈ a[:,2]
end

@testset "tuple" begin
# A dataset with 3 observations, each with 2 input features
X, Y = rand(2, 3), rand(3)
dataset = (X, Y)
@test nobs(dataset) == 3
if VERSION >= v"1.6"
o = @inferred getobs(dataset, 2)
else
o = getobs(dataset, 2)
end
@test o[1] == X[:,2]
@test o[2] == Y[2]

if VERSION >= v"1.6"
o = @inferred getobs(dataset, 1:2)
else
o = getobs(dataset, 1:2)
end

@test o[1] == X[:,1:2]
@test o[2] == Y[1:2]
end


@testset "named tuple" begin
X, Y = rand(2, 3), rand(3)
dataset = (x=X, y=Y)
@test nobs(dataset) == 3
if VERSION >= v"1.6"
o = @inferred getobs(dataset, 2)
else
o = getobs(dataset, 2)
end
@test o.x == X[:,2]
@test o.y == Y[2]

if VERSION >= v"1.6"
o = @inferred getobs(dataset, 1:2)
else
o = getobs(dataset, 1:2)
end
@test o.x == X[:,1:2]
@test o.y == Y[1:2]
end

@testset "dict" begin
X, Y = rand(2, 3), rand(3)
dataset = Dict("X" => X, "Y" => Y)
@test nobs(dataset) == 3

# o = @inferred getobs(dataset, 2) # not inferred
o = getobs(dataset, 2)
@test o["X"] == X[:,2]
@test o["Y"] == Y[2]

o = getobs(dataset, 1:2)
@test o["X"] == X[:,1:2]
@test o["Y"] == Y[1:2]
end
LearnBase.nobs(x::AbstractArray; obsdim=default_obsdim(x)) = size(x, obsdim)

a = rand(2,3)
@test nobs(a) == 3
@test getobs(a, 1) ≈ a[:,1]
@test getobs(a, 2) ≈ a[:,2]
@test getobs(a, 1, obsdim=1) ≈ a[1,:]
@test getobs(a, 2, obsdim=1) ≈ a[2,:]

# Here we use Ref to protect idx against broadcasting
LearnBase.getobs(t::Tuple, idx) = getobs.(t, Ref(idx))
# Assume all elements have the same nummber of observations.
# It would be safer to check explicitely though.
LearnBase.nobs(t::Tuple) = nobs(t[1])

# A dataset with 3 observations, each with 2 input features
X, Y = rand(2, 3), rand(3)
dataset = (X, Y)

o = getobs(dataset, 2) # -> (X[:,2], Y[2])
@test o[1] ≈ X[:,2]
@test o[2] == Y[2]

o = getobs(dataset, 1:2) # -> (X[:,1:2], Y[1:2])
@test o[1] ≈ X[:,1:2]
@test o[2] == Y[1:2]
end