Skip to content

Commit

Permalink
make Dense(x) prettier
Browse files Browse the repository at this point in the history
  • Loading branch information
Michael Abbott committed Feb 13, 2021
1 parent 5a27ff5 commit ae879cc
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions src/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -143,11 +143,11 @@ end
@functor Dense

function (a::Dense)(x::AbstractArray)
W, b, σ = getfield(a, :weight), getfield(a, :bias), getfield(a, )
W, b, σ = a.weight, a.bias, a.σ
sz = size(x)
x = reshape(x, sz[1], :) # reshape to handle dims > 1 as batch dimensions
x = σ.(W*x .+ b)
return reshape(x, :, sz[2:end]...)
y = reshape(x, sz[1], :) # reshape to handle dims > 1 as batch dimensions
z = σ.(W*y .+ b)
return reshape(z, :, sz[2:end]...)
end

function Base.show(io::IO, l::Dense)
Expand Down

0 comments on commit ae879cc

Please sign in to comment.