Skip to content

Commit

Permalink
remove multi-arg constructor
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello committed Apr 29, 2020
1 parent 3b3213c commit 041cc91
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 21 deletions.
4 changes: 4 additions & 0 deletions src/data/Data.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ export CMUDict, cmudict

deps(path...) = joinpath(@__DIR__, "..", "..", "deps", path...)


function download_and_verify(url, path, hash)
tmppath = tempname()
download(url, tmppath)
Expand Down Expand Up @@ -51,4 +52,7 @@ export Iris
include("housing.jl")
export Housing


@deprecate DataLoader(x...; kws...) DataLoader(x; kws...)

end
23 changes: 5 additions & 18 deletions src/data/dataloader.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,18 @@ struct DataLoader
end

"""
DataLoader(data...; batchsize=1, shuffle=false, partial=true)
DataLoader(data::Tuple; ...)
DataLoader(data; batchsize=1, shuffle=false, partial=true)
An object that iterates over mini-batches of `data`, each mini-batch containing `batchsize` observations
(except possibly the last one).
Takes as input one or more data tensors (e.g. X in unsupervised learning, X and Y in
supervised learning,) or a tuple of such tensors.
Takes as input a data tensors or a tuple of one or more such tensors.
The last dimension in each tensor is considered to be the observation dimension.
If `shuffle=true`, shuffles the observations each time iterations are re-started.
If `partial=false`, drops the last mini-batch if it is smaller than the batchsize.
The original data is preserved as a tuple in the `data` field of the DataLoader.
The original data is preserved in the `data` field of the DataLoader.
Example usage:
Expand All @@ -36,16 +34,9 @@ Example usage:
...
end
train_loader = DataLoader(Xtrain, batchsize=2)
# iterate over 50 mini-batches of size 2
for x in train_loader
@assert size(x) == (10, 2)
...
end
train_loader.data # original dataset
# similar but yelding tuples
# similar but yielding tuples
train_loader = DataLoader((Xtrain,), batchsize=2)
for (x,) in train_loader
@assert size(x) == (10, 2)
Expand All @@ -54,8 +45,6 @@ Example usage:
Xtrain = rand(10, 100)
Ytrain = rand(100)
train_loader = DataLoader(Xtrain, Ytrain, batchsize=2, shuffle=true)
# or equivalently
train_loader = DataLoader((Xtrain, Ytrain), batchsize=2, shuffle=true)
for epoch in 1:100
for (x, y) in train_loader
Expand All @@ -82,8 +71,6 @@ function DataLoader(data; batchsize=1, shuffle=false, partial=true)
DataLoader(data, batchsize, n, partial, imax, [1:n;], shuffle)
end

DataLoader(data...; kws...) = DataLoader(data; kws...)

@propagate_inbounds function Base.iterate(d::DataLoader, i=0) # returns data in d.indices[i+1:i+batchsize]
i >= d.imax && return nothing
if d.shuffle && i == 0
Expand Down
2 changes: 1 addition & 1 deletion src/deprecations.jl
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
@deprecate param(x) x
@deprecate data(x) x
@deprecate data(x) x
10 changes: 8 additions & 2 deletions test/data.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,13 @@
@test batches[1] == X[:,1:2]
@test batches[2] == X[:,3:4]

d = DataLoader(X, Y, batchsize=2)
d = DataLoader((X,), batchsize=2, partial=false)
batches = collect(d)
@test length(batches) == 2
@test batches[1] == (X[:,1:2],)
@test batches[2] == (X[:,3:4],)

d = DataLoader((X, Y), batchsize=2)
batches = collect(d)
@test length(batches) == 3
@test length(batches[1]) == 2
Expand All @@ -41,7 +47,7 @@
X = ones(2, 10)
Y = fill(2, 10)
loss(x, y) = sum((y - x'*θ).^2)
d = DataLoader(X, Y)
d = DataLoader((X, Y))
Flux.train!(loss, [θ], ncycle(d, 10), Descent(0.1))
@test norm.- 1) < 1e-10
end
Expand Down

0 comments on commit 041cc91

Please sign in to comment.