Skip to content

Commit

Permalink
fix merge conflicts
Browse files Browse the repository at this point in the history
  • Loading branch information
MJ10 committed Feb 5, 2019
1 parent fc0304a commit eceaf3e
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 6 deletions.
2 changes: 1 addition & 1 deletion src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ Standard convolutional layer. `size` should be a tuple like `(2, 2)`.
`in` and `out` specify the number of input and output channels respectively.
Data should be stored in WHCN order. In other words, a 100×100 RGB image would
be a `100×100×3` array, and a batch of 50 would be a `100×100×3×50` array.
be a `100×100×3×1` array, and a batch of 50 would be a `100×100×3×50` array.
Takes the keyword arguments `pad`, `stride` and `dilation`.
"""
Expand Down
8 changes: 4 additions & 4 deletions src/layers/recurrent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ end
RNNCell(in::Integer, out::Integer, σ = tanh;
init = glorot_uniform) =
RNNCell(σ, param(init(out, in)), param(init(out, out)),
param(zeros(out)), param(init(out)))
param(init(out)), param(zeros(out)))

function (m::RNNCell)(h, x)
σ, Wi, Wh, b = m.σ, m.Wi, m.Wh, m.b
Expand Down Expand Up @@ -122,8 +122,8 @@ end

function LSTMCell(in::Integer, out::Integer;
init = glorot_uniform)
cell = LSTMCell(param(init(out*4, in)), param(init(out*4, out)), param(zeros(out*4)),
param(init(out)), param(init(out)))
cell = LSTMCell(param(init(out*4, in)), param(init(out*4, out)), param(init(out*4)),
param(zeros(out)), param(zeros(out)))
cell.b.data[gate(out, 2)] .= 1
return cell
end
Expand Down Expand Up @@ -169,7 +169,7 @@ end

GRUCell(in, out; init = glorot_uniform) =
GRUCell(param(init(out*3, in)), param(init(out*3, out)),
param(zeros(out*3)), param(init(out)))
param(init(out*3)), param(zeros(out)))

function (m::GRUCell)(h, x)
b, o = m.b, size(h, 1)
Expand Down
6 changes: 6 additions & 0 deletions src/layers/stateless.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,14 @@ logitbinarycrossentropy(logŷ, y) = (1 - y)*logŷ - logσ(logŷ)
Normalises x to mean 0 and standard deviation 1, across the dimensions given by dims. Defaults to normalising over columns.
"""
<<<<<<< HEAD
function normalise(x::AbstractArray, dims::Int=1)
μ′ = mean(x, dims = dims)
σ′ = std(x, dims = dims, mean = μ′)
=======
function normalise(x::AbstractVecOrMat)
μ′ = mean(x, dims = 1)
σ′ = std(x, dims = 1, mean = μ′, corrected=false)
>>>>>>> 940b1e6dbfd95ed3f64daf0a75edae713d25e7c1
return (x .- μ′) ./ σ′
end
23 changes: 22 additions & 1 deletion src/onehot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ Base.size(xs::OneHotVector) = (Int64(xs.of),)

Base.getindex(xs::OneHotVector, i::Integer) = i == xs.ix

Base.getindex(xs::OneHotVector, ::Colon) = xs

A::AbstractMatrix * b::OneHotVector = A[:, b.ix]

struct OneHotMatrix{A<:AbstractVector{OneHotVector}} <: AbstractMatrix{Bool}
Expand All @@ -22,6 +24,21 @@ Base.getindex(xs::OneHotMatrix, i::Integer, j::Integer) = xs.data[j][i]
Base.getindex(xs::OneHotMatrix, ::Colon, i::Integer) = xs.data[i]
Base.getindex(xs::OneHotMatrix, ::Colon, i::AbstractArray) = OneHotMatrix(xs.height, xs.data[i])

Base.getindex(xs::Flux.OneHotMatrix, j::Base.UnitRange, i::Int) = xs.data[i][j]

Base.getindex(xs::OneHotMatrix, ::Colon, ::Colon) = xs

# handle special case for when we want the entire column
function Base.getindex(xs::Flux.OneHotMatrix{T}, ot::Union{Base.Slice, Base.OneTo}, i::Int) where {T<:AbstractArray}
res = similar(xs, size(xs, 1), 1)
if length(ot) == size(xs, 1)
res = xs[:,i]
else
res = xs[1:length(ot),i]
end
res
end

A::AbstractMatrix * B::OneHotMatrix = A[:, map(x->x.ix, B.data)]

Base.hcat(x::OneHotVector, xs::OneHotVector...) = OneHotMatrix(length(x), [x, xs...])
Expand Down Expand Up @@ -54,13 +71,17 @@ end
onehotbatch(ls, labels, unk...) =
OneHotMatrix(length(labels), [onehot(l, labels, unk...) for l in ls])

Base.argmax(xs::OneHotVector) = xs.ix

onecold(y::AbstractVector, labels = 1:length(y)) = labels[Base.argmax(y)]

onecold(y::AbstractMatrix, labels...) =
dropdims(mapslices(y -> onecold(y, labels...), y, dims=1), dims=1)

onecold(y::OneHotMatrix, labels...) = map(x -> onecold(x, labels...), y.data)

function argmax(xs...)
Base.depwarn("`argmax(...) is deprecated, use `onecold(...)` instead.", :argmax)
Base.depwarn("`argmax(...)` is deprecated, use `onecold(...)` instead.", :argmax)
return onecold(xs...)
end

Expand Down
7 changes: 7 additions & 0 deletions test/cuda/cuda.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,13 @@ Flux.back!(sum(l))

end

@testset "onecold gpu" begin
x = zeros(Float32, 10, 3) |> gpu;
y = Flux.onehotbatch(ones(3), 1:10) |> gpu;
res = Flux.onecold(x) .== Flux.onecold(y)
@test res isa CuArray
end

if CuArrays.libcudnn != nothing
@info "Testing Flux/CUDNN"
include("cudnn.jl")
Expand Down

0 comments on commit eceaf3e

Please sign in to comment.