diff --git a/src/CuArrays.jl b/src/CuArrays.jl index 06bb01be..654c9934 100644 --- a/src/CuArrays.jl +++ b/src/CuArrays.jl @@ -4,7 +4,7 @@ using CUDAapi, CUDAdrv, CUDAnative using GPUArrays -export CuArray, CuVector, CuMatrix, CuVecOrMat, cu +export CuArray, CuVector, CuMatrix, CuVecOrMat, CuIterator, cu export CUBLAS, CUSPARSE, CUSOLVER, CUFFT, CURAND, CUDNN, CUTENSOR import LinearAlgebra @@ -93,6 +93,7 @@ include("mapreduce.jl") include("accumulate.jl") include("linalg.jl") include("nnlib.jl") +include("iterator.jl") include("deprecated.jl") diff --git a/src/iterator.jl b/src/iterator.jl new file mode 100644 index 00000000..40209d88 --- /dev/null +++ b/src/iterator.jl @@ -0,0 +1,27 @@ +""" + CuIterator(batches) + +Return a `CuIterator` that can iterate through the provided `batches` via `Base.iterate`. + +Upon each iteration, the current `batch` is adapted to the GPU (via `map(x -> adapt(CuArray, x), batch)`) +and the previous iteration is marked as freeable from GPU memory (via `unsafe_free!`). + +This abstraction is useful for batching data into GPU memory in a manner that +allows old iterations to potentially be freed (or marked as reusable) earlier +than they otherwise would via CuArray's internal polling mechanism. +""" +mutable struct CuIterator{B} + batches::B + previous::Any + CuIterator(batches) = new{typeof(batches)}(batches) +end + +function Base.iterate(c::CuIterator, state...) + item = iterate(c.batches, state...) + isdefined(c, :previous) && foreach(unsafe_free!, c.previous) + item === nothing && return nothing + batch, next_state = item + cubatch = map(x -> adapt(CuArray, x), batch) + c.previous = cubatch + return cubatch, next_state +end diff --git a/test/iterator.jl b/test/iterator.jl new file mode 100644 index 00000000..e35222dd --- /dev/null +++ b/test/iterator.jl @@ -0,0 +1,16 @@ +@testset "CuIterator" begin + batch_count = 10 + max_batch_items = 3 + max_ndims = 3 + sizes = 20:50 + rand_shape = () -> rand(sizes, rand(1:max_ndims)) + batches = [[rand(Float32, rand_shape()...) for _ in 1:rand(1:max_batch_items)] for _ in 1:batch_count] + cubatches = CuIterator(batch for batch in batches) # ensure generators are accepted + previous_cubatch = missing + for (batch, cubatch) in zip(batches, cubatches) + @test ismissing(previous_cubatch) || all(x -> x.freed, previous_cubatch) + @test batch == Array.(cubatch) + @test all(x -> x isa CuArray, cubatch) + previous_cubatch = cubatch + end +end diff --git a/test/runtests.jl b/test/runtests.jl index ec4455fd..ccfdd6b7 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -60,6 +60,7 @@ include("solver.jl") include("sparse_solver.jl") include("dnn.jl") include("tensor.jl") +include("iterator.jl") include("forwarddiff.jl") include("nnlib.jl")