Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding support for folding RNNs over 3d arrays #1686

Merged
merged 15 commits into from
Sep 14, 2021
Merged

Conversation

mkschleg
Copy link
Contributor

@mkschleg mkschleg commented Aug 2, 2021

From #1678, adding a Recur like interface for a folded operation with support for 3-dimensional arrays. This is how many users expect RNNs to work if they are familiar with Pytorch and Tensorflow, and there seems to be some desire for support for this feature as per the discussion in #1671 and @jeremiedb . This will also make a push to implementing support for the CuDNN versions of RNNs/GRUs/LSTMs more streamlined as this is the data layout that API expects.

I did a barebones implementation to add support so we can start iterating on API.

There are several questions that I have lingering with this interface:

  • Should we support different modes where we return all or only the last hidden state? Is there a better way to do the concat of the hidden states?
  • What kind of tests should we have? Just follow what we currently do for RNNs/LSTMs/GRUs?
  • For the CPU version, does it make sense not to specialize on the different rnn types? We might be able to take more advantage of BLAS if we specialized on say Folded{GRU}.
  • Do we want to force the temporal dimension to be the 2nd?
  • Do we want this to be stateful? (i.e. allow the user to change what the starting hidden state is rather than state0).

PR Checklist

  • Tests are added
  • Entry in NEWS.md
  • Documentation, if applicable
  • API changes require approval from a committer (different from the author, if applicable)

src/layers/recurrent.jl Outdated Show resolved Hide resolved
src/layers/recurrent.jl Outdated Show resolved Hide resolved
@mkschleg
Copy link
Contributor Author

I've reimplemented using mapslices for now. As @ToucheSir mentioned this is not guaranteed to be left-right in the future. So likely an alternative would be better. Left it for now as the implementation is now a single function of 3 lines. Looks quite nice.

src/layers/recurrent.jl Outdated Show resolved Hide resolved
src/layers/recurrent.jl Outdated Show resolved Hide resolved
src/layers/recurrent.jl Outdated Show resolved Hide resolved
@jeremiedb
Copy link
Contributor

Regarding the dimensions, my understanding is that the default for fused RNN operations (ex CUDNN) is to operate on input which is layout as [features, batch_size, seq_length]. See for example Knet(same with other frameworks). As such, I'd suggest to default to that behavior and assumes that time dimension is the 3rd.

src/layers/recurrent.jl Outdated Show resolved Hide resolved
mkschleg and others added 2 commits August 17, 2021 09:29
Code review suggestions

Co-authored-by: Dhairya Gandhi <dhairya@juliacomputing.com>
@mkschleg
Copy link
Contributor Author

I agree to moving the time index to the 3rd.

@mkschleg
Copy link
Contributor Author

mkschleg commented Aug 21, 2021

Was able to play with this again.

We can get rid of the splat from Flux.stack by using reduce or foldl over an hcat. Unfortunately, foldl doesn't take advantage of reduce's optimizations for hcat and vcat. But maybe reduce here is fine? I wasn't sure if associativity would matter for hcat (i.e. if they break associativity in the future, will reduce(hcat, x) produce different orders of the arrays?)

You can see the results below for CPU and GPU. Overall the runtimes and memory consumption follow reduce < stack < foldl, so reduce is the best.

The one optimization that could help even more is if there is a way to reduce directly into the final matrix instead of having to cat. I'm not sure if that kind of reduction operator exists and has an adjoint written for it, but would love to know so I can test it. I guess mapslices exists for this. But doesn't work w/ GPUs, doesn't look like there is an adjoint written for it, and isn't guarantee left-right ordering (as per the above).

setup

using Flux, BenchmarkTools

function recur_stack(m::Flux.Recur, x::AbstractArray{T, 3}) where T
    h = [m(x[:, :, i]) for i in 1:size(x, 3)]
    Flux.stack(h, 3)
end

function recur_foldl(m::Flux.Recur, x::AbstractArray{T, 3}) where T
    h = [m(x[:, :, i]) for i in 1:size(x, 3)]
    sze = size(h[1])
    reshape(foldl(hcat, h), sze[1], sze[2], length(h))
end

function recur_reduce(m::Flux.Recur, x::AbstractArray{T, 3}) where T
    h = [m(x[:, :, i]) for i in 1:size(x, 3)]
    sze = size(h[1])
    reshape(reduce(hcat, h), sze[1], sze[2], length(h))
end

function get_grad_stack(rnn, x)
    grads = gradient(Flux.params(rnn)) do 
        sum(recur_stack(rnn, x))
    end
end
# sim for foldl and stack

rnn = RNN(1024, 512)
rnn_gpu = RNN(1024, 512) |> gpu

x = cat([rand(Float32, 1024, 32) for t in 1:16]...; dims=3);
x_gpu = cat([rand(Float32, 1024, 32) for t in 1:16]...; dims=3) |> gpu;

Forward pass CPU

@benchmark begin; Flux.reset!(rnn); recur_stack(rnn, x); end
BenchmarkTools.Trial: 698 samples with 1 evaluation.
 Range (min  max):  6.909 ms    9.340 ms  ┊ GC (min  max): 0.00%  24.63%
 Time  (median):     6.949 ms               ┊ GC (median):    0.00%
 Time  (mean ± σ):   7.158 ms ± 662.046 μs  ┊ GC (mean ± σ):  2.82% ±  6.99%

  █
  █▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ ▁▁ ▁ ▁
  7.17 ms         Histogram: frequency by time        6.94 ms <

 Memory estimate: 5.96 MiB, allocs estimate: 378.

@benchmark begin; Flux.reset!(rnn); recur_foldl(rnn, x); end
BenchmarkTools.Trial: 424 samples with 1 evaluation.
 Range (min  max):  11.058 ms  17.287 ms  ┊ GC (min  max): 0.00%  18.60%
 Time  (median):     11.127 ms              ┊ GC (median):    0.00%
 Time  (mean ± σ):   11.789 ms ±  1.346 ms  ┊ GC (mean ± σ):  5.39% ±  8.89%

  █
  █▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▄ ▆
  11.2 ms      Histogram: log(frequency) by time      14.3 ms <

 Memory estimate: 13.38 MiB, allocs estimate: 160.

@benchmark begin; Flux.reset!(rnn); recur_reduce(rnn, x); end
BenchmarkTools.Trial: 715 samples with 1 evaluation.
 Range (min  max):  6.745 ms    9.305 ms  ┊ GC (min  max): 0.00%  24.83%
 Time  (median):     6.784 ms               ┊ GC (median):    0.00%
 Time  (mean ± σ):   6.989 ms ± 644.995 μs  ┊ GC (mean ± σ):  2.83% ±  6.98%

  █
  █▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ ▁▁▁▁▁▁ ▁ ▁
  6.98 ms         Histogram: frequency by time        6.78 ms <

 Memory estimate: 5.94 MiB, allocs estimate: 132.

Gradient CPU

julia> @benchmark begin; Flux.reset!(rnn); get_grad_stack(rnn, x); end                                                                                                                                                                   [51/1896]
BenchmarkTools.Trial: 107 samples with 1 evaluation.
 Range (min  max):  44.002 ms  50.351 ms  ┊ GC (min  max):  7.44%  18.70%
 Time  (median):     47.295 ms              ┊ GC (median):    13.60%
 Time  (mean ± σ):   46.977 ms ±  1.217 ms  ┊ GC (mean ± σ):  12.79% ±  2.37%

                                      █
  ▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▃▁▁▁█▁▁▁▁▁▃▁▄▁▁▁▁▁▁▁▁▁▁▄▁▁▁ ▂
  46.8 ms         Histogram: frequency by time        47.5 ms <

 Memory estimate: 168.14 MiB, allocs estimate: 5231.

@benchmark begin; Flux.reset!(rnn); get_grad_foldl(rnn, x); end
BenchmarkTools.Trial: 141 samples with 1 evaluation.
 Range (min  max):  35.397 ms   37.066 ms  ┊ GC (min  max): 13.22%  12.84%
 Time  (median):     35.511 ms               ┊ GC (median):    13.43%
 Time  (mean ± σ):   35.544 ms ± 200.277 μs  ┊ GC (mean ± σ):  13.41% ±  0.08%

                                █
  ▄▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▄▁▁▁▁▁▁▁▁▁▁▁▁▁ ▄
  35.4 ms       Histogram: log(frequency) by time      35.5 ms <

 Memory estimate: 175.61 MiB, allocs estimate: 6127.

julia> @benchmark begin; Flux.reset!(rnn); get_grad_reduce(rnn, x); end
BenchmarkTools.Trial: 143 samples with 1 evaluation.
 Range (min  max):  32.100 ms  40.431 ms  ┊ GC (min  max):  7.43%  11.60%
 Time  (median):     34.985 ms              ┊ GC (median):    13.44%
 Time  (mean ± σ):   35.152 ms ±  1.503 ms  ┊ GC (mean ± σ):  12.99% ±  1.94%

                                                              █
  ▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ █
  34.3 ms         Histogram: frequency by time        32.2 ms <

 Memory estimate: 168.10 MiB, allocs estimate: 4483.

GPUs

@benchmark begin; Flux.reset!(rnn_gpu); recur_stack(rnn_gpu, x_gpu); end
BenchmarkTools.Trial: 5328 samples with 1 evaluation.
 Range (min  max):  793.573 μs  23.339 ms  ┊ GC (min  max): 0.00%  31.66%
 Time  (median):     811.077 μs              ┊ GC (median):    0.00%
 Time  (mean ± σ):   934.020 μs ±  1.641 ms  ┊ GC (mean ± σ):  4.30% ±  2.35%

  █
  █▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁  ▁
  1.16 ms         Histogram: frequency by time          811 μs <

 Memory estimate: 158.67 KiB, allocs estimate: 3602.

@benchmark begin; Flux.reset!(rnn_gpu); recur_foldl(rnn_gpu, x_gpu); end
BenchmarkTools.Trial: 4919 samples with 1 evaluation.
 Range (min  max):  844.082 μs  13.373 ms  ┊ GC (min  max): 0.00%  35.85%
 Time  (median):     859.915 μs              ┊ GC (median):    0.00%
 Time  (mean ± σ):     1.013 ms ±  1.351 ms  ┊ GC (mean ± σ):  5.88% ±  4.07%

  █
  █▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ ▁
  1.15 ms         Histogram: frequency by time          866 μs <

 Memory estimate: 160.36 KiB, allocs estimate: 3821.

@benchmark begin; Flux.reset!(rnn_gpu); recur_reduce(rnn_gpu, x_gpu); end
BenchmarkTools.Trial: 5580 samples with 1 evaluation.
 Range (min  max):  762.537 μs  22.104 ms  ┊ GC (min  max): 0.00%  31.36%
 Time  (median):     777.677 μs              ┊ GC (median):    0.00%
 Time  (mean ± σ):   892.435 μs ±  1.527 ms  ┊ GC (mean ± σ):  4.13% ±  2.32%

  █
  █▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ ▁▁   ▁
  1.18 ms         Histogram: frequency by time          776 μs <

 Memory estimate: 141.77 KiB, allocs estimate: 3275.

GPU Gradients

julia> @benchmark begin; Flux.reset!(rnn_gpu); get_grad_stack(rnn_gpu, x_gpu); end
BenchmarkTools.Trial: 1070 samples with 1 evaluation.
 Range (min  max):  3.710 ms  9.461 ms  ┊ GC (min  max):  0.00%  33.69%
 Time  (median):     3.782 ms             ┊ GC (median):     0.00%
 Time  (mean ± σ):   4.670 ms ± 1.989 ms  ┊ GC (mean ± σ):  10.64% ± 12.24%

  █
  █▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ ▁▁▁▁▁▁ ▁
  4.16 ms        Histogram: frequency by time       3.86 ms <

 Memory estimate: 536.33 KiB, allocs estimate: 13396.

@benchmark begin; Flux.reset!(rnn_gpu); get_grad_foldl(rnn_gpu, x_gpu); end
BenchmarkTools.Trial: 974 samples with 1 evaluation.
 Range (min  max):  4.079 ms  10.506 ms  ┊ GC (min  max):  0.00%  32.17%
 Time  (median):     4.142 ms              ┊ GC (median):     0.00%
 Time  (mean ± σ):   5.129 ms ±  2.108 ms  ┊ GC (mean ± σ):  10.86% ± 12.39%

  █
  █▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ ▁▁▁▁▁▁▁▁▁▁▁▁▁▁ ▁▁▁▁  ▁
  4.52 ms        Histogram: frequency by time        4.15 ms <

 Memory estimate: 584.59 KiB, allocs estimate: 15051.

@benchmark begin; Flux.reset!(rnn_gpu); get_grad_reduce(rnn_gpu, x_gpu); end
BenchmarkTools.Trial: 1095 samples with 1 evaluation.
 Range (min  max):  3.630 ms  9.804 ms  ┊ GC (min  max):  0.00%  33.95%
 Time  (median):     3.687 ms             ┊ GC (median):     0.00%
 Time  (mean ± σ):   4.564 ms ± 1.942 ms  ┊ GC (mean ± σ):  10.68% ± 12.29%

  █
  █▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ ▁▁▁▁▁▁▁ ▁
  4.05 ms        Histogram: frequency by time       3.72 ms <

 Memory estimate: 492.34 KiB, allocs estimate: 12522.

@mkschleg
Copy link
Contributor Author

mkschleg commented Aug 22, 2021

I believe we can get away with using reduce, as the associativity doesn't matter. Only that it is commutative. i.e.

julia> x1, x2, x3 = rand(10), rand(10), rand(10);
julia> all(hcat(hcat(x1, x2), x3) .== hcat(x1, hcat(x2, x3)))
true

@ToucheSir
Copy link
Member

The one optimization that could help even more is if there is a way to reduce directly into the final matrix instead of having to cat.

reduce([hv]cat) should do this already, implementation in https://github.com/JuliaLang/julia/blob/master/base/abstractarray.jl#L1562. Order of evaluation shouldn't matter here since the RNN has already been run through the inputs in order via the comprehension above.

If we can assume that the network will output the same sized output at each timestep, then another possibility is constructing a Buffer of shape (outputsize..., timesteps) and writing to it in a loop. I haven't tested the performance of this method, however.

@mkschleg
Copy link
Contributor Author

I just tested that, and it is a bit slower than the stack version (on a different computer, so I'm not going to put up results as they would be confusing). I think the reduce would be the best option with the current implementation.

@mkschleg
Copy link
Contributor Author

Given the above experimentation, i've changed the impl to be a reduce + a reshape both because it doesn't splat (so should only compile once if the sequence length changes) and it has better memory and computation time properties.

Aside: I wonder if we should also change the Flux.Stack implementation to use reduce, but that is another issue.

@mkschleg
Copy link
Contributor Author

I'll start working on tests + docs now. I also want to make sure the gradients are consistent.

@mkschleg
Copy link
Contributor Author

mkschleg commented Aug 24, 2021

I just added tests for BPTT and Size checking (on CPUs), and added a test for the gpu following the one hot encoding test (i still need to test this a bit more). I also added documentation for the folding which follows the over what was there previously.

Finally. While digging through tests, I noticed that the RNN tests do broadcasting instead of the new for loop practice.

Copy link
Member

@ToucheSir ToucheSir left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks again for your work on this! I won't have an opportunity to do a full review until next week, but if someone else gets to it beforehand I think this is good to go.

src/layers/recurrent.jl Outdated Show resolved Hide resolved
test/cuda/curnn.jl Outdated Show resolved Hide resolved
test/layers/recurrent.jl Outdated Show resolved Hide resolved
test/layers/recurrent.jl Outdated Show resolved Hide resolved
mkschleg and others added 2 commits August 24, 2021 22:12
Co-authored-by: Brian Chen <ToucheSir@users.noreply.github.com>
@CarloLucibello
Copy link
Member

looks good! add a news entry?

function (m::Recur)(x::AbstractArray{T, 3}) where T
h = [m(x[:, :, i]) for i in 1:size(x, 3)]
sze = size(h[1])
reshape(reduce(hcat, h), sze[1], sze[2], length(h))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this would create a bunch of intermediaries. Could we also check the benchmarks with larger batchsizes and models?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What would you want to compare against? The Zygote.buffer version? And what are we looking for? Memory issues?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only additional intermediary generated here is the vector h holding the outputs. From my own benchmarking with Buffer, this doesn't appear to add much if any overhead memory-wise, and is faster to compute + not dependent on Zygote OOTB.

Copy link
Member

@ToucheSir ToucheSir Aug 30, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To clarify, the only way we can further reduce intermediate allocations is to make the RNNCell itself operate in-place. Since that's not possible at the moment, we eat an extra one for each output regardless. An extra 8 bytes for each additional timestep is insignificant in comparison, especially since Buffer usage appears to induce extra allocations of its own.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Vs the stack implementation for longer time steps. And yes, I want to understand the memory impact in the GPU case specifically. Also, would it make sense to avoid the copy by sending a view of x?

Copy link
Member

@ToucheSir ToucheSir Aug 30, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I forgot to mention it, but I was using view and it definitely helps! Here are the benchmarks from that:

Sequence length stress test: RNN(1, 1), batch size 1, sequence length 1024:

CPU:
f = recur_stack
  18.194 ms (17966 allocations: 15.17 MiB)
  68.377 ms (302842 allocations: 39.03 MiB)
f = recur_foldl
  620.649 μs (4100 allocations: 2.43 MiB)
  73.480 ms (358217 allocations: 78.00 MiB)
f = recur_reduce
  346.497 μs (3078 allocations: 300.42 KiB)
  47.053 ms (259184 allocations: 22.33 MiB)
f = recur_buffer
  341.854 μs (3075 allocations: 292.19 KiB)
  49.909 ms (279707 allocations: 20.87 MiB)

GPU:
f = recur_stack
  51.815 ms (161334 allocations: 22.73 MiB)
  213.182 ms (852717 allocations: 55.73 MiB)
f = recur_foldl
  38.556 ms (182209 allocations: 8.91 MiB)
  221.603 ms (964286 allocations: 60.72 MiB)
f = recur_reduce
  32.980 ms (142349 allocations: 7.51 MiB)
  190.225 ms (799857 allocations: 38.89 MiB)
f = recur_buffer
  36.472 ms (143886 allocations: 7.48 MiB)
  219.450 ms (836290 allocations: 38.26 MiB)

A more realistic scenario: RNN(32, 32), batch size 64, sequence length 128:

CPU:
f = recur_stack
  3.319 ms (2077 allocations: 4.23 MiB)
  41.513 ms (38493 allocations: 269.27 MiB)
f = recur_foldl
  7.595 ms (642 allocations: 67.54 MiB)
  48.827 ms (45359 allocations: 333.60 MiB)
f = recur_reduce
  2.766 ms (391 allocations: 4.04 MiB)
  37.872 ms (33394 allocations: 268.86 MiB)
f = recur_buffer
  2.866 ms (388 allocations: 4.04 MiB)
  41.226 ms (35456 allocations: 270.68 MiB)

GPU:
f = recur_stack
  3.959 ms (20388 allocations: 1.18 MiB)
  26.319 ms (116239 allocations: 5.66 MiB)
f = recur_foldl
  4.394 ms (23233 allocations: 1.12 MiB)
  28.567 ms (130096 allocations: 6.20 MiB)
f = recur_reduce
  3.701 ms (18189 allocations: 967.55 KiB)
  25.282 ms (109855 allocations: 5.22 MiB)
f = recur_buffer
  4.042 ms (18187 allocations: 960.36 KiB)
  27.730 ms (110834 allocations: 5.05 MiB)

reduce(hcat) is the fastest in every configuration, and at worst only uses marginally more memory than Buffer. I suspect the pathological behaviour of the stack forward pass on longer sequences has something to do with the splat.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @ToucheSir! I got really busy so wasn't able to work on this. I think the reduce(hcat) + view is the best option right now.

Co-authored-by: Brian Chen <ToucheSir@users.noreply.github.com>
@DhairyaLGandhi
Copy link
Member

I'd be curious to hear @mkschleg 's view on whether we should be more conservative on choosing what a specific dimension means in this way for future models.

Also ditto for @sdobber

@@ -67,6 +81,12 @@ end

flip(f, xs) = reverse(f.(reverse(xs)))

function (m::Recur)(x::AbstractArray{T, 3}) where T
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should 3 => N here such that we only consider the last dimension as time, leaving modellers to tune what different dimensions might mean for different tasks? It shouldn't be a very difficult charge but can help generalise our approach similar to how conv works.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That dispatch would be ideal, but the problem is that it would override (m::Recur)(x::AbstractArray{T, 1}) and (m::Recur)(x::AbstractArray{T, 2}), i.e. the single timestep cases. Even (m::Recur)(x::AbstractArray{T, 3}) can be suspect if Recur is wrapping a model that expects 3+1D inputs (e.g. Recur(Chain(Conv((3,), ...), GlobalMeanPool(), RNNCell(...))). That scenario is less common, however, as I think users would be more likely to write Chain(Conv((3,), ...), GlobalMeanPool(), RNN(...)).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AbstractVector/Matrix can take care of dispatches. The usage question is orthogonal, however, I'd imagine having the ability to write both the models is a more pressing concern.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My hope was to avoid having 3 dispatches with the exact same implementation (even if it's in a helper function). As you mentioned, the longer-term solution is to add in more appropriate functionality for this use case (i.e. not trying to overload Recur even more).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need just one kernel and reshape the inputs as necessary

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need at least the following:

  1. (m::Recur)(x) for non-array inputs
  2. (m::Recur)(x::AbstractArray{T, N}) for this PR, which overrides 1. for sequences.
  3. (m::Recur)(x::AbstractArray{T, 1}) and (m::Recur)(x::AbstractArray{T, 2}), which use the same code as 1. but exist solely to undo the override in 2. for individual sequence elements + the current array-per-timestep pattern.

Now to be fair, this is only 2x2 lines of duplicated code at worst for case 3. However, it does show the perils of trying to guess what the user wants instead of creating more explicit and appropriate abstractions like we've discussed in #1678. For example, these dispatches will not work if someone wants to apply a Convolutional LSTM across individual timesteps.

Just to be clear, I do not want to hold up this PR. Far the contrary, this is a great stopgap that should allow us to start prototyping efficient CuDNN forwarding while working on #1678. Embracing that and not trying to expand the scope beyond said forwarding is advantageous because it means users won't hit unavoidable edge cases that only exist because Recur is trying to handle 2 conflicting usage patterns.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think no matter what we do we will have to have a more complicated dispatch structure for recur to be able to support different assumptions about inputs (i.e. ConvLSTMs) as we support the folding version. But maybe we can do this by dispatching on Recur{ConvLSTM} and being clear in the documentation how this works. This will be a little bit unfortunate, but the amount code duplication is minimal and the default is pretty clear. I don't see how to design past this and have Recur be responsible for folding and per-step.

I think we can assume the final index is the temporal index (which seems reasonable given how arrays are stored in julia). I agree with @ToucheSir that this is a stop-gap and we can use it to iterate.

@mkschleg
Copy link
Contributor Author

mkschleg commented Sep 5, 2021

I'd be curious to hear @mkschleg 's view on whether we should be more conservative on choosing what a specific dimension means in this way for future models.

Also ditto for @sdobber

I think we need to be very conservative, but there might certain assumptions we have to make. Unfortunately, I think what we are doing here might break some assumptions in the batch norm implementation. So we likely need to come up with strategies across Flux on what different dimensions may mean and how we can navigate the space. That means we probably need a design document of some kind. This could likely be solved with the TimeDistributed layer in #1678 .

@mkschleg
Copy link
Contributor Author

mkschleg commented Sep 5, 2021

I think this might highlight some issues with chain. Sometimes I feel we want different behavior based on what layers are in a chain and what the input structure is. I'm not sure this is something we can solve, and we should just push people more towards building their own model structs in these instances.

Just as an example. In my research, I sometimes find myself searching through the chain to find certain layers so I can do a dynamic dispatch. I'm not sure there is a way around this given I need different input structure and different behavior in my updates based on what layers are in my chain (I do weird things though, so maybe not a good example).

@ToucheSir
Copy link
Member

There's always the option of allowing the user to specify which dimensions are important up-front. That of course would have to be the responsibility of a layer other than Recur, so I don't think we need to work it out for this PR.

@mkschleg
Copy link
Contributor Author

mkschleg commented Sep 6, 2021

Sounds good. Are there any lingering questions w/ this PR? I think it would be beneficial to get this through so we can start working on CuDNN paths.

@jeremiedb
Copy link
Contributor

Just making a bump for having this PR merged.
I think it would be preferable to have a new feature that might be imperfect (from the API perspective, I don't mean to compromise on correctness of the result) but that will be exposed and benefit from actual users' experience and feedback rather than trying to anticipate the perfect fit.

@ToucheSir
Copy link
Member

ToucheSir commented Sep 14, 2021

Sorry, I didn't realize this had already received an approval. Let's get it merged. Edit: well this is embarrassing, I think I mixed Flux up with another repo with write access. If nobody gets to it in 24hrs, I'd say ping someone who does.

@DhairyaLGandhi
Copy link
Member

bors r+

@bors
Copy link
Contributor

bors bot commented Sep 14, 2021

Build succeeded:

@bors bors bot merged commit f1dbc97 into FluxML:master Sep 14, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants