Skip to content

Restructure on Dense no longer plays nicely with alternative types #1556

Closed
@ChrisRackauckas

Description

@ChrisRackauckas

MWE:

using ReverseDiff, Flux
ann = Chain(Dense(2,10,tanh), Dense(10,1))
p1,re = Flux.destructure(ann)
function f(p)
    re(p[1:41])([1.0,2.0])
end

ReverseDiff.gradient(f,p3)
MethodError: ReverseDiff.TrackedReal{Float32, Float32, ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}}(::ForwardDiff.Dual{ForwardDiff.Tag{ReverseDiff.var"#105#107"{DataType, Tuple{}, Val{(1,)}}, Float32}, Float32, 1}) is ambiguous. Candidates:
  (T::Type{var"#s31"} where var"#s31"<:Real)(x::ForwardDiff.Dual) in Tracker at C:\Users\accou\.julia\packages\Tracker\YNNTM\src\lib\real.jl:111
  ReverseDiff.TrackedReal{V, D, O}(value) where {V, D, O} in ReverseDiff at C:\Users\accou\.julia\packages\ReverseDiff\iHmB4\src\tracked.jl:56
Possible fix, define
  ReverseDiff.TrackedReal{V, D, O}(::ForwardDiff.Dual) where {V, D, O}
macro expansion at broadcast.jl:126 [inlined]
splatcall at broadcast.jl:111 [inlined]
(::ReverseDiff.var"#105#107"{DataType, Tuple{}, Val{(1,)}})(s::StaticArrays.SVector{1, ForwardDiff.Dual{ForwardDiff.Tag{ReverseDiff.var"#105#107"{DataType, Tuple{}, Val{(1,)}}, Float32}, Float32, 1}}) at broadcast.jl:150
static_dual_eval at apiutils.jl:32 [inlined]
vector_mode_gradient! at gradient.jl:125 [inlined]
gradient! at gradient.jl:48 [inlined]
(::ReverseDiff.var"#df#106"{DataType, DiffResults.ImmutableDiffResult{1, Float32, Tuple{StaticArrays.SVector{1, Float32}}}, Tuple{}, Val{(1,)}})(x::Float32) at broadcast.jl:148
_broadcast_getindex_evalf at broadcast.jl:648 [inlined]
_broadcast_getindex at broadcast.jl:621 [inlined]
getindex at broadcast.jl:575 [inlined]
copy at broadcast.jl:922 [inlined]
materialize at broadcast.jl:883 [inlined]
broadcast(f::ReverseDiff.var"#df#106"{DataType, DiffResults.ImmutableDiffResult{1, Float32, Tuple{StaticArrays.SVector{1, Float32}}}, Tuple{}, Val{(1,)}}, As::Vector{Float32}) at broadcast.jl:821
∇broadcast at broadcast.jl:154 [inlined]
copy(_bc::Base.Broadcast.Broadcasted{ReverseDiff.TrackedStyle, Tuple{Base.OneTo{Int64}}, Type{ReverseDiff.TrackedReal{Float32, Float32, ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}}}, Tuple{ReverseDiff.TrackedArray{Float32, Float32, 1, Vector{Float32}, Vector{Float32}}}}) at broadcast.jl:94
materialize(bc::Base.Broadcast.Broadcasted{ReverseDiff.TrackedStyle, Nothing, Type{ReverseDiff.TrackedReal{Float32, Float32, ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}}}, Tuple{ReverseDiff.TrackedArray{Float32, Float32, 1, Vector{Float32}, Vector{Float32}}}}) at broadcast.jl:883
broadcast at broadcast.jl:821 [inlined]
create_bias(weights::ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}, bias::ReverseDiff.TrackedArray{Float32, Float32, 1, Vector{Float32}, Vector{Float32}}, dims::Int64) at utils.jl:397
Dense(W::ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}, bias::ReverseDiff.TrackedArray{Float32, Float32, 1, Vector{Float32}, Vector{Float32}}, σ::typeof(tanh)) at basic.jl:117
(::Flux.var"#126#127")(y::NamedTuple{(:weight, :bias, ), Tuple{ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}, ReverseDiff.TrackedArray{Float32, Float32, 1, Vector{Float32}, Vector{Float32}}, typeof(tanh)}}) at functor.jl:18
fmap1(f::Function, x::Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}) at functor.jl:49
fmap(f::Flux.var"#54#55"{ReverseDiff.TrackedArray{Float32, Float32, 1, Vector{Float32}, Vector{Float32}}}, x::Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}; exclude::typeof(Functors.isleaf), cache::IdDict{Any, Any}) at functor.jl:56

The stack trace is a bit off and gets cut off, but the issue comes from:

function create_bias(weights::AbstractArray, bias::AbstractArray, dims::Integer...)
  size(bias) == dims || throw(DimensionMismatch("expected bias of size $(dims), got size $(size(bias))"))
  if eltype(bias) == eltype(weights)
    return bias
  else
    @warn "converting bias to match element type of weights" typeof(weights) typeof(bias) maxlog=3 _id=hash(dims)
    return broadcast(eltype(weights), bias)
  end
end

Generally that's not the most correct way to convert types. A better way to handle this is:

function Flux.create_bias(weights::AbstractArray, bias::AbstractArray, dims::Integer...)
  size(bias) == dims || throw(DimensionMismatch("expected bias of size $(dims), got size $(size(bias))"))
  if eltype(bias) == eltype(weights)
    return bias
  else
    @warn "converting bias to match element type of weights" typeof(weights) typeof(bias) maxlog=3 _id=hash(dims)
    return convert.((eltype(weights),),bias)
  end
end

However, that's only a partial solution since it shouldn't be throwing a warning on a legitimate use case, so there's a bigger question as to why it ends up there in the first place (which is cut off from the stack trace).

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions