diff --git a/NEWS.md b/NEWS.md index a9db7cfa58..54959f6902 100644 --- a/NEWS.md +++ b/NEWS.md @@ -7,6 +7,8 @@ been removed in favour of MLDatasets.jl. * `flatten` is not exported anymore due to clash with Iterators.flatten. * Remove Juno.jl progress bar support as it is now obsolete. * `Dropout` gained improved compatibility with Int and Complex arrays and is now twice-differentiable. +* Many utily functions and the `DataLoader` are [now provided by MLUtils.jl](https://github.com/FluxML/Flux.jl/pull/1874). +* The DataLoader is now compatible with generic dataset types implementing `MLUtils.numobs` and `MLUtils.getobs`. ## v0.12.10 * `Dropout`/`AlphaDropout` now supports [user-specified RNGs](https://github.com/FluxML/Flux.jl/pull/1838) diff --git a/Project.toml b/Project.toml index bc7bbec4b2..6695e22267 100644 --- a/Project.toml +++ b/Project.toml @@ -8,6 +8,7 @@ ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" NNlibCUDA = "a00861dc-f156-4864-bf3c-e6376f28a68d" @@ -25,6 +26,7 @@ Adapt = "3.0" ArrayInterface = "3.1, 4" CUDA = "3" Functors = "0.2.1" +MLUtils = "0.1.4" MacroTools = "0.5" NNlib = "0.8.2" NNlibCUDA = "0.2" diff --git a/docs/src/models/nnlib.md b/docs/src/models/nnlib.md index 7102b449d0..3ad07e1883 100644 --- a/docs/src/models/nnlib.md +++ b/docs/src/models/nnlib.md @@ -76,7 +76,10 @@ NNlib.gather NNlib.gather! NNlib.scatter NNlib.scatter! +``` + +## Miscellaneous -## Utilities +```@docs NNlib.logsumexp ``` diff --git a/docs/src/utilities.md b/docs/src/utilities.md index de05b887df..5878813a39 100644 --- a/docs/src/utilities.md +++ b/docs/src/utilities.md @@ -7,15 +7,17 @@ callback functions. ## Working with Data +Utilities for data processing are provided by [MLUtils.jl](https://github.com/JuliaML/MLUtils.jl). Below is a non-exhaustive list. + ```@docs -Flux.unsqueeze -Flux.stack -Flux.unstack -Flux.chunk -Flux.frequencies -Flux.batch -Flux.unbatch -Flux.batchseq +MLUtils.unsqueeze +MLUtils.stack +MLUtils.unstack +MLUtils.chunk +MLUtils.group_counts +MLUtils.batch +MLUtils.unbatch +MLUtils.batchseq Base.rpad(v::AbstractVector, n::Integer, p) ``` diff --git a/src/Flux.jl b/src/Flux.jl index 7b7c545b0f..f7696b6549 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -7,6 +7,9 @@ using Statistics, Random, LinearAlgebra using Zygote, MacroTools, ProgressLogging, Reexport using MacroTools: @forward @reexport using NNlib + +using MLUtils + using Zygote: Params, @adjoint, gradient, pullback, @nograd export gradient @@ -50,6 +53,7 @@ include("outputsize.jl") include("data/Data.jl") using .Data + include("losses/Losses.jl") using .Losses # TODO: stop importing Losses in Flux's namespace in v0.12 diff --git a/src/data/Data.jl b/src/data/Data.jl index cb3a073969..70ab68b534 100644 --- a/src/data/Data.jl +++ b/src/data/Data.jl @@ -1,9 +1,6 @@ module Data -using Random: shuffle! -using Base: @propagate_inbounds - -include("dataloader.jl") +using MLUtils export DataLoader end#module diff --git a/src/data/dataloader.jl b/src/data/dataloader.jl deleted file mode 100644 index 422747776c..0000000000 --- a/src/data/dataloader.jl +++ /dev/null @@ -1,121 +0,0 @@ -# Adapted from Knet's src/data.jl (author: Deniz Yuret) -using Random: AbstractRNG, shuffle!, GLOBAL_RNG - -struct DataLoader{D,R<:AbstractRNG} - data::D - batchsize::Int - nobs::Int - partial::Bool - imax::Int - indices::Vector{Int} - shuffle::Bool - rng::R -end - -""" - Flux.DataLoader(data; batchsize=1, shuffle=false, partial=true, rng=GLOBAL_RNG) - -An object that iterates over mini-batches of `data`, -each mini-batch containing `batchsize` observations -(except possibly the last one). - -Takes as input a single data tensor, or a tuple (or a named tuple) of tensors. -The last dimension in each tensor is the observation dimension, i.e. the one -divided into mini-batches. - -If `shuffle=true`, it shuffles the observations each time iterations are re-started. -If `partial=false` and the number of observations is not divisible by the batchsize, -then the last mini-batch is dropped. - -The original data is preserved in the `data` field of the DataLoader. - -# Examples -```jldoctest -julia> Xtrain = rand(10, 100); - -julia> array_loader = Flux.DataLoader(Xtrain, batchsize=2); - -julia> for x in array_loader - @assert size(x) == (10, 2) - # do something with x, 50 times - end - -julia> array_loader.data === Xtrain -true - -julia> tuple_loader = Flux.DataLoader((Xtrain,), batchsize=2); # similar, but yielding 1-element tuples - -julia> for x in tuple_loader - @assert x isa Tuple{Matrix} - @assert size(x[1]) == (10, 2) - end - -julia> Ytrain = rand('a':'z', 100); # now make a DataLoader yielding 2-element named tuples - -julia> train_loader = Flux.DataLoader((data=Xtrain, label=Ytrain), batchsize=5, shuffle=true); - -julia> for epoch in 1:100 - for (x, y) in train_loader # access via tuple destructuring - @assert size(x) == (10, 5) - @assert size(y) == (5,) - # loss += f(x, y) # etc, runs 100 * 20 times - end - end - -julia> first(train_loader).label isa Vector{Char} # access via property name -true - -julia> first(train_loader).label == Ytrain[1:5] # because of shuffle=true -false - -julia> foreach(println∘summary, Flux.DataLoader(rand(Int8, 10, 64), batchsize=30)) # partial=false would omit last -10×30 Matrix{Int8} -10×30 Matrix{Int8} -10×4 Matrix{Int8} -``` -""" -function DataLoader(data; batchsize=1, shuffle=false, partial=true, rng=GLOBAL_RNG) - batchsize > 0 || throw(ArgumentError("Need positive batchsize")) - - n = _nobs(data) - if n < batchsize - @warn "Number of observations less than batchsize, decreasing the batchsize to $n" - batchsize = n - end - imax = partial ? n : n - batchsize + 1 - DataLoader(data, batchsize, n, partial, imax, [1:n;], shuffle, rng) -end - -@propagate_inbounds function Base.iterate(d::DataLoader, i=0) # returns data in d.indices[i+1:i+batchsize] - i >= d.imax && return nothing - if d.shuffle && i == 0 - shuffle!(d.rng, d.indices) - end - nexti = min(i + d.batchsize, d.nobs) - ids = d.indices[i+1:nexti] - batch = _getobs(d.data, ids) - return (batch, nexti) -end - -function Base.length(d::DataLoader) - n = d.nobs / d.batchsize - d.partial ? ceil(Int,n) : floor(Int,n) -end - -_nobs(data::AbstractArray) = size(data)[end] - -function _nobs(data::Union{Tuple, NamedTuple}) - length(data) > 0 || throw(ArgumentError("Need at least one data input")) - n = _nobs(data[1]) - for i in keys(data) - ni = _nobs(data[i]) - n == ni || throw(DimensionMismatch("All data inputs should have the same number of observations, i.e. size in the last dimension. " * - "But data[$(repr(first(keys(data))))] ($(summary(data[1]))) has $n, while data[$(repr(i))] ($(summary(data[i]))) has $ni.")) - end - return n -end - -_getobs(data::AbstractArray, i) = data[ntuple(i -> Colon(), Val(ndims(data) - 1))..., i] -_getobs(data::Union{Tuple, NamedTuple}, i) = map(Base.Fix2(_getobs, i), data) - -Base.eltype(::DataLoader{D}) where D = D diff --git a/src/data/tree.jl b/src/data/tree.jl deleted file mode 100644 index 3a1c8fca80..0000000000 --- a/src/data/tree.jl +++ /dev/null @@ -1,35 +0,0 @@ -using AbstractTrees - -struct Tree{T} - value::T - children::Vector{Tree{T}} -end - -Tree{T}(x::T, xs::Tree{T}...) where T = Tree{T}(x, [xs...]) -Tree{T}(x) where T = Tree(convert(T, x)) - -Tree(x::T, xs::Tree{T}...) where T = Tree{T}(x, xs...) - -AbstractTrees.children(t::Tree) = t.children -AbstractTrees.printnode(io::IO, t::Tree) = show(io, t.value) - -Base.show(io::IO, t::Type{Tree}) = print(io, "Tree") -Base.show(io::IO, t::Type{Tree{T}}) where T = print(io, "Tree{", @isdefined(T) ? T : :T, "}") - -function Base.show(io::IO, t::Tree) - println(io, typeof(t)) - print_tree(io, t) -end - -Base.getindex(t::Tree, i::Integer) = t.children[i] -Base.getindex(t::Tree, i::Integer, is::Integer...) = t[i][is...] - -# Utilities - -isleaf(t) = isempty(children(t)) - -leaves(xs::Tree) = map(x -> x.value, Leaves(xs)) - -Base.map(f, t::Tree, ts::Tree...) = - Tree{Any}(f(map(t -> t.value, (t, ts...))...), - [map(f, chs...) for chs in zip(map(t -> t.children, (t, ts...))...)]...) diff --git a/src/deprecations.jl b/src/deprecations.jl index 8bda5ef284..3998f25028 100644 --- a/src/deprecations.jl +++ b/src/deprecations.jl @@ -16,3 +16,4 @@ ones32(::Type, dims...) = throw(ArgumentError("Flux.ones32 is always Float32, us zeros32(::Type, dims...) = throw(ArgumentError("Flux.zeros32 is always Float32, use Base.zeros to specify the element type")) # v0.13 deprecations +@deprecate frequencies(xs) group_counts(xs) diff --git a/src/onehot.jl b/src/onehot.jl index fc79872c9f..36345438a3 100644 --- a/src/onehot.jl +++ b/src/onehot.jl @@ -87,7 +87,7 @@ Base.hcat(x::T, xs::T...) where {L, T <: OneHotLike{<:Any, L, <:Any, 2}} = Base.hcat(x::T, xs::T...) where {L, T <: OneHotLike{<:Any, L, <:Any, 1}} = OneHotMatrix(reduce(vcat, _indices.(xs); init = _indices(x)), L) -batch(xs::AbstractArray{<:OneHotVector{<:Any, L}}) where L = OneHotArray(_indices.(xs), L) +MLUtils.batch(xs::AbstractArray{<:OneHotVector{<:Any, L}}) where L = OneHotMatrix(_indices.(xs), L) Adapt.adapt_structure(T, x::OneHotArray{<:Any, L}) where L = OneHotArray(adapt(T, _indices(x)), L) diff --git a/src/utils.jl b/src/utils.jl index 035798b5c0..67ca92f597 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -408,246 +408,6 @@ function create_bias(weights::AbstractArray, bias::AbstractArray, dims::Integer. bias end -""" - unsqueeze(xs, dim) - -Return `xs` reshaped into an array one dimensionality higher than `xs`, -where `dim` indicates in which dimension `xs` is extended. - -See also [`flatten`](@ref), [`stack`](@ref). - -# Examples -```jldoctest -julia> Flux.unsqueeze([1 2; 3 4], 2) -2×1×2 Array{Int64, 3}: -[:, :, 1] = - 1 - 3 - -[:, :, 2] = - 2 - 4 - -julia> xs = [[1, 2], [3, 4], [5, 6]] -3-element Vector{Vector{Int64}}: - [1, 2] - [3, 4] - [5, 6] - -julia> Flux.unsqueeze(xs, 1) -1×3 Matrix{Vector{Int64}}: - [1, 2] [3, 4] [5, 6] -``` -""" -function unsqueeze(xs::AbstractArray, dim::Integer) - sz = ntuple(i -> i < dim ? size(xs, i) : i == dim ? 1 : size(xs, i - 1), ndims(xs) + 1) - return reshape(xs, sz) -end - -""" - unsqueeze(dim) - -Returns a function which, acting on an array, inserts a dimension of size 1 at `dim`. - -# Examples -```jldoctest -julia> rand(21, 22, 23) |> Flux.unsqueeze(2) |> size -(21, 1, 22, 23) - -julia> m = Chain(Flux.unsqueeze(3), Flux.unsqueeze(4), Conv((3,3), 1=>7, pad=SamePad())); - -julia> rand(Float32, 10, 10) |> m |> size -(10, 10, 7, 1) -``` -""" -unsqueeze(dim::Integer) = Base.Fix2(unsqueeze, dim) - -Base.show_function(io::IO, u::Base.Fix2{typeof(unsqueeze)}, ::Bool) = print(io, "unsqueeze(", u.x, ")") - -""" - stack(xs, dim) - -Concatenate the given `Array` of `Array`s `xs` into a single `Array` along the -given dimension `dim`. - -# Examples -```jldoctest -julia> xs = [[1, 2], [3, 4], [5, 6]] -3-element Vector{Vector{Int64}}: - [1, 2] - [3, 4] - [5, 6] - -julia> Flux.stack(xs, 1) -3×2 Matrix{Int64}: - 1 2 - 3 4 - 5 6 - -julia> cat(xs, dims=1) -3-element Vector{Vector{Int64}}: - [1, 2] - [3, 4] - [5, 6] -``` -""" -stack(xs, dim) = cat(unsqueeze.(xs, dim)..., dims=dim) - -""" - unstack(xs, dim) - -Unroll the given `xs` into an `Array` of `Array`s along the given dimension `dim`. - -# Examples -```jldoctest -julia> Flux.unstack([1 3 5 7; 2 4 6 8], 2) -4-element Vector{Vector{Int64}}: - [1, 2] - [3, 4] - [5, 6] - [7, 8] -``` -""" -unstack(xs, dim) = [copy(selectdim(xs, dim, i)) for i in 1:size(xs, dim)] - -""" - chunk(xs, n) - -Split `xs` into `n` parts. - -# Examples -```jldoctest -julia> Flux.chunk(1:10, 3) -3-element Vector{UnitRange{Int64}}: - 1:4 - 5:8 - 9:10 - -julia> Flux.chunk(collect(1:10), 3) -3-element Vector{SubArray{Int64, 1, Vector{Int64}, Tuple{UnitRange{Int64}}, true}}: - [1, 2, 3, 4] - [5, 6, 7, 8] - [9, 10] -``` -""" -chunk(xs, n) = collect(Iterators.partition(xs, ceil(Int, length(xs)/n))) - -batchindex(xs, i) = (reverse(Base.tail(reverse(axes(xs))))..., i) - -""" - frequencies(xs) - -Count the number of times that each element of `xs` appears. - -# Examples -```jldoctest -julia> Flux.frequencies(['a','b','b']) -Dict{Char, Int64} with 2 entries: - 'a' => 1 - 'b' => 2 -``` -""" -function frequencies(xs) - fs = Dict{eltype(xs),Int}() - for x in xs - fs[x] = get(fs, x, 0) + 1 - end - return fs -end - -head(x::Tuple) = reverse(Base.tail(reverse(x))) - -squeezebatch(x) = reshape(x, head(size(x))) - -""" - batch(xs) - -Batch the arrays in `xs` into a single array. - -See also [`unbatch`](@ref) - -# Examples -```jldoctest -julia> Flux.batch([[1,2,3],[4,5,6]]) -3×2 Matrix{Int64}: - 1 4 - 2 5 - 3 6 -``` -""" -function batch(xs) - data = first(xs) isa AbstractArray ? - similar(first(xs), size(first(xs))..., length(xs)) : - Vector{eltype(xs)}(undef, length(xs)) - for (i, x) in enumerate(xs) - data[batchindex(data, i)...] = x - end - return data -end - -""" - unbatch(x) - -Reverse of the [`batch`](@ref) operation, -unstacking the last dimension of the array `x`. - -See also [`unstack`](@ref). - -# Examples - -```jldoctest -julia> Flux.unbatch([1 3 5 7; - 2 4 6 8]) -4-element Vector{Vector{Int64}}: - [1, 2] - [3, 4] - [5, 6] - [7, 8] -""" -unbatch(x::AbstractArray) = unstack(x, ndims(x)) -unbatch(x::AbstractVector) = x - -""" -Return the given sequence padded with `p` up to a maximum length of `n`. - -# Examples -```jldoctest -julia> rpad([1, 2], 4, 0) -4-element Vector{Int64}: - 1 - 2 - 0 - 0 - -julia> rpad([1, 2, 3], 2, 0) -3-element Vector{Int64}: - 1 - 2 - 3 -``` -""" -Base.rpad(v::AbstractVector, n::Integer, p) = [v; fill(p, max(n - length(v), 0))] - -""" - batchseq(seqs, pad) - -Take a list of `N` sequences, and turn them into a single sequence where each -item is a batch of `N`. Short sequences will be padded by `pad`. - -# Examples -```jldoctest -julia> Flux.batchseq([[1, 2, 3], [4, 5]], 0) -3-element Vector{Vector{Int64}}: - [1, 4] - [2, 5] - [3, 0] -``` -""" -function batchseq(xs, pad = nothing, n = maximum(length(x) for x in xs)) - xs_ = [rpad(x, n, pad) for x in xs] - [batch([xs_[j][i] for j = 1:length(xs_)]) for i = 1:n] -end - # Flattening models to weight vectors, and back function _restructure(m, xs) diff --git a/test/data.jl b/test/data.jl index 1f54ddf796..f2e85e9a15 100644 --- a/test/data.jl +++ b/test/data.jl @@ -5,34 +5,38 @@ using Random Y = [1:5;] d = DataLoader(X, batchsize=2) - @inferred first(d) + # @inferred first(d) batches = collect(d) - @test eltype(batches) == eltype(d) == typeof(X) + # @test eltype(batches) == eltype(d) == typeof(X) + @test eltype(batches) == typeof(X) @test length(batches) == 3 @test batches[1] == X[:,1:2] @test batches[2] == X[:,3:4] @test batches[3] == X[:,5:5] d = DataLoader(X, batchsize=2, partial=false) - @inferred first(d) + # @inferred first(d) batches = collect(d) - @test eltype(batches) == eltype(d) == typeof(X) + # @test eltype(batches) == eltype(d) == typeof(X) + @test eltype(batches) == typeof(X) @test length(batches) == 2 @test batches[1] == X[:,1:2] @test batches[2] == X[:,3:4] d = DataLoader((X,), batchsize=2, partial=false) - @inferred first(d) + # @inferred first(d) batches = collect(d) - @test eltype(batches) == eltype(d) == Tuple{typeof(X)} + # @test eltype(batches) == eltype(d) == Tuple{typeof(X)} + @test eltype(batches) == Tuple{typeof(X)} @test length(batches) == 2 @test batches[1] == (X[:,1:2],) @test batches[2] == (X[:,3:4],) d = DataLoader((X, Y), batchsize=2) - @inferred first(d) + # @inferred first(d) batches = collect(d) - @test eltype(batches) == eltype(d) == Tuple{typeof(X), typeof(Y)} + # @test eltype(batches) == eltype(d) == Tuple{typeof(X), typeof(Y)} + @test eltype(batches) == Tuple{typeof(X), typeof(Y)} @test length(batches) == 3 @test length(batches[1]) == 2 @test length(batches[2]) == 2 @@ -46,9 +50,10 @@ using Random # test with NamedTuple d = DataLoader((x=X, y=Y), batchsize=2) - @inferred first(d) + # @inferred first(d) batches = collect(d) - @test eltype(batches) == eltype(d) == NamedTuple{(:x, :y), Tuple{typeof(X), typeof(Y)}} + # @test eltype(batches) == eltype(d) == NamedTuple{(:x, :y), Tuple{typeof(X), typeof(Y)}} + @test eltype(batches) == NamedTuple{(:x, :y), Tuple{typeof(X), typeof(Y)}} @test length(batches) == 3 @test length(batches[1]) == 2 @test length(batches[2]) == 2 @@ -60,6 +65,12 @@ using Random @test batches[3][1] == batches[3].x == X[:,5:5] @test batches[3][2] == batches[3].y == Y[5:5] + # Don't mutate state https://github.com/FluxML/Flux.jl/issues/1227 + d = DataLoader([1:10;], shuffle=true) + cd = collect(zip(d, d)) + # skip the first since it used to be different also before fixing the bug + @test [cd[i][1] for i=2:10] != [cd[i][2] for i=2:10] + # test interaction with `train!` θ = ones(2) X = zeros(2, 10)