diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 5ab86902f8..5e5f52dfd5 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -287,7 +287,7 @@ end Bilinear(W::AbstractArray, [bias, σ]) Creates a Bilinear layer, which operates on two inputs at the same time. -It its output, given vectors `x`, `y` is another vector `z` with, +Its output, given vectors `x` & `y`, is another vector `z` with, for all `i ∈ 1:out`: z[i] = σ(x' * W[i,:,:] * y + bias[i]) @@ -325,21 +325,22 @@ julia> sc(x) |> size (3, 32) ``` """ -struct Bilinear{A,B,S} +struct Bilinear{F,A,B} weight::A bias::B - σ::S + σ::F + function Bilinear(W::A, bias = true, σ::F = identity) where {A<:AbstractArray, F} + ndims(A) == 3 || throw(ArgumentError("expected a 3-array of weights")) + b = create_bias(W, bias, size(W,1)) + new{F,A,typeof(b)}(W, b, σ) + end end @functor Bilinear -Bilinear(weight::AbstractArray, bias = true) = Bilinear(weight, create_bias(weight, bias, size(weight,1)), identity) - function Bilinear(in1::Integer, in2::Integer, out::Integer, σ = identity; - init = glorot_uniform, bias = true) - W = init(out, in1, in2) - b = create_bias(W, bias, out) - return Bilinear(W, b, σ) + init = glorot_uniform, bias = true) + Bilinear(init(out, in1, in2), bias, σ) end function (a::Bilinear)(x::AbstractMatrix, y::AbstractMatrix) diff --git a/test/layers/basic.jl b/test/layers/basic.jl index 7047e095bc..d76056806a 100644 --- a/test/layers/basic.jl +++ b/test/layers/basic.jl @@ -161,12 +161,17 @@ import Flux: activations b2 = Flux.Bilinear(randn(3,4,5), false) @test b2.bias == Flux.Zeros() - b3 = Flux.Bilinear(randn(3,4,5), true, tanh) + b3 = Flux.Bilinear(randn(Float16, 3,4,5), true, tanh) @test b3.σ == tanh + @test b2.bias isa Vector{Float16} @test size(b3(rand(4), rand(5))) == (3,) b4 = Flux.Bilinear(3,3,7; bias=1:7, init=Flux.zeros) @test b4.bias isa Vector{Float32} + + @test_throws ArgumentError Flux.Bilinear(rand(3)) # expects a 3-array + @test_throws ArgumentError Flux.Bilinear(rand(3,4), false, tanh) + @test_throws DimensionMismatch Flux.Bilinear(rand(3,4,5), rand(6), tanh) # wrong length bias end end