Skip to content

Commit

Permalink
Merge #1759
Browse files Browse the repository at this point in the history
1759: Make unsqueeze type stable r=CarloLucibello a=cossio

This PR makes Flux.unsqueeze type stable and improves its performance.

Closes #1737.

Please see linked issue for comparison.
I also added some tests.

Co-authored-by: cossio <j.cossio.diaz@gmail.com>
  • Loading branch information
bors[bot] and cossio authored Oct 31, 2021
2 parents 69afb67 + 78dd3f6 commit ea26f45
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 19 deletions.
13 changes: 8 additions & 5 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ sparse_init(rng::AbstractRNG; init_kwargs...) = (dims...; kwargs...) -> sparse_i
"""
identity_init([rng=GLOBAL_RNG], dims...; gain=1, shift=0)
Return an `Array` of size `dims` which yields an identity mapping when used as parameters in
Return an `Array` of size `dims` which yields an identity mapping when used as parameters in
most Flux layers. Use `gain` to scale the identity by a constant.
Often useful in the context of transfer learning, i.e when one wants to add more capacity to
Expand All @@ -297,10 +297,10 @@ Equivalent to `Base.circshift(identity(dims...), shift)`.
Some caveats: Not all layers will be identity mapping when used with this init. Exceptions
include recurrent layers, `DepthwiseConv` and normalization layers.
Also note that layers must have `input_size == output_size` for identity mapping to be
Also note that layers must have `input_size == output_size` for identity mapping to be
possible. When this is not the case, extra dimensions of the array are padded with zeros.
For convolutional layers, in addition to the above, the kernel sizes must also be odd and
For convolutional layers, in addition to the above, the kernel sizes must also be odd and
padding must be applied so that output feature maps have the same size as input feature maps,
e.g by using [`SamePad`](@ref).
Expand Down Expand Up @@ -420,7 +420,10 @@ julia> Flux.unsqueeze(xs, 1)
[1, 2] [3, 4] [5, 6]
```
"""
unsqueeze(xs::AbstractArray, dim::Integer) = reshape(xs, (size(xs)[1:dim-1]..., 1, size(xs)[dim:end]...))
function unsqueeze(xs::AbstractArray, dim::Integer)
sz = ntuple(i -> i < dim ? size(xs, i) : i == dim ? 1 : size(xs, i - 1), ndims(xs) + 1)
return reshape(xs, sz)
end

"""
unsqueeze(dim)
Expand Down Expand Up @@ -574,7 +577,7 @@ See also [`unstack`](@ref).
# Examples
```jldoctest
julia> Flux.unbatch([1 3 5 7;
julia> Flux.unbatch([1 3 5 7;
2 4 6 8])
4-element Vector{Vector{Int64}}:
[1, 2]
Expand Down
37 changes: 23 additions & 14 deletions test/utils.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,20 @@
using Flux
using Flux: throttle, nfan, glorot_uniform, glorot_normal,
kaiming_normal, kaiming_uniform, orthogonal,
sparse_init, stack, unstack, Zeros, batch, unbatch
kaiming_normal, kaiming_uniform, orthogonal,
sparse_init, stack, unstack, Zeros, batch, unbatch,
unsqueeze
using StatsBase: var, std
using Random
using Test

@testset "unsqueeze" begin
x = randn(2, 3, 2)
@test @inferred(unsqueeze(x, 1)) == reshape(x, 1, 2, 3, 2)
@test @inferred(unsqueeze(x, 2)) == reshape(x, 2, 1, 3, 2)
@test @inferred(unsqueeze(x, 3)) == reshape(x, 2, 3, 1, 2)
@test @inferred(unsqueeze(x, 4)) == reshape(x, 2, 3, 2, 1)
end

@testset "Throttle" begin
@testset "default behaviour" begin
a = []
Expand Down Expand Up @@ -178,10 +187,10 @@ end

@testset "$layer ID mapping with kernelsize $kernelsize" for layer in (Conv, ConvTranspose, CrossCor), kernelsize in (
(1,),
(3,),
(1, 3),
(3, 5),
(3, 5, 7))
(3,),
(1, 3),
(3, 5),
(3, 5, 7))
nch = 3
l = layer(kernelsize, nch=>nch, init=identity_init, pad=SamePad())

Expand Down Expand Up @@ -333,9 +342,9 @@ end


@testset "Batching" begin
stacked_array=[ 8 9 3 5
9 6 6 9
9 1 7 2
stacked_array=[ 8 9 3 5
9 6 6 9
9 1 7 2
7 4 10 6 ]
unstacked_array=[[8, 9, 9, 7], [9, 6, 1, 4], [3, 6, 7, 10], [5, 9, 2, 6]]
@test unbatch(stacked_array) == unstacked_array
Expand Down Expand Up @@ -445,7 +454,7 @@ end

modules = Flux.modules(Chain(SkipConnection(
Conv((2,3), 4=>5; pad=6, stride=7),
+),
+),
LayerNorm(8)))
@test length(modules) == 5
end
Expand Down Expand Up @@ -475,16 +484,16 @@ end
@testset "early stopping" begin
@testset "args & kwargs" begin
es = Flux.early_stopping((x; y = 1) -> x + y, 10; min_dist=3)

n_iter = 0
while n_iter < 99
es(-n_iter; y=-n_iter) && break
n_iter += 1
end

@test n_iter == 9
end

@testset "distance" begin
es = Flux.early_stopping(identity, 10; distance=(best_score, score) -> score - best_score)

Expand All @@ -496,7 +505,7 @@ end

@test n_iter == 99
end

@testset "init_score" begin
es = Flux.early_stopping(identity, 10; init_score=10)

Expand Down

0 comments on commit ea26f45

Please sign in to comment.