Skip to content

Consistency in the type behavior of restructure #95

Open
@ChrisRackauckas

Description

@ChrisRackauckas

This was discovered in SciML/NeuralPDE.jl#533 as an issue that only showed itself as an incorrect gradient: the primal passes of what was being trained was in Float64, the reverse passes gave a Float64, the loss function print out give a Float64, and everything looked fine, except magically the Flux neural network was just "a bit more janky", in that it had a much higher probability of failing CI tests for a reason nobody could figure out for 5 months. Finally it was discovered that parts of the gradient were calculated in Float32 because the Flux.Chain had Float32 parameters in there. This showcased that re(p) does not "always" respect the types of p.

But it doesn't "always" respect the types of the Flux.Chain either. For example, for a standard Flux.Chain of Dense layers with Float32 parameters, you get:

  • re(p::Vector{Float64}) computes in Float32
  • re(p::CuVector{Float32}) computes on the GPU in Float32
  • re(p::Vector{Dual}) computes with Dual numbers
  • re(p::Vector{ComplexF32}) computes with Float32

And now let's have some fun:

  • re(p::CuVector{Float64}) computes ???. My guess is CuVector{Float32}?
  • re(p::ReverseDiff.TrackedArray) computes ??? My guess is Array{TrackedReal{Float32}}?

I understand that this isn't intended behavior and comes out of some quirks about ProjectTo , that exposes some (IMO odd) behaviors of a ChainRules internal to users who are likely not experts in the autodiff system.

Now the problem that I have with it is that discovering this behavior is rather hard, because if you do anything other than the simplest "just use the neural network", almost any case will not expose to the user that this behavior exists. For example,

  • (p[end] .* re(p))::typeof(p)
  • (p[end] .+ re(p))::typeof(p)
  • ...

so hold in the examples I described because the type demotion is countered by the type promotion that's applied by essentially any other computation that uses things with the eltype(p). Thus unless re(p) is the only operation that is used (in which case, you probably don't need to be using restructure/destructure), some other operation in the primal will mask the demotion and your forward pass will look like it computed using typeof(p). It will only present itself to a user in the gradient pass.

Thus I understand @mcabbott's reasoning behind saying it's not a gradient correctness issue (since it's correctly calculating the gradients of the object that is actually reconstructed), but I have now isolated many different cases that I thought were just "Flux janky behavior" and "I don't know why FastChain works here but Flux.Chain doesn't" all back to this same behavior. It may not be a gradient correctness issue, but it only presents itself as one in downstream libraries where I have found this, it only really exposes itself if you try to look into a seemingly incorrect gradient, and if it quacks like 🦆?

I understand that this behavior is now documented, but I'm not sure a behavior that presents itself like that is sufficiently handled just by documentation because it's hard to even figure out that something is going wrong without investigating the gradient calculation.

What could be done?

I would propose that we should just make the behavior undeniably straightforward and consistent. Either always make re(p) compute using values of typeof(p), or make it so it always computes using the values from the original Flux.Chain. Either choice is an easily explainable and predictable behavior. This middle ground is not easy to explain or predict.

Always matching p is the more predictable behavior in the Julia ecosystem. If you stick a complex number as the initial condition in the ODE solver, as the initial guess for a value in Optim, as the starting point for IterativeSolvers or NLsolve, etc. any generic code that I can think of, they will treat the computation in the sense that p provides. In many cases generic codes will just error if they can't handle it, but they try to compute using p. Non-generic codes immediately throw method errors describing what the allowed inputs are. I cannot think of another example in the Julia ecosystem where the "computation type" for f(p) does not match p or a fixed type, but instead match the internal types of the fields of f, only sometimes, other times it matches p.

If it always matches the Flux.Chain, at least that would be clearly visible since when you do it on a CuArray you see you get an Array and you're like oh, I see how this works. If I want to GPU, then I |> gpu the chain because it doesn't convert to p. Got it. With the current behavior, you see it re(p) works on the GPU, so okay why not just do re(p::Array{Float64}) as a quick way to convert to Float64? And if you think like that, you get burned.

The other behavior could be to throw an error in any case where a type conversion is necessary. If you want re(p::Array{Float64}) to work, go back and |> f64 the neural network. Now, this will cause some issues with making libraries work, but it's a nice (overly) safe option that would ensure there are no surprises.

Or, as @ToucheSir suggested, maybe these are two different functions, or two different options, and you should be required to choose which behavior you want. Some kind of re(p,Optimisers.NoConvert()) and re(p,Optimisers.Convert()).

Those 4 behaviors would be clear and easily predictable. I think the only option I would be adamantly against is the current behavior.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions