Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose api for custom datasets #1288

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ makedocs(modules=[Flux, NNlib],
"NNlib" => "models/nnlib.md"],
"Handling Data" =>
["One-Hot Encoding" => "data/onehot.md",
"DataLoader" => "data/dataloader.md"],
"DataLoader" => "data/dataloader.md",
"Custom Dataset" => "data/dataset.md"],
"Training Models" =>
["Optimisers" => "training/optimisers.md",
"Training" => "training/training.md"],
Expand Down
34 changes: 34 additions & 0 deletions docs/src/data/dataset.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Custom Dataset

In order to maintain compatibility for custom datasets with `DataLoader`
you need to implement following methods:

- `Flux.Data.nobs(::CustomDataset)` -- total number of items in `CustomDataset`;
- `Flux.Data.getobs(::CustomDataset, ids)` -- how to retrieve items from dataset for given list of `ids`;
- `Base.eltype(::DataLoader{CustomDataset})` -- type of the elements returned by dataset.

Below is a dummy example of how to adapt custom dataset
to make it compatible with `DataLoader`.

```julia
# For each index returns an array of zeros.
struct CustomDataset{T, N}
element_size::Tuple
total::Int
end

Base.eltype(::DataLoader{CustomDataset{T, N}}) where {T, N} = Array{T, N}

Flux.Data.nobs(d::CustomDataset) = d.total
function Flux.Data.getobs(d::CustomDataset{T, N}, i)::Array{T, N} where {T, N}
zeros(T, d.element_size..., length(i))
end
```

And now you can use `CustomDataset` with `DataLoader`:

```julia
dataset = CustomDataset{Float32, 4}((28, 28, 1), 16)
loader = DataLoader(dataset, batchsize=4, shuffle=true)
batches = collect(loader)
```
18 changes: 9 additions & 9 deletions src/data/dataloader.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ Usage example:
"""
function DataLoader(data; batchsize=1, shuffle=false, partial=true)
batchsize > 0 || throw(ArgumentError("Need positive batchsize"))
n = _nobs(data)

n = nobs(data)
if n < batchsize
@warn "Number of observations less than batchsize, decreasing the batchsize to $n"
batchsize = n
Expand All @@ -84,7 +84,7 @@ end
end
nexti = min(i + d.batchsize, d.nobs)
ids = d.indices[i+1:nexti]
batch = _getobs(d.data, ids)
batch = getobs(d.data, ids)
return (batch, nexti)
end

Expand All @@ -93,18 +93,18 @@ function Base.length(d::DataLoader)
d.partial ? ceil(Int,n) : floor(Int,n)
end

_nobs(data::AbstractArray) = size(data)[end]
nobs(data::AbstractArray) = size(data)[end]

function _nobs(data::Union{Tuple, NamedTuple})
function nobs(data::Union{Tuple, NamedTuple})
length(data) > 0 || throw(ArgumentError("Need at least one data input"))
n = _nobs(data[1])
if !all(x -> _nobs(x) == n, Base.tail(data))
n = nobs(data[1])
if !all(x -> nobs(x) == n, Base.tail(data))
throw(DimensionMismatch("All data should contain same number of observations"))
end
return n
end

_getobs(data::AbstractArray, i) = data[ntuple(i -> Colon(), Val(ndims(data) - 1))..., i]
_getobs(data::Union{Tuple, NamedTuple}, i) = map(Base.Fix2(_getobs, i), data)
getobs(data::AbstractArray, i) = data[ntuple(i -> Colon(), Val(ndims(data) - 1))..., i]
getobs(data::Union{Tuple, NamedTuple}, i) = map(Base.Fix2(getobs, i), data)

Base.eltype(::DataLoader{D}) where D = D
31 changes: 31 additions & 0 deletions test/data.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,37 @@
@test norm(θ .- 1) < 1e-10
end

@testset "Dataset" begin
struct ZerosDataset{T, N}
element_size::Tuple
total::Int
end

Base.eltype(::DataLoader{ZerosDataset{T, N}}) where {T, N} = Array{T, N}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need the eltype to be defined separately?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because by default eltype of the dataloader is the type of data that is holds, not the type of data that data returns.
Thus it causes type instability and issues with things like @inferred, collect, etc.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a little uncomfortable having to define eltypes like that, could we maybe make it part of the signature because strictly speaking, the elements of the data loader are the mini batches, which isn't represented by the T here.


Flux.Data.nobs(d::ZerosDataset) = d.total
function Flux.Data.getobs(d::ZerosDataset{T, N}, i)::Array{T, N} where {T, N}
zeros(T, d.element_size..., length(i))
end

batch_size = 4
data_length = 16
item_size = (28, 28, 1)

dataset = ZerosDataset{Float32, 4}(item_size, data_length)
loader = DataLoader(dataset, batchsize=batch_size, shuffle=true)

@inferred first(loader)
@test length(loader) == data_length / batch_size

batches = collect(loader)
@test length(batches) == data_length / batch_size

for b in batches
@test size(b) == (item_size..., batch_size)
end
end

@testset "CMUDict" begin
@test cmudict()["CATASTROPHE"] == :[K,AH0,T,AE1,S,T,R,AH0,F,IY0].args

Expand Down