-
-
Notifications
You must be signed in to change notification settings - Fork 608
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
extend dataloader #1152
extend dataloader #1152
Conversation
Removed multi-arg constructor as discussed in #1149 |
fix #1088 |
af8cf73
to
041cc91
Compare
There's a few empty lines added here, good to check them |
DataLoader(data, batchsize, nx, partial, imax, [1:nx;], shuffle) | ||
imax = partial ? n : n - batchsize + 1 | ||
ids = 1:min(n, batchsize) | ||
DataLoader(data, batchsize, n, partial, imax, [1:n;], shuffle) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we end up creating a copy of the data that we pass into it? It will be annoying while using GPUs where VRAM is already at a premium.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe it should specialize on if we can move data to the GPU in a separate thread while the training is running and add it to a buffer which can be done for like a couple batches worth, so we can eliminate time to move, retrieve data without saturating the GPU?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we end up creating a copy of the data that we pass into it? It will be annoying while using GPUs where VRAM is already at a premium.
the entire datasets is kept with no copying. We have copying only when indexing to produce a mini-batch
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe it should specialize on if we can move data to the GPU in a separate thread while the training is running and add it to a buffer which can be done for like a couple batches worth, so we can eliminate time to move, retrieve data without saturating the GPU?
this should be done at some point
There was mention on slack about some type instability, could we clear that |
See #1159 |
fix #1159 |
other changes
|
src/optimise/train.jl
Outdated
else | ||
gs = gradient(ps) do | ||
loss(d...) | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Already in #1149 ,let's not repeat it here
.travis.yml
Outdated
@@ -16,7 +16,7 @@ notifications: | |||
jobs: | |||
include: | |||
- stage: "Documentation" | |||
julia: 1 | |||
julia: 1.4 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's fixed on master
src/optimise/train.jl
Outdated
@@ -56,14 +56,17 @@ function stop() | |||
throw(StopException()) | |||
end | |||
|
|||
maketuple(x) = (x,) | |||
maketuple(x::Tuple) = x |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We'd decided to not do this, so let's not
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is the last point of #1149 (comment)
I'd suggest keeping the manifest changes separate as there is a separate PR and discussion unrelated to Data loader here |
Best to keep each PR to one logical change. For this that should be the data loaders; the training loop can happen in #1149 and we can follow up with the manifest change. As much as anything, the manifest change is going to lead to more merge conflicts. |
bump |
src/data/dataloader.jl
Outdated
return n | ||
end | ||
|
||
function _getobs(data::A, i) where A<:AbstractArray{T,N} where {T,N} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
function _getobs(data::A, i) where A<:AbstractArray{T,N} where {T,N} | |
function _getobs(data::AbstractArray{T,N}, i) where {T,N} |
src/data/dataloader.jl
Outdated
getindex(data, ntuple(i->Colon(), N-1)..., i) | ||
end | ||
|
||
_getobs(data::Tuple, i) = ((_getobs(x, i) for x in data)...,) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_getobs(data::Tuple, i) = ((_getobs(x, i) for x in data)...,) | |
_getobs(data::Tuple, i) = map(x -> _getobs(x, i), data) |
src/data/dataloader.jl
Outdated
ids = 1:min(nx, batchsize) | ||
DataLoader(data, batchsize, nx, partial, imax, [1:nx;], shuffle) | ||
imax = partial ? n : n - batchsize + 1 | ||
ids = 1:min(n, batchsize) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove this line?
reinstated the Manifest and removed the changes to |
Manifest.toml
Outdated
@@ -384,4 +384,4 @@ version = "0.4.20" | |||
deps = ["MacroTools"] | |||
git-tree-sha1 = "b3b4882cc9accf6731a08cc39543fbc6b669dca8" | |||
uuid = "700de1a5-db45-46bc-99cf-38207098b444" | |||
version = "0.2.0" | |||
version = "0.2.0" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Trailing newline is missing here
src/optimise/train.jl
Outdated
@@ -121,4 +121,4 @@ macro epochs(n, ex) | |||
@info "Epoch $i" | |||
$(esc(ex)) | |||
end) | |||
end | |||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would like a quick cleanup of the newlines (otherwise it just adds a little noise to PRs when people have editor settings that automatically fix this stuff) but otherwise happy with the API changes.
bors r+ |
Build succeeded: |
cfr discussion in #1149. Currently DataLoader interface supports
for x in DataLoader(X)
for (x, y) in DataLoader(X, Y)
This PR adds
for (x,) in DataLoader((X,))
for (x, y) in DataLoader((X, Y))
Edit:
the constructor in 2. is removed in this PR