Skip to content

Commit

Permalink
extend dataloader
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello committed Jun 6, 2020
1 parent c444226 commit 19ee886
Show file tree
Hide file tree
Showing 7 changed files with 157 additions and 115 deletions.
110 changes: 65 additions & 45 deletions Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ version = "0.5.0"

[[AbstractTrees]]
deps = ["Markdown"]
git-tree-sha1 = "86d092c2599f1f7bb01668bf8eb3412f98d61e47"
git-tree-sha1 = "33e450545eaf7699da1a6e755f9ea65f14077a45"
uuid = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
version = "0.3.2"
version = "0.3.3"

[[Adapt]]
deps = ["LinearAlgebra"]
Expand All @@ -20,23 +20,23 @@ version = "1.0.1"

[[ArrayLayouts]]
deps = ["FillArrays", "LinearAlgebra"]
git-tree-sha1 = "41956a49a8a4fefa1bf6664bca4a3035aba4c3a0"
git-tree-sha1 = "a504dca2ac7eda8761c8f7c1ed52427a1be75a3c"
uuid = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
version = "0.2.3"
version = "0.2.6"

[[Base64]]
uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"

[[BinaryProvider]]
deps = ["Libdl", "SHA"]
git-tree-sha1 = "5b08ed6036d9d3f0ee6369410b830f8873d4024c"
deps = ["Libdl", "Logging", "SHA"]
git-tree-sha1 = "428e9106b1ff27593cbd979afac9b45b82372b8c"
uuid = "b99e7846-7c00-51b0-8f62-c81ae34c0232"
version = "0.5.8"
version = "0.5.9"

[[CEnum]]
git-tree-sha1 = "62847acab40e6855a9b5905ccb99c2b5cf6b3ebb"
git-tree-sha1 = "1b77a77c3b28e0b3f413f7567c9bb8dd9bdccd14"
uuid = "fa961155-64e5-5f13-b03f-caf6b980ea82"
version = "0.2.0"
version = "0.3.0"

[[CUDAapi]]
deps = ["Libdl", "Logging"]
Expand All @@ -46,21 +46,21 @@ version = "4.0.0"

[[CUDAdrv]]
deps = ["CEnum", "CUDAapi", "Printf"]
git-tree-sha1 = "e650cbaee92b60433313157926b1e80d0c3a0e2e"
git-tree-sha1 = "f56bbf18c86bcff7a961a32a4947a5abb2963a29"
uuid = "c5f51814-7f29-56b8-a69c-e4d8f6be1fde"
version = "6.2.2"
version = "6.3.0"

[[CUDAnative]]
deps = ["Adapt", "BinaryProvider", "CEnum", "CUDAapi", "CUDAdrv", "Cthulhu", "DataStructures", "InteractiveUtils", "LLVM", "Libdl", "MacroTools", "Pkg", "Printf", "TimerOutputs"]
git-tree-sha1 = "d1fc99635d0002c8a819b78cb1f441eb44310725"
deps = ["Adapt", "BinaryProvider", "CEnum", "CUDAapi", "CUDAdrv", "ExprTools", "GPUCompiler", "LLVM", "Libdl", "Pkg", "Printf"]
git-tree-sha1 = "ac86db2b05fdfec96b011e25a504ffe7476e8a68"
uuid = "be33ccc6-a3ff-5ff2-a52e-74243cff1e17"
version = "3.0.2"
version = "3.1.0"

[[CodeTracking]]
deps = ["InteractiveUtils", "UUIDs"]
git-tree-sha1 = "0becdab7e6fbbcb7b88d8de5b72e5bb2f28239f3"
git-tree-sha1 = "cab4da992adc0a64f63fa30d2db2fd8bec40cab4"
uuid = "da1fd8a2-8d9e-5ec2-8556-3022fb5608a2"
version = "0.5.8"
version = "0.5.11"

[[CodecZlib]]
deps = ["TranscodingStreams", "Zlib_jll"]
Expand All @@ -70,9 +70,9 @@ version = "0.7.0"

[[ColorTypes]]
deps = ["FixedPointNumbers", "Random"]
git-tree-sha1 = "c4c1cca28748906265ed62c788d6fe6f0134d264"
git-tree-sha1 = "c73d9cfc2a9d8433dc77f5bff4bddf46b1d78c20"
uuid = "3da002f7-5984-5a60-b8a6-cbb66c0b333f"
version = "0.10.0"
version = "0.10.3"

[[Colors]]
deps = ["ColorTypes", "FixedPointNumbers", "InteractiveUtils", "Reexport"]
Expand All @@ -94,26 +94,26 @@ version = "0.3.3+0"

[[Cthulhu]]
deps = ["CodeTracking", "InteractiveUtils", "REPL", "Unicode"]
git-tree-sha1 = "484790098c85c26f8e59051f8ff1a0745c034a7d"
git-tree-sha1 = "a4849ec61df9659423cc63b298ed895904ee9743"
uuid = "f68482b8-f384-11e8-15f7-abe071a5a75f"
version = "1.0.1"
version = "1.0.2"

[[CuArrays]]
deps = ["AbstractFFTs", "Adapt", "CEnum", "CUDAapi", "CUDAdrv", "CUDAnative", "DataStructures", "GPUArrays", "Libdl", "LinearAlgebra", "MacroTools", "NNlib", "Pkg", "Printf", "Random", "Reexport", "Requires", "SparseArrays", "Statistics", "TimerOutputs"]
git-tree-sha1 = "e8c55b38dcca955f5aed8ec4479cdc95810db1e1"
git-tree-sha1 = "870a4ac61e99c36f42d15e496fd290c841541d90"
uuid = "3a865a2d-5b23-5a0f-bc46-62713ec82fae"
version = "2.0.1"
version = "2.2.0"

[[DataAPI]]
git-tree-sha1 = "674b67f344687a88310213ddfa8a2b3c76cc4252"
git-tree-sha1 = "176e23402d80e7743fc26c19c681bfb11246af32"
uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a"
version = "1.1.0"
version = "1.3.0"

[[DataStructures]]
deps = ["InteractiveUtils", "OrderedCollections"]
git-tree-sha1 = "73eb18320fe3ba58790c8b8f6f89420f0a622773"
git-tree-sha1 = "6166ecfaf2b8bbf2b68d791bc1d54501f345d314"
uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
version = "0.17.11"
version = "0.17.15"

[[Dates]]
deps = ["Printf"]
Expand All @@ -139,11 +139,16 @@ version = "1.0.1"
deps = ["Random", "Serialization", "Sockets"]
uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"

[[ExprTools]]
git-tree-sha1 = "6f0517056812fd6aa3af23d4b70d5325a2ae4e95"
uuid = "e2ba6199-217a-4e67-a87a-7c52f15ade04"
version = "0.1.1"

[[FillArrays]]
deps = ["LinearAlgebra", "Random", "SparseArrays"]
git-tree-sha1 = "51cc2f9bc4eb9c6c0e81ec2f779d1085583cc956"
git-tree-sha1 = "6c89d5b673e59b8173c546c84127e5f623d865f6"
uuid = "1a297f60-69ca-5386-bcde-b61e274b549b"
version = "0.8.7"
version = "0.8.9"

[[FixedPointNumbers]]
git-tree-sha1 = "3ba9ea634d4c8b289d590403b4a06f8e227a6238"
Expand All @@ -156,17 +161,33 @@ git-tree-sha1 = "869540e4367122fbffaace383a5bdc34d6e5e5ac"
uuid = "f6369f11-7733-5829-9624-2563aa707210"
version = "0.10.10"

[[Functors]]
deps = ["MacroTools"]
git-tree-sha1 = "f40adc6422f548176bb4351ebd29e4abf773040a"
uuid = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
version = "0.1.0"

[[Future]]
deps = ["Random"]
uuid = "9fa8497b-333b-5362-9e8d-4d0656e87820"

[[GPUArrays]]
deps = ["AbstractFFTs", "Adapt", "LinearAlgebra", "Printf", "Random", "Serialization"]
git-tree-sha1 = "d586762b08dcda13228df8967119b9cb6f22ade5"
git-tree-sha1 = "ce4579ebffef43e07318e9544ffeb6532c95d04d"
uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
version = "3.1.0"
version = "3.3.0"

[[GPUCompiler]]
deps = ["Cthulhu", "DataStructures", "InteractiveUtils", "LLVM", "Libdl", "TimerOutputs"]
git-tree-sha1 = "5275aa268ecd09640b32560e1eae90c78816e4d1"
uuid = "61eb1bfa-7361-4325-ad38-22787b887f55"
version = "0.2.0"

[[IRTools]]
deps = ["InteractiveUtils", "MacroTools", "Test"]
git-tree-sha1 = "1a4355e4b5b50be2311ebb644f34f3306dbd0410"
git-tree-sha1 = "8845400bd2d9815d37720251f1b53d27a335e1f4"
uuid = "7869d1d1-7146-5819-86e3-90919afe41df"
version = "0.3.1"
version = "0.3.2"

[[InteractiveUtils]]
deps = ["Markdown"]
Expand All @@ -180,9 +201,9 @@ version = "0.8.1"

[[LLVM]]
deps = ["CEnum", "Libdl", "Printf", "Unicode"]
git-tree-sha1 = "b6b86801ae2f2682e0a4889315dc76b68db2de71"
git-tree-sha1 = "93d2e1e960fe47db1a9015e86fad1d47cf67cf59"
uuid = "929cbde3-209d-540e-8aea-75f648917ca0"
version = "1.3.4"
version = "1.4.1"

[[LibGit2]]
deps = ["Printf"]
Expand Down Expand Up @@ -241,10 +262,9 @@ uuid = "efe28fd5-8261-553b-a9e1-b2916fc3738e"
version = "0.5.3+3"

[[OrderedCollections]]
deps = ["Random", "Serialization", "Test"]
git-tree-sha1 = "c4c13474d23c60d20a67b217f1d7f22a40edf8f1"
git-tree-sha1 = "12ce190210d278e12644bcadf5b21cbdcf225cd3"
uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
version = "1.1.0"
version = "1.2.0"

[[Pkg]]
deps = ["Dates", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"]
Expand Down Expand Up @@ -305,9 +325,9 @@ version = "0.10.0"

[[StaticArrays]]
deps = ["LinearAlgebra", "Random", "Statistics"]
git-tree-sha1 = "5a3bcb6233adabde68ebc97be66e95dcb787424c"
git-tree-sha1 = "5c06c0aeb81bef54aed4b3f446847905eb6cbda0"
uuid = "90137ffa-7385-5640-81b9-e52037218182"
version = "0.12.1"
version = "0.12.3"

[[Statistics]]
deps = ["LinearAlgebra", "SparseArrays"]
Expand All @@ -325,9 +345,9 @@ uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[[TimerOutputs]]
deps = ["Printf"]
git-tree-sha1 = "311765af81bbb48d7bad01fb016d9c328c6ede03"
git-tree-sha1 = "0cc8db57cb537191b02948d4fabdc09eb7f31f98"
uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
version = "0.5.3"
version = "0.5.5"

[[TranscodingStreams]]
deps = ["Random", "Test"]
Expand Down Expand Up @@ -355,13 +375,13 @@ uuid = "83775a58-1f1d-513f-b197-d71354ab007a"
version = "1.2.11+9"

[[Zygote]]
deps = ["AbstractFFTs", "ArrayLayouts", "DiffRules", "FillArrays", "ForwardDiff", "IRTools", "InteractiveUtils", "LinearAlgebra", "MacroTools", "NNlib", "NaNMath", "Random", "Requires", "SpecialFunctions", "Statistics", "ZygoteRules"]
git-tree-sha1 = "1ccbfbe8930376e31752b812daa2532c723dc332"
deps = ["AbstractFFTs", "ArrayLayouts", "DiffRules", "FillArrays", "ForwardDiff", "Future", "IRTools", "InteractiveUtils", "LinearAlgebra", "MacroTools", "NNlib", "NaNMath", "Random", "Requires", "SpecialFunctions", "Statistics", "ZygoteRules"]
git-tree-sha1 = "707ceea58e2bd0ff3077ab13a92f8355181d3ee4"
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
version = "0.4.13"
version = "0.4.20"

[[ZygoteRules]]
deps = ["MacroTools"]
git-tree-sha1 = "b3b4882cc9accf6731a08cc39543fbc6b669dca8"
uuid = "700de1a5-db45-46bc-99cf-38207098b444"
version = "0.2.0"
version = "0.2.0"
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Flux"
uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c"
version = "0.10.4"
version = "0.11.0"

[deps]
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
Expand Down
2 changes: 2 additions & 0 deletions src/data/Data.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,4 +51,6 @@ export Iris
include("housing.jl")
export Housing

@deprecate DataLoader(x...; kws...) DataLoader(x; kws...)

end
70 changes: 42 additions & 28 deletions src/data/dataloader.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Adapted from Knet's src/data.jl (author: Deniz Yuret)

struct DataLoader
data
struct DataLoader{D}
data::D
batchsize::Int
nobs::Int
partial::Bool
Expand All @@ -11,21 +11,20 @@ struct DataLoader
end

"""
DataLoader(data...; batchsize=1, shuffle=false, partial=true)
DataLoader(data; batchsize=1, shuffle=false, partial=true)
An object that iterates over mini-batches of `data`, each mini-batch containing `batchsize` observations
(except possibly the last one).
Takes as input one or more data tensors, e.g. X in unsupervised learning, X and Y in
supervised learning. The last dimension in each tensor is considered to be the observation
dimension.
Takes as input a data tensors or a tuple of one or more such tensors.
The last dimension in each tensor is considered to be the observation dimension.
If `shuffle=true`, shuffles the observations each time iterations are re-started.
If `partial=false`, drops the last mini-batch if it is smaller than the batchsize.
The original data is preserved as a tuple in the `data` field of the DataLoader.
The original data is preserved in the `data` field of the DataLoader.
Example usage:
Usage example:
Xtrain = rand(10, 100)
train_loader = DataLoader(Xtrain, batchsize=2)
Expand All @@ -37,9 +36,16 @@ Example usage:
train_loader.data # original dataset
# similar, but yielding tuples
train_loader = DataLoader((Xtrain,), batchsize=2)
for (x,) in train_loader
@assert size(x) == (10, 2)
...
end
Xtrain = rand(10, 100)
Ytrain = rand(100)
train_loader = DataLoader(Xtrain, Ytrain, batchsize=2, shuffle=true)
train_loader = DataLoader((Xtrain, Ytrain), batchsize=2, shuffle=true)
for epoch in 1:100
for (x, y) in train_loader
@assert size(x) == (10, 2)
Expand All @@ -52,41 +58,49 @@ Example usage:
using IterTools: ncycle
Flux.train!(loss, ps, ncycle(train_loader, 10), opt)
"""
function DataLoader(data...; batchsize=1, shuffle=false, partial=true)
length(data) > 0 || throw(ArgumentError("Need at least one data input"))
function DataLoader(data; batchsize=1, shuffle=false, partial=true)
batchsize > 0 || throw(ArgumentError("Need positive batchsize"))

nx = size(data[1])[end]
for i=2:length(data)
nx != size(data[i])[end] && throw(DimensionMismatch("All data should contain same number of observations"))
n = _nobs(data)
if n < batchsize
@warn "Number of observations less than batchsize, decreasing the batchsize to $n"
batchsize = n
end
if nx < batchsize
@warn "Number of data points less than batchsize, decreasing the batchsize to $nx"
batchsize = nx
end
imax = partial ? nx : nx - batchsize + 1
ids = 1:min(nx, batchsize)
DataLoader(data, batchsize, nx, partial, imax, [1:nx;], shuffle)
imax = partial ? n : n - batchsize + 1
DataLoader(data, batchsize, n, partial, imax, [1:n;], shuffle)
end

getdata(x::AbstractArray, ids) = x[(Base.Colon() for _=1:ndims(x)-1)..., ids]

@propagate_inbounds function Base.iterate(d::DataLoader, i=0) # returns data in d.indices[i+1:i+batchsize]
i >= d.imax && return nothing
if d.shuffle && i == 0
shuffle!(d.indices)
end
nexti = min(i + d.batchsize, d.nobs)
ids = d.indices[i+1:nexti]
if length(d.data) == 1
batch = getdata(d.data[1], ids)
else
batch = ((getdata(x, ids) for x in d.data)...,)
end
batch = _getobs(d.data, ids)
return (batch, nexti)
end

function Base.length(d::DataLoader)
n = d.nobs / d.batchsize
d.partial ? ceil(Int,n) : floor(Int,n)
end

_nobs(data::AbstractArray) = size(data)[end]

function _nobs(data::Tuple)
length(data) > 0 || throw(ArgumentError("Need at least one data input"))
n = _nobs(data[1])
if !all(x -> _nobs(x) == n, data[2:end])
throw(DimensionMismatch("All data should contain same number of observations"))
end
return n
end

function _getobs(data::AbstractArray{T,N}, i) where {T,N}
getindex(data, ntuple(i->Colon(), N-1)..., i)
end

_getobs(data::Tuple, i) = map(x -> _getobs(x, i), data)

Base.eltype(d::DataLoader{D}) where D = D
2 changes: 1 addition & 1 deletion src/optimise/train.jl
Original file line number Diff line number Diff line change
Expand Up @@ -121,4 +121,4 @@ macro epochs(n, ex)
@info "Epoch $i"
$(esc(ex))
end)
end
end
Loading

0 comments on commit 19ee886

Please sign in to comment.