Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue with CRF loss function #1087

Open
opus111 opened this issue Mar 17, 2020 · 13 comments
Open

Issue with CRF loss function #1087

opus111 opened this issue Mar 17, 2020 · 13 comments

Comments

@opus111
Copy link

opus111 commented Mar 17, 2020

Here is a file that reproduces the problem. This code is copied from the TextAnalysis package and slightly altered for Flux 10. The version in GitHub works with Flux 9

`#=
This code is copied from CRF of TextAnalysis

The current version in GitHub works with Flux 0.9

https://github.com/JuliaText/TextAnalysis.jl/tree/master/src/CRF
=#

using Flux

log_sum_exp(z) = log_sum_exp(z, maximum(z, dims = 1))
log_sum_exp(z, m) = log.(sum(exp.(z .- m), dims = 1)) .+ m

mutable struct CRF{S}
W::S # Transition Scores
n::Int # Num Labels
end

function CRF(n::Integer)
W = rand(Float32, n + 2, n + 2)
W[:, n + 1] .= -10000
W[n + 2, :] .= -10000
return CRF(W, n)
end

Flux.@functor CRF (W,)

preds_first(c::CRF, y) = c.W[c.n + 1, Flux.onecold(y, 1:length(y))]
preds_last(c::CRF, y) = c.W[Flux.onecold(y, 1:length(y)), c.n + 2]
preds_single(c::CRF, y, y_prev) = c.W[Flux.onecold(y_prev, 1:length(y_prev)), Flux.onecold(y, 1:length(y))]

function forward_score(c::CRF, x, init_α)
forward_var = log_sum_exp((c.W .+ transpose(x[1])) .+ init_α)
for i in 2:length(x)
forward_var = log_sum_exp((c.W .+ transpose(x[i])) .+ transpose(forward_var))
end
fs = log_sum_exp(c.W[:, c.n + 2] + transpose(forward_var))
return fs[1]
end

function score_sequence(c::CRF, x, label_seq)
score = preds_first(c, label_seq[1]) + Flux.onecold(label_seq[1], x[1])
for i in 2:length(label_seq)
score += preds_single(c, label_seq[i], label_seq[i-1]) +
Flux.onecold(label_seq[i], x[i])
end
return score + preds_last(c, label_seq[end])
end

crf_loss(c::CRF, x, label_seq, init_α) = forward_score(c, x, init_α) - score_sequence(c, x, label_seq)

label_count = 10
seq_length = 5
crf = CRF(label_count-2)
init_α = fill(-10000.0,label_count)
init_α[label_count-1] = 0.0
label_seq = [Flux.onehot(i,1:label_count) for i in 1:seq_length]
x = [rand(label_count) for _ in 1:seq_length]
print("crf_loss=$(crf_loss(crf,x,label_seq,init_α))")
print("gradient(crf_loss)=$(gradient(() -> crf_loss(crf,x,label_seq,init_α)))")
`

@DhairyaLGandhi
Copy link
Member

Oops.

Could you please add in the stacktrace you see as well? Itd make it easier to spot where the issue is

@opus111
Copy link
Author

opus111 commented Mar 17, 2020

`crf_loss=11.085941941312553ERROR: LoadError: DimensionMismatch("cannot broadcast array to have fewer dimensions")
Stacktrace:
[1] check_broadcast_shape(::Tuple{}, ::Tuple{Base.OneTo{Int64}}) at ./broadcast.jl:506
[2] check_broadcast_shape(::Tuple{Base.OneTo{Int64}}, ::Tuple{Base.OneTo{Int64},Base.OneTo{Int64}}) at ./broadcast.jl:509
[3] check_broadcast_axes at ./broadcast.jl:511 [inlined]
[4] check_broadcast_axes at ./broadcast.jl:515 [inlined]
[5] instantiate at ./broadcast.jl:259 [inlined]
[6] materialize! at ./broadcast.jl:822 [inlined]
[7] (::Zygote.var"#1023#1025"{Array{Float32,2},Tuple{Colon,Int64}})(::Array{Float64,2}) at /Users/peter.wolf/.julia/packages/Zygote/ApBXe/src/lib/array.jl:47
[8] (::Zygote.var"#2707#back#1019"{Zygote.var"#1023#1025"{Array{Float32,2},Tuple{Colon,Int64}}})(::Array{Float64,2}) at /Users/peter.wolf/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:49
[9] forward_score at /Users/peter.wolf/dev/inquiries-proto/v4/src/zygote_bug.jl:29 [inlined]
[10] (::typeof(∂(forward_score)))(::Float64) at /Users/peter.wolf/.julia/packages/Zygote/ApBXe/src/compiler/interface2.jl:0
[11] crf_loss at /Users/peter.wolf/dev/inquiries-proto/v4/src/zygote_bug.jl:42 [inlined]
[12] (::typeof(∂(crf_loss)))(::Float64) at /Users/peter.wolf/.julia/packages/Zygote/ApBXe/src/compiler/interface2.jl:0
[13] #65 at /Users/peter.wolf/dev/inquiries-proto/v4/src/zygote_bug.jl:53 [inlined]
[14] (::typeof(∂(#65)))(::Float64) at /Users/peter.wolf/.julia/packages/Zygote/ApBXe/src/compiler/interface2.jl:0
[15] (::Zygote.var"#38#39"{typeof(∂(#65))})(::Float64) at /Users/peter.wolf/.julia/packages/Zygote/ApBXe/src/compiler/interface.jl:36
[16] gradient(::Function) at /Users/peter.wolf/.julia/packages/Zygote/ApBXe/src/compiler/interface.jl:45
[17] top-level scope at /Users/peter.wolf/dev/inquiries-proto/v4/src/zygote_bug.jl:53
[18] include at ./boot.jl:328 [inlined]
[19] include_relative(::Module, ::String) at ./loading.jl:1105
[20] include(::Module, ::String) at ./Base.jl:31
[21] include(::String) at ./client.jl:424
[22] top-level scope at none:0
in expression starting at /Users/peter.wolf/dev/inquiries-proto/v4/src/zygote_bug.jl:53

`

@mcabbott
Copy link
Member

mcabbott commented Mar 17, 2020

So /src/lib/array.jl:47 is ∇getindex, and adding some @show statements to that, here's what it gets before crashing:

typeof(x) = Array{Float32,2}
size(x) = (10, 10)
typeof(dy) = Array{Float64,2}
size(dy) = (10, 1)
inds = (Colon(), 10)
size(dxv) = (10,)

You can make the error go away by defining Zygote._droplike(dy::AbstractMatrix, dxv::AbstractVector) = vec(dy), this function is from FluxML/Zygote.jl#499. But I'm not certain that's a good idea in general, and it might be worth understanding what causes this.

(Notice also, aside, that the gradient has a different element type, which is a common performance bug, ref #1031.)

@opus111
Copy link
Author

opus111 commented Mar 17, 2020

Thank you for responding so quickly :-D

Yes, adding _droplike does seem to make the issue go away, at least in my short test. Working on a bigger test now.

Please let me know if I can be helpful in any way

@opus111
Copy link
Author

opus111 commented Mar 17, 2020

@mcabbott sorry, but I don't think the fix works. It does run without complaint (and quickly).
However, in my application I do not get models with the same performance. In fact, they are terrible, so I don't think Zygote is producing the correct answer. Is there a good way to compare the results of Flux 9 and Flux 10?

@opus111
Copy link
Author

opus111 commented Mar 17, 2020

I am going to create some known trivial input for the CRF training, and compare the output for Flux9 and 10. With the same hyper parameters the result should be pretty similar. Will report back

@mcabbott
Copy link
Member

No smarter ideas, that sounds like the right course.

@opus111
Copy link
Author

opus111 commented Mar 23, 2020

Good morning Michael. I am pleased to report that when trained with identical starting conditions, the CRF exactly matches the results to Flux 9/Tracker. Unfortunately, I am building a CRF/LSTM model, and when placed on top of another layer it is not training properly. It runs, but does not produce good models. Since I am new to Flux 10, I assume it is my bug and Tomas Pevny has offered to look at my code. However, is it possible that your suggested fix could affect lower layers in a DNN?

@opus111 opus111 closed this as completed Mar 23, 2020
@opus111
Copy link
Author

opus111 commented Mar 31, 2020

Hello Michael, happy April Fool. Unfortunately, not a fool here. I now have 2 versions of CRF on top of other layers that behave the same way-- the CRF works on its own, but the lower layers do not train properly. As in the first example, the weights change, but the loss does not decrease. The second example is the CRF test from TextAnalysis.jl ported to Flux 10. Here is the code. Let me know if you want the full branch of TextAnalysis port to Flux 10

`
LSTM_STATE_SIZE = 5
d_out = Dense(LSTM_STATE_SIZE, num_labels + 2)
lstm = LSTM(num_features, LSTM_STATE_SIZE)
m(x) = d_out.(lstm.(x))

    c = CRF(num_labels)
    init_α = fill(-10000, (stop(c), 1))
    init_α[start(c)] = 0

    loss(xs, ys) = crf_loss(c, m(xs), ys, init_α)

    opt = Descent(0.01)
    data = zip(X, Y)
    ps = params(lstm,d_out,c)

    function train()
        for d in data
            reset!(lstm)
            grads = gradient(() -> loss(d[1], d[2]), ps)
            Flux.Optimise.update!(opt, ps, grads)
        end
    end

    function find_loss(d)
        reset!(lstm)
        loss(d[1], d[2])
    end

    l1 = sum([find_loss(d) for d in data])
    dense_param_1 = deepcopy(d_out.W)
    lstm_param_1 = deepcopy(lstm.cell.Wh)
    crf_param_1 = deepcopy(c.W)

    for i in 1:10
        train()
    end

    dense_param_2 = deepcopy(d_out.W)
    lstm_param_2 = deepcopy(lstm.cell.Wh)
    crf_param_2 = deepcopy(c.W)
    l2 = sum([find_loss(d) for d in data])

    @test l1 > l2
    @test dense_param_1 != dense_param_2
    @test lstm_param_1 != lstm_param_2
    @test crf_param_1 != crf_param_2`

@opus111
Copy link
Author

opus111 commented Mar 31, 2020

Reopening...

@darsnack
Copy link
Member

@opus111 can you check if this issue has been resolved on master?

@darsnack
Copy link
Member

darsnack commented Dec 7, 2020

Bump @opus111 I think you said this can be closed?

@opus111
Copy link
Author

opus111 commented Dec 8, 2020 via email

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants