Skip to content

Commit

Permalink
use inner constructor for Bilinear, more like Dense
Browse files Browse the repository at this point in the history
  • Loading branch information
Michael Abbott committed Feb 13, 2021
1 parent 644ec8c commit 374a976
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 10 deletions.
19 changes: 10 additions & 9 deletions src/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ end
Bilinear(W::AbstractArray, [bias, σ])
Creates a Bilinear layer, which operates on two inputs at the same time.
It its output, given vectors `x`, `y` is another vector `z` with,
Its output, given vectors `x` & `y`, is another vector `z` with,
for all `i ∈ 1:out`:
z[i] = σ(x' * W[i,:,:] * y + bias[i])
Expand Down Expand Up @@ -325,21 +325,22 @@ julia> sc(x) |> size
(3, 32)
```
"""
struct Bilinear{A,B,S}
struct Bilinear{F,A,B}
weight::A
bias::B
σ::S
σ::F
function Bilinear(W::A, bias = true, σ::F = identity) where {A<:AbstractArray, F}
ndims(A) == 3 || throw(ArgumentError("expected a 3-array of weights"))
b = create_bias(W, bias, size(W,1))
new{F,A,typeof(b)}(W, b, σ)
end
end

@functor Bilinear

Bilinear(weight::AbstractArray, bias = true) = Bilinear(weight, create_bias(weight, bias, size(weight,1)), identity)

function Bilinear(in1::Integer, in2::Integer, out::Integer, σ = identity;
init = glorot_uniform, bias = true)
W = init(out, in1, in2)
b = create_bias(W, bias, out)
return Bilinear(W, b, σ)
init = glorot_uniform, bias = true)
Bilinear(init(out, in1, in2), bias, σ)
end

function (a::Bilinear)(x::AbstractMatrix, y::AbstractMatrix)
Expand Down
7 changes: 6 additions & 1 deletion test/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -161,12 +161,17 @@ import Flux: activations
b2 = Flux.Bilinear(randn(3,4,5), false)
@test b2.bias == Flux.Zeros()

b3 = Flux.Bilinear(randn(3,4,5), true, tanh)
b3 = Flux.Bilinear(randn(Float16, 3,4,5), true, tanh)
@test b3.σ == tanh
@test b2.bias isa Vector{Float16}
@test size(b3(rand(4), rand(5))) == (3,)

b4 = Flux.Bilinear(3,3,7; bias=1:7, init=Flux.zeros)
@test b4.bias isa Vector{Float32}

@test_throws ArgumentError Flux.Bilinear(rand(3)) # expects a 3-array
@test_throws ArgumentError Flux.Bilinear(rand(3,4), false, tanh)
@test_throws DimensionMismatch Flux.Bilinear(rand(3,4,5), rand(6), tanh) # wrong length bias
end
end

Expand Down

0 comments on commit 374a976

Please sign in to comment.