From 041cc9144d7eb2ad8ec638579a3eed24bb222381 Mon Sep 17 00:00:00 2001 From: CarloLucibello Date: Wed, 29 Apr 2020 10:31:43 +0200 Subject: [PATCH] remove multi-arg constructor --- src/data/Data.jl | 4 ++++ src/data/dataloader.jl | 23 +++++------------------ src/deprecations.jl | 2 +- test/data.jl | 10 ++++++++-- 4 files changed, 18 insertions(+), 21 deletions(-) diff --git a/src/data/Data.jl b/src/data/Data.jl index 16a025a7e0..3c69c35ed8 100644 --- a/src/data/Data.jl +++ b/src/data/Data.jl @@ -10,6 +10,7 @@ export CMUDict, cmudict deps(path...) = joinpath(@__DIR__, "..", "..", "deps", path...) + function download_and_verify(url, path, hash) tmppath = tempname() download(url, tmppath) @@ -51,4 +52,7 @@ export Iris include("housing.jl") export Housing + +@deprecate DataLoader(x...; kws...) DataLoader(x; kws...) + end diff --git a/src/data/dataloader.jl b/src/data/dataloader.jl index e4732fcaf2..c20a2be9fd 100644 --- a/src/data/dataloader.jl +++ b/src/data/dataloader.jl @@ -11,20 +11,18 @@ struct DataLoader end """ - DataLoader(data...; batchsize=1, shuffle=false, partial=true) - DataLoader(data::Tuple; ...) - + DataLoader(data; batchsize=1, shuffle=false, partial=true) + An object that iterates over mini-batches of `data`, each mini-batch containing `batchsize` observations (except possibly the last one). -Takes as input one or more data tensors (e.g. X in unsupervised learning, X and Y in -supervised learning,) or a tuple of such tensors. +Takes as input a data tensors or a tuple of one or more such tensors. The last dimension in each tensor is considered to be the observation dimension. If `shuffle=true`, shuffles the observations each time iterations are re-started. If `partial=false`, drops the last mini-batch if it is smaller than the batchsize. -The original data is preserved as a tuple in the `data` field of the DataLoader. +The original data is preserved in the `data` field of the DataLoader. Example usage: @@ -36,16 +34,9 @@ Example usage: ... end - train_loader = DataLoader(Xtrain, batchsize=2) - # iterate over 50 mini-batches of size 2 - for x in train_loader - @assert size(x) == (10, 2) - ... - end - train_loader.data # original dataset - # similar but yelding tuples + # similar but yielding tuples train_loader = DataLoader((Xtrain,), batchsize=2) for (x,) in train_loader @assert size(x) == (10, 2) @@ -54,8 +45,6 @@ Example usage: Xtrain = rand(10, 100) Ytrain = rand(100) - train_loader = DataLoader(Xtrain, Ytrain, batchsize=2, shuffle=true) - # or equivalently train_loader = DataLoader((Xtrain, Ytrain), batchsize=2, shuffle=true) for epoch in 1:100 for (x, y) in train_loader @@ -82,8 +71,6 @@ function DataLoader(data; batchsize=1, shuffle=false, partial=true) DataLoader(data, batchsize, n, partial, imax, [1:n;], shuffle) end -DataLoader(data...; kws...) = DataLoader(data; kws...) - @propagate_inbounds function Base.iterate(d::DataLoader, i=0) # returns data in d.indices[i+1:i+batchsize] i >= d.imax && return nothing if d.shuffle && i == 0 diff --git a/src/deprecations.jl b/src/deprecations.jl index ccaac27aaf..bf3f17d95c 100644 --- a/src/deprecations.jl +++ b/src/deprecations.jl @@ -1,2 +1,2 @@ @deprecate param(x) x -@deprecate data(x) x +@deprecate data(x) x \ No newline at end of file diff --git a/test/data.jl b/test/data.jl index 2049232313..79dae0da72 100644 --- a/test/data.jl +++ b/test/data.jl @@ -15,7 +15,13 @@ @test batches[1] == X[:,1:2] @test batches[2] == X[:,3:4] - d = DataLoader(X, Y, batchsize=2) + d = DataLoader((X,), batchsize=2, partial=false) + batches = collect(d) + @test length(batches) == 2 + @test batches[1] == (X[:,1:2],) + @test batches[2] == (X[:,3:4],) + + d = DataLoader((X, Y), batchsize=2) batches = collect(d) @test length(batches) == 3 @test length(batches[1]) == 2 @@ -41,7 +47,7 @@ X = ones(2, 10) Y = fill(2, 10) loss(x, y) = sum((y - x'*θ).^2) - d = DataLoader(X, Y) + d = DataLoader((X, Y)) Flux.train!(loss, [θ], ncycle(d, 10), Descent(0.1)) @test norm(θ .- 1) < 1e-10 end