Skip to content

ZygoteAD dot operator fails again #1595

@Red-Portal

Description

@Red-Portal

Hi, issue #1207 is active again.

Here's the code in question

Turing.@model logistic_regression(X, y, n, d, σ) = begin
    β  ~ MvNormal(d, σ)
    α  ~ Normal(0, σ)
    s  = X*β .+ α
    y .~ Turing.BernoulliLogit.(s)
end

begin 
           data_x, data_y = fetch_dataset(:pima)
           n_dims = size(data_x, 2)
           n_data = size(data_x, 1)
           model  = logistic_regression(data_x,
                                        data_y,
                                        n_data, 
                                        n_dims,
                                        1.0)
           Turing.setadbackend(:zygote)
           spl  = Turing.NUTS(100, 0.8; max_depth=8)
           Turing.sample(model, spl, 200) 
end

Here's the error messages

Compiling Tuple{typeof(DynamicPPL.dot_observe), DynamicPPL.SampleFromPrior, Vector{Turing.BinomialLogit{Float64, Float64}}, Vector{Float64}, DynamicPPL.ThreadSafeVarInfo{DynamicPPL.TypedVarInfo{NamedTuple{(, ), Tuple{DynamicPPL.Metadata{Dict{AbstractPPL.VarName{, Tuple{}}, Int64}, Vector{ZeroMeanIsoNormal{Tuple{Base.OneTo{Int64}}}}, Vector{AbstractPPL.VarName{, Tuple{}}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}, DynamicPPL.Metadata{Dict{AbstractPPL.VarName{, Tuple{}}, Int64}, Vector{Normal{Float64}}, Vector{AbstractPPL.VarName{, Tuple{}}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}}, Float64}, Vector{Base.RefValue{Float64}}}}: try/catch is not supported.
Stacktrace:
  [1] instrument(ir::IRTools.Inner.IR)
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/reverse.jl:121
  [2] #Primal#20
    @ ~/.julia/packages/Zygote/6HN9x/src/compiler/reverse.jl:202 [inlined]
  [3] Zygote.Adjoint(ir::IRTools.Inner.IR; varargs::Nothing, normalise::Bool)
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/reverse.jl:315
  [4] _lookup_grad(T::Type)
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/emit.jl:101
  [5] #s2996#1179
    @ ~/.julia/packages/Zygote/6HN9x/src/compiler/interface2.jl:34 [inlined]
  [6] var"#s2996#1179"(T::Any, j::Any, Δ::Any)
    @ Zygote ./none:0
  [7] Pullback
    @ ~/.julia/packages/Turing/rHLGJ/src/inference/hmc.jl:526 [inlined]
  [8] (::typeof((dot_observe)))(Δ::Int64)
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface2.jl:0
  [9] Pullback
    @ ~/.julia/packages/DynamicPPL/8Oe1A/src/context_implementations.jl:428 [inlined]
 [10] Pullback
    @ ~/.julia/packages/DynamicPPL/8Oe1A/src/context_implementations.jl:386 [inlined]
 [11] (::typeof((dot_tilde)))(Δ::Int64)
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface2.jl:0
 [12] Pullback
    @ ~/.julia/packages/DynamicPPL/8Oe1A/src/context_implementations.jl:408 [inlined]
 [13] (::typeof((dot_tilde_observe)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface2.jl:0
 [14] Pullback
    @ ~/Projects/KLpqVI.jl/scripts/task/logistic.jl:6 [inlined]
 [15] (::typeof((#1948)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface2.jl:0
 [16] macro expansion
    @ ~/.julia/packages/DynamicPPL/8Oe1A/src/model.jl:0 [inlined]
 [17] Pullback
    @ ~/.julia/packages/DynamicPPL/8Oe1A/src/model.jl:154 [inlined]
 [18] (::typeof((_evaluate)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface2.jl:0
 [19] Pullback
    @ ~/.julia/packages/DynamicPPL/8Oe1A/src/model.jl:144 [inlined]
 [20] (::typeof((evaluate_threadsafe)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface2.jl:0
 [21] Pullback
    @ ~/.julia/packages/DynamicPPL/8Oe1A/src/model.jl:94 [inlined]
 [22] (::typeof((λ)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface2.jl:0
 [23] #180
    @ ~/.julia/packages/Zygote/6HN9x/src/lib/lib.jl:194 [inlined]
 [24] #1689#back
    @ ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59 [inlined]
 [25] Pullback
    @ ~/.julia/packages/DynamicPPL/8Oe1A/src/model.jl:98 [inlined]
 [26] Pullback
    @ ~/.julia/packages/Turing/rHLGJ/src/core/ad.jl:165 [inlined]
 [27] (::typeof((λ)))(Δ::Int64)
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface2.jl:0
 [28] (::Zygote.var"#41#42"{typeof((λ))})(Δ::Int64)
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface.jl:41
 [29] gradient_logp(backend::Turing.Core.ZygoteAD, θ::Vector{Float64}, vi::DynamicPPL.TypedVarInfo{NamedTuple{(, ), Tuple{DynamicPPL.Metadata{Dict{AbstractPPL.VarName{, Tuple{}}, Int64}, Vector{ZeroMeanIsoNormal{Tuple{Base.OneTo{Int64}}}}, Vector{AbstractPPL.VarName{, Tuple{}}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}, DynamicPPL.Metadata{Dict{AbstractPPL.VarName{, Tuple{}}, Int64}, Vector{Normal{Float64}}, Vector{AbstractPPL.VarName{, Tuple{}}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}}, Float64}, model::DynamicPPL.Model{var"#1948#1949", (:X, :y, :n, :d, :σ), (), (), Tuple{Matrix{Float64}, Vector{Float64}, Int64, Int64, Float64}, Tuple{}}, sampler::DynamicPPL.Sampler{Turing.Inference.NUTS{Turing.Core.ZygoteAD, (), AdvancedHMC.DiagEuclideanMetric}}, context::DynamicPPL.DefaultContext)
    @ Turing.Core ~/.julia/packages/Turing/rHLGJ/src/core/ad.jl:171

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions