Skip to content

Commit

Permalink
Merge #1152
Browse files Browse the repository at this point in the history
1152: extend dataloader r=CarloLucibello a=CarloLucibello

cfr discussion in #1149. Currently DataLoader interface supports

1. `for x in DataLoader(X)`
2. `for (x, y) in DataLoader(X, Y)`

This PR adds

3. `for (x,) in DataLoader((X,))`
4. `for (x, y) in DataLoader((X, Y))`

Edit:
the constructor in 2. is removed in this PR

Co-authored-by: CarloLucibello <carlo.lucibello@gmail.com>
  • Loading branch information
bors[bot] and CarloLucibello authored Jun 8, 2020
2 parents d9b0747 + 0cf4643 commit a7bbd3d
Show file tree
Hide file tree
Showing 6 changed files with 94 additions and 69 deletions.
3 changes: 3 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# v0.11
* Change to `DataLoader`'s constructor [https://github.com/FluxML/Flux.jl/pull/1152]

# v0.10.5
* Add option for [same padding](https://github.com/FluxML/Flux.jl/pull/901) to conv and pooling layers by setting `pad=SamePad()`.
* Added option to set `bias` to [Flux.Zeros](https://github.com/FluxML/Flux.jl/pull/873) to eliminating `bias` from being trained.
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Flux"
uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c"
version = "0.10.5"
version = "0.11.0"

[deps]
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
Expand Down
2 changes: 2 additions & 0 deletions src/data/Data.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,4 +51,6 @@ export Iris
include("housing.jl")
export Housing

@deprecate DataLoader(x...; kws...) DataLoader(x; kws...)

end
70 changes: 42 additions & 28 deletions src/data/dataloader.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Adapted from Knet's src/data.jl (author: Deniz Yuret)

struct DataLoader
data
struct DataLoader{D}
data::D
batchsize::Int
nobs::Int
partial::Bool
Expand All @@ -11,21 +11,20 @@ struct DataLoader
end

"""
DataLoader(data...; batchsize=1, shuffle=false, partial=true)
DataLoader(data; batchsize=1, shuffle=false, partial=true)
An object that iterates over mini-batches of `data`, each mini-batch containing `batchsize` observations
(except possibly the last one).
Takes as input one or more data tensors, e.g. X in unsupervised learning, X and Y in
supervised learning. The last dimension in each tensor is considered to be the observation
dimension.
Takes as input a data tensors or a tuple of one or more such tensors.
The last dimension in each tensor is considered to be the observation dimension.
If `shuffle=true`, shuffles the observations each time iterations are re-started.
If `partial=false`, drops the last mini-batch if it is smaller than the batchsize.
The original data is preserved as a tuple in the `data` field of the DataLoader.
The original data is preserved in the `data` field of the DataLoader.
Example usage:
Usage example:
Xtrain = rand(10, 100)
train_loader = DataLoader(Xtrain, batchsize=2)
Expand All @@ -37,9 +36,16 @@ Example usage:
train_loader.data # original dataset
# similar, but yielding tuples
train_loader = DataLoader((Xtrain,), batchsize=2)
for (x,) in train_loader
@assert size(x) == (10, 2)
...
end
Xtrain = rand(10, 100)
Ytrain = rand(100)
train_loader = DataLoader(Xtrain, Ytrain, batchsize=2, shuffle=true)
train_loader = DataLoader((Xtrain, Ytrain), batchsize=2, shuffle=true)
for epoch in 1:100
for (x, y) in train_loader
@assert size(x) == (10, 2)
Expand All @@ -52,41 +58,49 @@ Example usage:
using IterTools: ncycle
Flux.train!(loss, ps, ncycle(train_loader, 10), opt)
"""
function DataLoader(data...; batchsize=1, shuffle=false, partial=true)
length(data) > 0 || throw(ArgumentError("Need at least one data input"))
function DataLoader(data; batchsize=1, shuffle=false, partial=true)
batchsize > 0 || throw(ArgumentError("Need positive batchsize"))

nx = size(data[1])[end]
for i=2:length(data)
nx != size(data[i])[end] && throw(DimensionMismatch("All data should contain same number of observations"))
n = _nobs(data)
if n < batchsize
@warn "Number of observations less than batchsize, decreasing the batchsize to $n"
batchsize = n
end
if nx < batchsize
@warn "Number of data points less than batchsize, decreasing the batchsize to $nx"
batchsize = nx
end
imax = partial ? nx : nx - batchsize + 1
ids = 1:min(nx, batchsize)
DataLoader(data, batchsize, nx, partial, imax, [1:nx;], shuffle)
imax = partial ? n : n - batchsize + 1
DataLoader(data, batchsize, n, partial, imax, [1:n;], shuffle)
end

getdata(x::AbstractArray, ids) = x[(Base.Colon() for _=1:ndims(x)-1)..., ids]

@propagate_inbounds function Base.iterate(d::DataLoader, i=0) # returns data in d.indices[i+1:i+batchsize]
i >= d.imax && return nothing
if d.shuffle && i == 0
shuffle!(d.indices)
end
nexti = min(i + d.batchsize, d.nobs)
ids = d.indices[i+1:nexti]
if length(d.data) == 1
batch = getdata(d.data[1], ids)
else
batch = ((getdata(x, ids) for x in d.data)...,)
end
batch = _getobs(d.data, ids)
return (batch, nexti)
end

function Base.length(d::DataLoader)
n = d.nobs / d.batchsize
d.partial ? ceil(Int,n) : floor(Int,n)
end

_nobs(data::AbstractArray) = size(data)[end]

function _nobs(data::Tuple)
length(data) > 0 || throw(ArgumentError("Need at least one data input"))
n = _nobs(data[1])
if !all(x -> _nobs(x) == n, data[2:end])
throw(DimensionMismatch("All data should contain same number of observations"))
end
return n
end

function _getobs(data::AbstractArray{T,N}, i) where {T,N}
getindex(data, ntuple(i->Colon(), N-1)..., i)
end

_getobs(data::Tuple, i) = map(x -> _getobs(x, i), data)

Base.eltype(d::DataLoader{D}) where D = D
14 changes: 12 additions & 2 deletions test/data.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,29 @@

d = DataLoader(X, batchsize=2)
batches = collect(d)
@test eltype(batches) == eltype(d) == typeof(X)
@test length(batches) == 3
@test batches[1] == X[:,1:2]
@test batches[2] == X[:,3:4]
@test batches[3] == X[:,5:5]

d = DataLoader(X, batchsize=2, partial=false)
batches = collect(d)
@test eltype(batches) == eltype(d) == typeof(X)
@test length(batches) == 2
@test batches[1] == X[:,1:2]
@test batches[2] == X[:,3:4]

d = DataLoader(X, Y, batchsize=2)
d = DataLoader((X,), batchsize=2, partial=false)
batches = collect(d)
@test eltype(batches) == eltype(d) == Tuple{typeof(X)}
@test length(batches) == 2
@test batches[1] == (X[:,1:2],)
@test batches[2] == (X[:,3:4],)

d = DataLoader((X, Y), batchsize=2)
batches = collect(d)
@test eltype(batches) == eltype(d) == Tuple{typeof(X), typeof(Y)}
@test length(batches) == 3
@test length(batches[1]) == 2
@test length(batches[2]) == 2
Expand All @@ -41,7 +51,7 @@
X = ones(2, 10)
Y = fill(2, 10)
loss(x, y) = sum((y - x'*θ).^2)
d = DataLoader(X, Y)
d = DataLoader((X, Y))
Flux.train!(loss, [θ], ncycle(d, 10), Descent(0.1))
@test norm.- 1) < 1e-10
end
Expand Down
72 changes: 34 additions & 38 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,49 +2,45 @@ using Flux
using Flux.Data
using Test
using Random, Statistics, LinearAlgebra
using Documenter
using IterTools: ncycle

Random.seed!(0)

@testset "Flux" begin

@testset "Utils" begin
include("utils.jl")
end

@testset "Onehot" begin
include("onehot.jl")
end

@testset "Optimise" begin
include("optimise.jl")
end

@testset "Data" begin
include("data.jl")
end

@testset "Layers" begin
include("layers/basic.jl")
include("layers/normalisation.jl")
include("layers/stateless.jl")
include("layers/conv.jl")
end

@testset "CUDA" begin
if Flux.use_cuda[]
include("cuda/cuda.jl")
else
@warn "CUDA unavailable, not testing GPU support"
end
@testset "Utils" begin
include("utils.jl")
end

@testset "Onehot" begin
include("onehot.jl")
end

@testset "Optimise" begin
include("optimise.jl")
end

@testset "Data" begin
include("data.jl")
end

@testset "Layers" begin
include("layers/basic.jl")
include("layers/normalisation.jl")
include("layers/stateless.jl")
include("layers/conv.jl")
end

@testset "CUDA" begin
if Flux.use_cuda[]
include("cuda/cuda.jl")
else
@warn "CUDA unavailable, not testing GPU support"
end
end

@static if VERSION >= v"1.4"
using Documenter
@testset "Docs" begin
if VERSION >= v"1.4"
DocMeta.setdocmeta!(Flux, :DocTestSetup, :(using Flux); recursive=true)
doctest(Flux)
end
DocMeta.setdocmeta!(Flux, :DocTestSetup, :(using Flux); recursive=true)
doctest(Flux)
end

end # testset Flux
end

0 comments on commit a7bbd3d

Please sign in to comment.