Skip to content

Commit a74f86d

Browse files
committed
use foldl less often
1 parent 058d59d commit a74f86d

File tree

2 files changed

+39
-1
lines changed

2 files changed

+39
-1
lines changed

src/layers/basic.jl

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,17 @@ end
4343

4444
functor(::Type{<:Chain}, c) = c.layers, ls -> Chain(ls...)
4545

46-
(c::Chain)(x) = foldl((y,f) -> f(y), (x, c.layers...))
46+
function (c::Chain)(x)
47+
if order() < 2
48+
foldl((y,f) -> f(y), (x, c.layers...))
49+
else
50+
# This hand-written foldl causes high latency
51+
applychain(Tuple(c.layers), x)
52+
end
53+
end
54+
55+
applychain(::Tuple{}, x) = x
56+
applychain(fs::Tuple, x) = applychain(tail(fs), first(fs)(x))
4757

4858
Base.getindex(c::Chain, i::AbstractArray) = Chain(c.layers[i]...)
4959
Base.getindex(c::Chain{<:NamedTuple}, i::AbstractArray) =

src/utils.jl

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -890,3 +890,31 @@ function plateau(f, width; distance = -, init_score = 0, min_dist = 1f-6)
890890

891891
return patience(is_plateau, width)
892892
end
893+
894+
895+
"""
896+
order()
897+
898+
Returns `1` inside a call to `Zygote.gradient`, `2` inside nested such calls.
899+
900+
# Examples
901+
```jldoctest; setup = :(using Flux, Zygote)
902+
julia> Flux.order()
903+
0
904+
905+
julia> gradient(x -> (@show(Flux.order()); x^3), 1)
906+
Flux.order() = 1
907+
(3.0,)
908+
909+
julia> gradient(y -> gradient(x -> (@show(Flux.order()); x^3), y)[1], 1)
910+
Flux.order() = 2
911+
(6.0,)
912+
913+
julia> Zygote.hessian(x -> (@show(Flux.order()); x^3), 1) # uses ForwardDiff over Zygote
914+
Flux.order() = 1
915+
6
916+
```
917+
"""
918+
order(::Val{n} = Val(0)) where {n} = n
919+
920+
Zygote.@adjoint order(::Val{n}) where {n} = order(Val(n+1)), Returns(nothing)

0 commit comments

Comments
 (0)