diff --git a/NEWS.md b/NEWS.md index c9b2188db0..82c1dcfdfc 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,3 +1,6 @@ +# v0.11 +* Change to `DataLoader`'s constructor [https://github.com/FluxML/Flux.jl/pull/1152] + # v0.10.5 * Add option for [same padding](https://github.com/FluxML/Flux.jl/pull/901) to conv and pooling layers by setting `pad=SamePad()`. * Added option to set `bias` to [Flux.Zeros](https://github.com/FluxML/Flux.jl/pull/873) to eliminating `bias` from being trained. diff --git a/Project.toml b/Project.toml index a9af31da01..969991246e 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Flux" uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c" -version = "0.10.5" +version = "0.11.0" [deps] AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" diff --git a/src/data/Data.jl b/src/data/Data.jl index 16a025a7e0..eb7c4ab0d7 100644 --- a/src/data/Data.jl +++ b/src/data/Data.jl @@ -51,4 +51,6 @@ export Iris include("housing.jl") export Housing +@deprecate DataLoader(x...; kws...) DataLoader(x; kws...) + end diff --git a/src/data/dataloader.jl b/src/data/dataloader.jl index 07c8f1fd31..4f91dc2a3d 100644 --- a/src/data/dataloader.jl +++ b/src/data/dataloader.jl @@ -1,7 +1,7 @@ # Adapted from Knet's src/data.jl (author: Deniz Yuret) -struct DataLoader - data +struct DataLoader{D} + data::D batchsize::Int nobs::Int partial::Bool @@ -11,21 +11,20 @@ struct DataLoader end """ - DataLoader(data...; batchsize=1, shuffle=false, partial=true) + DataLoader(data; batchsize=1, shuffle=false, partial=true) An object that iterates over mini-batches of `data`, each mini-batch containing `batchsize` observations (except possibly the last one). -Takes as input one or more data tensors, e.g. X in unsupervised learning, X and Y in -supervised learning. The last dimension in each tensor is considered to be the observation -dimension. +Takes as input a data tensors or a tuple of one or more such tensors. +The last dimension in each tensor is considered to be the observation dimension. If `shuffle=true`, shuffles the observations each time iterations are re-started. If `partial=false`, drops the last mini-batch if it is smaller than the batchsize. -The original data is preserved as a tuple in the `data` field of the DataLoader. +The original data is preserved in the `data` field of the DataLoader. -Example usage: +Usage example: Xtrain = rand(10, 100) train_loader = DataLoader(Xtrain, batchsize=2) @@ -37,9 +36,16 @@ Example usage: train_loader.data # original dataset + # similar, but yielding tuples + train_loader = DataLoader((Xtrain,), batchsize=2) + for (x,) in train_loader + @assert size(x) == (10, 2) + ... + end + Xtrain = rand(10, 100) Ytrain = rand(100) - train_loader = DataLoader(Xtrain, Ytrain, batchsize=2, shuffle=true) + train_loader = DataLoader((Xtrain, Ytrain), batchsize=2, shuffle=true) for epoch in 1:100 for (x, y) in train_loader @assert size(x) == (10, 2) @@ -52,25 +58,18 @@ Example usage: using IterTools: ncycle Flux.train!(loss, ps, ncycle(train_loader, 10), opt) """ -function DataLoader(data...; batchsize=1, shuffle=false, partial=true) - length(data) > 0 || throw(ArgumentError("Need at least one data input")) +function DataLoader(data; batchsize=1, shuffle=false, partial=true) batchsize > 0 || throw(ArgumentError("Need positive batchsize")) - nx = size(data[1])[end] - for i=2:length(data) - nx != size(data[i])[end] && throw(DimensionMismatch("All data should contain same number of observations")) + n = _nobs(data) + if n < batchsize + @warn "Number of observations less than batchsize, decreasing the batchsize to $n" + batchsize = n end - if nx < batchsize - @warn "Number of data points less than batchsize, decreasing the batchsize to $nx" - batchsize = nx - end - imax = partial ? nx : nx - batchsize + 1 - ids = 1:min(nx, batchsize) - DataLoader(data, batchsize, nx, partial, imax, [1:nx;], shuffle) + imax = partial ? n : n - batchsize + 1 + DataLoader(data, batchsize, n, partial, imax, [1:n;], shuffle) end -getdata(x::AbstractArray, ids) = x[(Base.Colon() for _=1:ndims(x)-1)..., ids] - @propagate_inbounds function Base.iterate(d::DataLoader, i=0) # returns data in d.indices[i+1:i+batchsize] i >= d.imax && return nothing if d.shuffle && i == 0 @@ -78,11 +77,7 @@ getdata(x::AbstractArray, ids) = x[(Base.Colon() for _=1:ndims(x)-1)..., ids] end nexti = min(i + d.batchsize, d.nobs) ids = d.indices[i+1:nexti] - if length(d.data) == 1 - batch = getdata(d.data[1], ids) - else - batch = ((getdata(x, ids) for x in d.data)...,) - end + batch = _getobs(d.data, ids) return (batch, nexti) end @@ -90,3 +85,22 @@ function Base.length(d::DataLoader) n = d.nobs / d.batchsize d.partial ? ceil(Int,n) : floor(Int,n) end + +_nobs(data::AbstractArray) = size(data)[end] + +function _nobs(data::Tuple) + length(data) > 0 || throw(ArgumentError("Need at least one data input")) + n = _nobs(data[1]) + if !all(x -> _nobs(x) == n, data[2:end]) + throw(DimensionMismatch("All data should contain same number of observations")) + end + return n +end + +function _getobs(data::AbstractArray{T,N}, i) where {T,N} + getindex(data, ntuple(i->Colon(), N-1)..., i) +end + +_getobs(data::Tuple, i) = map(x -> _getobs(x, i), data) + +Base.eltype(d::DataLoader{D}) where D = D diff --git a/test/data.jl b/test/data.jl index 2049232313..a26878db4d 100644 --- a/test/data.jl +++ b/test/data.jl @@ -4,6 +4,7 @@ d = DataLoader(X, batchsize=2) batches = collect(d) + @test eltype(batches) == eltype(d) == typeof(X) @test length(batches) == 3 @test batches[1] == X[:,1:2] @test batches[2] == X[:,3:4] @@ -11,12 +12,21 @@ d = DataLoader(X, batchsize=2, partial=false) batches = collect(d) + @test eltype(batches) == eltype(d) == typeof(X) @test length(batches) == 2 @test batches[1] == X[:,1:2] @test batches[2] == X[:,3:4] - d = DataLoader(X, Y, batchsize=2) + d = DataLoader((X,), batchsize=2, partial=false) batches = collect(d) + @test eltype(batches) == eltype(d) == Tuple{typeof(X)} + @test length(batches) == 2 + @test batches[1] == (X[:,1:2],) + @test batches[2] == (X[:,3:4],) + + d = DataLoader((X, Y), batchsize=2) + batches = collect(d) + @test eltype(batches) == eltype(d) == Tuple{typeof(X), typeof(Y)} @test length(batches) == 3 @test length(batches[1]) == 2 @test length(batches[2]) == 2 @@ -41,7 +51,7 @@ X = ones(2, 10) Y = fill(2, 10) loss(x, y) = sum((y - x'*θ).^2) - d = DataLoader(X, Y) + d = DataLoader((X, Y)) Flux.train!(loss, [θ], ncycle(d, 10), Descent(0.1)) @test norm(θ .- 1) < 1e-10 end diff --git a/test/runtests.jl b/test/runtests.jl index c2ea0715cf..23dbc1b870 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2,49 +2,45 @@ using Flux using Flux.Data using Test using Random, Statistics, LinearAlgebra -using Documenter using IterTools: ncycle Random.seed!(0) -@testset "Flux" begin - - @testset "Utils" begin - include("utils.jl") - end - - @testset "Onehot" begin - include("onehot.jl") - end - - @testset "Optimise" begin - include("optimise.jl") - end - - @testset "Data" begin - include("data.jl") - end - - @testset "Layers" begin - include("layers/basic.jl") - include("layers/normalisation.jl") - include("layers/stateless.jl") - include("layers/conv.jl") - end - - @testset "CUDA" begin - if Flux.use_cuda[] - include("cuda/cuda.jl") - else - @warn "CUDA unavailable, not testing GPU support" - end +@testset "Utils" begin + include("utils.jl") +end + +@testset "Onehot" begin + include("onehot.jl") +end + +@testset "Optimise" begin + include("optimise.jl") +end + +@testset "Data" begin + include("data.jl") +end + +@testset "Layers" begin + include("layers/basic.jl") + include("layers/normalisation.jl") + include("layers/stateless.jl") + include("layers/conv.jl") +end + +@testset "CUDA" begin + if Flux.use_cuda[] + include("cuda/cuda.jl") + else + @warn "CUDA unavailable, not testing GPU support" end +end +@static if VERSION >= v"1.4" + using Documenter @testset "Docs" begin - if VERSION >= v"1.4" - DocMeta.setdocmeta!(Flux, :DocTestSetup, :(using Flux); recursive=true) - doctest(Flux) - end + DocMeta.setdocmeta!(Flux, :DocTestSetup, :(using Flux); recursive=true) + doctest(Flux) end - -end # testset Flux +end