Skip to content
This repository was archived by the owner on May 21, 2022. It is now read-only.

Updated to reflect new LearnBase interface for getobs and ObsDim #50

Open
wants to merge 17 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fixed getobs and nobs for datasubset, deprecated datasubset construct…
…or without indices, fixed some tests
  • Loading branch information
racinmat committed Nov 6, 2021
commit 88c59ebce73a46b9a19b8eb12e5e49c29b65ca1d
68 changes: 60 additions & 8 deletions Manifest.toml
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
# This file is machine-generated - editing it directly is not advised

[[ArgTools]]
uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f"

[[Artifacts]]
uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33"

[[Base64]]
uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"

[[ChainRulesCore]]
deps = ["Compat", "LinearAlgebra", "SparseArrays"]
git-tree-sha1 = "3533f5a691e60601fe60c90d8bc47a27aa2907ec"
git-tree-sha1 = "f885e7e7c124f8c92650d61b9477b9ac2ee607dd"
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
version = "1.11.0"
version = "1.11.1"

[[Compat]]
deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "SHA", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]
Expand Down Expand Up @@ -44,6 +50,10 @@ git-tree-sha1 = "b19534d1895d702889b219c382a6e18010797f0b"
uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
version = "0.8.6"

[[Downloads]]
deps = ["ArgTools", "LibCURL", "NetworkOptions"]
uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6"

[[InteractiveUtils]]
deps = ["Markdown"]
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
Expand All @@ -61,16 +71,28 @@ version = "0.1.1"

[[LearnBase]]
deps = ["StatsBase"]
git-tree-sha1 = "f1b8214972833125cac5c7d52830932600f0ffa9"
git-tree-sha1 = "cf0a6441a65eade5d1786764bc74c77528836e98"
repo-rev = "darsnack/rm-obsdim"
repo-url = "https://github.com/darsnack/LearnBase.jl.git"
uuid = "7f8f8fb0-2700-5f03-b4bd-41f8cfc144b6"
version = "0.5.3"

[[LibCURL]]
deps = ["LibCURL_jll", "MozillaCACerts_jll"]
uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21"

[[LibCURL_jll]]
deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"]
uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0"

[[LibGit2]]
deps = ["Printf"]
deps = ["Base64", "NetworkOptions", "Printf", "SHA"]
uuid = "76f85450-5226-5b5a-8eaa-529ad045b433"

[[LibSSH2_jll]]
deps = ["Artifacts", "Libdl", "MbedTLS_jll"]
uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8"

[[Libdl]]
uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb"

Expand All @@ -89,7 +111,7 @@ uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"

[[MLLabelUtils]]
deps = ["LearnBase", "MappedArrays", "StatsBase"]
git-tree-sha1 = "222ece6e6cc650f2f5f36f37974a08043a1415ba"
git-tree-sha1 = "b07772b9d422ff2bc866dda6493b20c91e2c4114"
repo-rev = "darsnack/refactor"
repo-url = "https://github.com/darsnack/MLLabelUtils.jl.git"
uuid = "66a33bbf-0c2b-5fc8-a008-9da813334f0a"
Expand All @@ -104,6 +126,10 @@ version = "0.4.1"
deps = ["Base64"]
uuid = "d6f4376e-aef5-505a-96c1-9c027394607a"

[[MbedTLS_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1"

[[Missings]]
deps = ["DataAPI"]
git-tree-sha1 = "bf210ce90b6c9eed32d25dbcae1ebc565df2687f"
Expand All @@ -113,21 +139,27 @@ version = "1.0.2"
[[Mmap]]
uuid = "a63ad114-7e13-5084-954f-fe012c677804"

[[MozillaCACerts_jll]]
uuid = "14a3606d-f60d-562e-9121-12d972cd8159"

[[NetworkOptions]]
uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908"

[[OrderedCollections]]
git-tree-sha1 = "85f8e6578bf1f9ee0d11e7bb1b1456435479d47c"
uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
version = "1.4.1"

[[Pkg]]
deps = ["Dates", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"]
deps = ["Artifacts", "Dates", "Downloads", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"]
uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"

[[Printf]]
deps = ["Unicode"]
uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7"

[[REPL]]
deps = ["InteractiveUtils", "Markdown", "Sockets"]
deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"]
uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"

[[Random]]
Expand Down Expand Up @@ -172,8 +204,16 @@ git-tree-sha1 = "eb35dcc66558b2dda84079b9a1be17557d32091a"
uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
version = "0.33.12"

[[TOML]]
deps = ["Dates"]
uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76"

[[Tar]]
deps = ["ArgTools", "SHA"]
uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e"

[[Test]]
deps = ["Distributed", "InteractiveUtils", "Logging", "Random"]
deps = ["InteractiveUtils", "Logging", "Random", "Serialization"]
uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[[UUIDs]]
Expand All @@ -182,3 +222,15 @@ uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"

[[Unicode]]
uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"

[[Zlib_jll]]
deps = ["Libdl"]
uuid = "83775a58-1f1d-513f-b197-d71354ab007a"

[[nghttp2_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d"

[[p7zip_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0"
1 change: 1 addition & 0 deletions src/MLDataPattern.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,5 +73,6 @@ include("stratifiedobs.jl")
include("resample.jl")
include("folds.jl")
include("dataiterator.jl")
include("deprecations.jl")

end # module
2 changes: 1 addition & 1 deletion src/dataiterator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -502,7 +502,7 @@ Base.eltype(::Type{BufferGetObs{E,T}}) where {E,T} = E
Base.IteratorSize(::Type{BufferGetObs{E,T}}) where {E,T} = Base.IteratorSize(T)
Base.length(b::BufferGetObs) = length(b.iter)
Base.size(b::BufferGetObs, I...) = size(b.iter, I...)
StatsBase.nobs(b::BufferGetObs) = nobs(b.iter)
StatsBase.nobs(b::BufferGetObs; obsdim = default_obsdim(b)) = nobs(b.iter; obsdim = obsdim)
batchsize(b::BufferGetObs) = batchsize(b.iter)

function Base.summary(b::BufferGetObs)
Expand Down
13 changes: 7 additions & 6 deletions src/datasubset.jl
Original file line number Diff line number Diff line change
Expand Up @@ -214,18 +214,19 @@ Base.length(subset::DataSubset) = length(subset.indices)

Base.lastindex(subset::DataSubset) = length(subset)

# todo: check this, because it does not seem to make much sense
Base.getindex(subset::DataSubset, idx) =
DataSubset(subset.data, _view(subset.indices, idx), default_obsdim(subset))
DataSubset(subset.data, _view(subset.indices, idx))

LearnBase.default_obsdim(subset::DataSubset) = default_obsdim(subset.data)

LearnBase.nobs(subset::DataSubset) = length(subset)
LearnBase.nobs(subset::DataSubset; obsdim = default_obsdim(subset)) = length(subset)

LearnBase.getobs(subset::DataSubset, idx, obsdim = default_obsdim(subset)) =
getobs(subset.data, _view(subset.indices, idx), obsdim)
LearnBase.getobs(subset::DataSubset, idx; obsdim = default_obsdim(subset)) =
getobs(subset.data, _view(subset.indices, idx); obsdim = obsdim)

LearnBase.getobs!(buffer, subset::DataSubset, idx, obsdim = default_obsdim(subset)) =
getobs!(buffer, subset.data, _view(subset.indices, idx), obsdim)
LearnBase.getobs!(buffer, subset::DataSubset, idx; obsdim = default_obsdim(subset)) =
getobs!(buffer, subset.data, _view(subset.indices, idx); obsdim = obsdim)

# --------------------------------------------------------------------

Expand Down
1 change: 1 addition & 0 deletions src/deprecations.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
@deprecate DataSubset(data::T; obsdim = default_obsdim(data)) where {T} DataSubset(data::T, 1:nobs(data; obsdim = obsdim))
2 changes: 1 addition & 1 deletion test/references/DataSubset1.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
DataSubset(::Array{Float64,2}, ::UnitRange{Int64}, ObsDim.Last())
DataSubset(::Array{Float64,2}, ::UnitRange{Int64})
150 observations
16 changes: 8 additions & 8 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,23 +34,23 @@ Y1 = collect(1:150)
struct EmptyType end

struct CustomType end
StatsBase.nobs(::CustomType) = 100
LearnBase.getobs(::CustomType, i::Int) = i
LearnBase.getobs(::CustomType, i::AbstractVector) = collect(i)
StatsBase.nobs(x::CustomType; obsdim = LearnBase.default_obsdim(x)) = 100
LearnBase.getobs(x::CustomType, i::Int; obsdim = LearnBase.default_obsdim(x)) = i
LearnBase.getobs(x::CustomType, i::AbstractVector; obsdim = LearnBase.default_obsdim(x)) = collect(i)
LearnBase.gettargets(::CustomType, i::Int) = "obs $i"
LearnBase.gettargets(::CustomType, i::AbstractVector) = "batch $i"

struct CustomStorage end
struct CustomObs{T}; data::T end
StatsBase.nobs(::CustomStorage) = 2
LearnBase.getobs(::CustomStorage, i) = CustomObs(i)
StatsBase.nobs(x::CustomStorage; obsdim = LearnBase.default_obsdim(x)) = 2
LearnBase.getobs(x::CustomStorage, i; obsdim = LearnBase.default_obsdim(x)) = CustomObs(i)
LearnBase.gettarget(str::String, obs::CustomObs) = "$str - obs $(obs.data)"
LearnBase.gettarget(obs::CustomObs) = "obs $(obs.data)"

struct ObsDimTriggeredException <: Exception end
struct MetaDataStorage end
StatsBase.nobs(::MetaDataStorage) = 3
LearnBase.getobs(::MetaDataStorage, i) = throw(ObsDimTriggeredException())
StatsBase.nobs(x::MetaDataStorage; obsdim = LearnBase.default_obsdim(x)) = 3
LearnBase.getobs(x::MetaDataStorage, i; obsdim = LearnBase.default_obsdim(x)) = throw(ObsDimTriggeredException())
LearnBase.gettargets(::MetaDataStorage) = "full"
LearnBase.gettargets(::MetaDataStorage, i::Int) = "obs $i"
LearnBase.gettargets(::MetaDataStorage, i::AbstractVector) = "batch $i"
Expand Down Expand Up @@ -84,7 +84,7 @@ end

tests = [
"tst_container.jl"
# "tst_datasubset.jl"
"tst_datasubset.jl"
"tst_randobs.jl"
# "tst_shuffleobs.jl"
# "tst_splitobs.jl"
Expand Down
Loading