Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

extend dataloader #1152

Merged
merged 3 commits into from
Jun 8, 2020
Merged

extend dataloader #1152

merged 3 commits into from
Jun 8, 2020

Conversation

CarloLucibello
Copy link
Member

@CarloLucibello CarloLucibello commented Apr 29, 2020

cfr discussion in #1149. Currently DataLoader interface supports

  1. for x in DataLoader(X)
  2. for (x, y) in DataLoader(X, Y)

This PR adds

  1. for (x,) in DataLoader((X,))
  2. for (x, y) in DataLoader((X, Y))

Edit:
the constructor in 2. is removed in this PR

@CarloLucibello
Copy link
Member Author

Removed multi-arg constructor as discussed in #1149

@CarloLucibello
Copy link
Member Author

fix #1088

@DhairyaLGandhi
Copy link
Member

There's a few empty lines added here, good to check them

DataLoader(data, batchsize, nx, partial, imax, [1:nx;], shuffle)
imax = partial ? n : n - batchsize + 1
ids = 1:min(n, batchsize)
DataLoader(data, batchsize, n, partial, imax, [1:n;], shuffle)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we end up creating a copy of the data that we pass into it? It will be annoying while using GPUs where VRAM is already at a premium.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe it should specialize on if we can move data to the GPU in a separate thread while the training is running and add it to a buffer which can be done for like a couple batches worth, so we can eliminate time to move, retrieve data without saturating the GPU?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we end up creating a copy of the data that we pass into it? It will be annoying while using GPUs where VRAM is already at a premium.

the entire datasets is kept with no copying. We have copying only when indexing to produce a mini-batch

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe it should specialize on if we can move data to the GPU in a separate thread while the training is running and add it to a buffer which can be done for like a couple batches worth, so we can eliminate time to move, retrieve data without saturating the GPU?

this should be done at some point

src/data/dataloader.jl Show resolved Hide resolved
@DhairyaLGandhi
Copy link
Member

There was mention on slack about some type instability, could we clear that

@cossio
Copy link
Contributor

cossio commented May 4, 2020

See #1159

@CarloLucibello
Copy link
Member Author

fix #1159

@CarloLucibello
Copy link
Member Author

CarloLucibello commented May 4, 2020

other changes

else
gs = gradient(ps) do
loss(d...)
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Already in #1149 ,let's not repeat it here

.travis.yml Outdated
@@ -16,7 +16,7 @@ notifications:
jobs:
include:
- stage: "Documentation"
julia: 1
julia: 1.4
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's fixed on master

@@ -56,14 +56,17 @@ function stop()
throw(StopException())
end

maketuple(x) = (x,)
maketuple(x::Tuple) = x
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We'd decided to not do this, so let's not

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is the last point of #1149 (comment)

@DhairyaLGandhi
Copy link
Member

I'd suggest keeping the manifest changes separate as there is a separate PR and discussion unrelated to Data loader here

@MikeInnes
Copy link
Member

Best to keep each PR to one logical change. For this that should be the data loaders; the training loop can happen in #1149 and we can follow up with the manifest change. As much as anything, the manifest change is going to lead to more merge conflicts.

@DhairyaLGandhi
Copy link
Member

bump

return n
end

function _getobs(data::A, i) where A<:AbstractArray{T,N} where {T,N}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
function _getobs(data::A, i) where A<:AbstractArray{T,N} where {T,N}
function _getobs(data::AbstractArray{T,N}, i) where {T,N}

getindex(data, ntuple(i->Colon(), N-1)..., i)
end

_getobs(data::Tuple, i) = ((_getobs(x, i) for x in data)...,)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
_getobs(data::Tuple, i) = ((_getobs(x, i) for x in data)...,)
_getobs(data::Tuple, i) = map(x -> _getobs(x, i), data)

ids = 1:min(nx, batchsize)
DataLoader(data, batchsize, nx, partial, imax, [1:nx;], shuffle)
imax = partial ? n : n - batchsize + 1
ids = 1:min(n, batchsize)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove this line?

@CarloLucibello
Copy link
Member Author

reinstated the Manifest and removed the changes to train! (that should happen in #1149 instead). This should be ready to go

Manifest.toml Outdated
@@ -384,4 +384,4 @@ version = "0.4.20"
deps = ["MacroTools"]
git-tree-sha1 = "b3b4882cc9accf6731a08cc39543fbc6b669dca8"
uuid = "700de1a5-db45-46bc-99cf-38207098b444"
version = "0.2.0"
version = "0.2.0"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Trailing newline is missing here

@@ -121,4 +121,4 @@ macro epochs(n, ex)
@info "Epoch $i"
$(esc(ex))
end)
end
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also here

Copy link
Member

@MikeInnes MikeInnes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would like a quick cleanup of the newlines (otherwise it just adds a little noise to PRs when people have editor settings that automatically fix this stuff) but otherwise happy with the API changes.

@CarloLucibello
Copy link
Member Author

bors r+

@bors
Copy link
Contributor

bors bot commented Jun 8, 2020

Build succeeded:

@bors bors bot merged commit a7bbd3d into master Jun 8, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants