Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support ChainRules AbstractZero types #786

Merged
merged 63 commits into from
Oct 9, 2020

Conversation

mzgubic
Copy link
Collaborator

@mzgubic mzgubic commented Sep 9, 2020

WIP for #603. Depends on this branch FluxML/ZygoteRules.jl#10

The work can be split into three parts:

  • support AbstractZero
  • support Composite types
  • support thunks

This PR deals with the first part

@mzgubic
Copy link
Collaborator Author

mzgubic commented Sep 9, 2020

Current state leaves three types of errors in tests:

  • try/catch is not supported.
  • nested AD hitting identity(::Tuple) pullback
  • Non-differentiable function bitcast

and the deprecation warnings which will be gone once the Zygote adjoints are changed to use Zero() instead of nothing

src/compiler/interface.jl Outdated Show resolved Hide resolved
src/compiler/interface2.jl Outdated Show resolved Hide resolved
src/lib/lib.jl Outdated Show resolved Hide resolved
src/lib/lib.jl Outdated Show resolved Hide resolved
src/lib/lib.jl Outdated Show resolved Hide resolved
Co-authored-by: Lyndon White <oxinabox@ucc.asn.au>
@oxinabox
Copy link
Member

oxinabox commented Sep 9, 2020

try/catch is not supported.

This is interesting, I don't know what might have changed in the code that introduced a try-catch

nested AD hitting identity(::Tuple) pullback

This has causes problems before.

@mzgubic
Copy link
Collaborator Author

mzgubic commented Oct 2, 2020

Three issues remain:

two similar errors [1], [2] coming from accum:
@test sin'''(1.0) == -cos(1.0)
(from features.jl:171)
a3, pb3 = Zygote.pullback(h, 1)
(from chainrules.jl:157)

legacytype warning:
@test gradtest(X -> sum(x -> x^2, X), randn(10))
(from gradcheck.jl:138)

difftype warning (needs to be turned on in adjoint.jl):
gradient(x -> sum(Float32.(x)), [1.0])
(from precompile.jl: 22)

[1]

Features: Error During Test at /Users/mzgubic/Projects/Zygote.jl/test/features.jl:171
  Test threw exception
  Expression: (((sin')')')(1.0) == -(cos(1.0))
  MethodError: no method matching +(::Tuple{Zero,NamedTuple{(:back,),Tuple{NamedTuple{(:back,),Tuple{NamedTuple{(:x,),Tuple{Float64}}}}}}}, ::Tuple{Zero,NamedTuple{(:back,),Tuple{NamedTuple{(:back,),Tuple{NamedTuple{(:x,),Tuple{Float64}}}}}}})
  Closest candidates are:
    +(::Any, ::Any, !Matched::Any, !Matched::Any...) at operators.jl:529
    +(!Matched::ChainRulesCore.DoesNotExist, ::Any) at /Users/mzgubic/.julia/packages/ChainRulesCore/y9QnK/src/differential_arithmetic.jl:22
    +(!Matched::ChainRulesCore.One, ::Any) at /Users/mzgubic/.julia/packages/ChainRulesCore/y9QnK/src/differential_arithmetic.jl:83
    ...
  Stacktrace:
   [1] macro expansion at /Users/mzgubic/Projects/Zygote.jl/src/compiler/interface2.jl:0 [inlined]
   [2] _pullback(::Zygote.Context, ::typeof(+), ::Tuple{Zero,NamedTuple{(:back,),Tuple{NamedTuple{(:back,),Tuple{NamedTuple{(:x,),Tuple{Float64}}}}}}}, ::Tuple{Zero,NamedTuple{(:back,),Tuple{NamedTuple{(:back,),Tuple{NamedTuple{(:x,),Tuple{Float64}}}}}}}) at /Users/mzgubic/Projects/Zygote.jl/src/compiler/interface2.jl:13
   [3] accum at /Users/mzgubic/Projects/Zygote.jl/src/lib/lib.jl:8 [inlined]
   [4] _pullback(::Zygote.Context, ::typeof(Zygote.accum), ::Tuple{Zero,NamedTuple{(:back,),Tuple{NamedTuple{(:back,),Tuple{NamedTuple{(:x,),Tuple{Float64}}}}}}}, ::Zero) at /Users/mzgubic/Projects/Zygote.jl/src/compiler/interface2.jl:0
   [5] gradient at /Users/mzgubic/Projects/Zygote.jl/src/compiler/interface.jl:53 [inlined]
   [6] _pullback(::Zygote.Context, ::typeof(∂(gradient)), ::Tuple{Float64}) at /Users/mzgubic/Projects/Zygote.jl/src/compiler/interface2.jl:0
   [7] #47 at /Users/mzgubic/Projects/Zygote.jl/src/compiler/interface.jl:57 [inlined]
   [8] _pullback(::Zygote.Context, ::typeof(∂(#47)), ::Float64) at /Users/mzgubic/Projects/Zygote.jl/src/compiler/interface2.jl:0
   [9] #45 at /Users/mzgubic/Projects/Zygote.jl/src/compiler/interface.jl:45 [inlined]
   [10] _pullback(::Zygote.Context, ::Zygote.var"#45#46"{typeof(∂(#47))}, ::Float64) at /Users/mzgubic/Projects/Zygote.jl/src/compiler/interface2.jl:0
   [11] gradient at /Users/mzgubic/Projects/Zygote.jl/src/compiler/interface.jl:54 [inlined]
   [12] _pullback(::Zygote.Context, ::typeof(gradient), ::Zygote.var"#47#48"{typeof(sin)}, ::Float64) at /Users/mzgubic/Projects/Zygote.jl/src/compiler/interface2.jl:0
   [13] #47 at /Users/mzgubic/Projects/Zygote.jl/src/compiler/interface.jl:57 [inlined]
   [14] _pullback(::Zygote.Context, ::Zygote.var"#47#48"{Zygote.var"#47#48"{typeof(sin)}}, ::Float64) at /Users/mzgubic/Projects/Zygote.jl/src/compiler/interface2.jl:0
   [15] _pullback(::Function, ::Float64) at /Users/mzgubic/Projects/Zygote.jl/src/compiler/interface.jl:38
   [16] pullback(::Function, ::Float64) at /Users/mzgubic/Projects/Zygote.jl/src/compiler/interface.jl:44
   [17] gradient(::Function, ::Float64) at /Users/mzgubic/Projects/Zygote.jl/src/compiler/interface.jl:53
   [18] (::Zygote.var"#47#48"{Zygote.var"#47#48"{Zygote.var"#47#48"{typeof(sin)}}})(::Float64) at /Users/mzgubic/Projects/Zygote.jl/src/compiler/interface.jl:57
   [19] top-level scope at /Users/mzgubic/Projects/Zygote.jl/test/features.jl:171
   [20] include at ./boot.jl:328 [inlined]
   [21] include_relative(::Module, ::String) at ./loading.jl:1105
   [22] include(::Module, ::String) at ./Base.jl:31
   [23] include(::String) at ./client.jl:424
   [24] top-level scope at /Users/mzgubic/Projects/Zygote.jl/test/runtests.jl:20
   [25] top-level scope at /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.3/Test/src/Test.jl:1107
   [26] top-level scope at /Users/mzgubic/Projects/Zygote.jl/test/runtests.jl:20
   [27] top-level scope at /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.3/Test/src/Test.jl:1107
   [28] top-level scope at /Users/mzgubic/Projects/Zygote.jl/test/runtests.jl:7

[2]

nested AD hitting identity(::Tuple) pullback: Error During Test at /Users/mzgubic/Projects/Zygote.jl/test/chainrules.jl:143
  Got exception outside of a @test
  MethodError: no method matching +(::NamedTuple{(:x, :y),Tuple{Int64,Zero}}, ::NamedTuple{(:x, :y),Tuple{Int64,Zero}})
  Closest candidates are:
    +(::Any, ::Any, !Matched::Any, !Matched::Any...) at operators.jl:529
    +(!Matched::DoesNotExist, ::Any) at /Users/mzgubic/.julia/packages/ChainRulesCore/y9QnK/src/differential_arithmetic.jl:22
    +(!Matched::One, ::Any) at /Users/mzgubic/.julia/packages/ChainRulesCore/y9QnK/src/differential_arithmetic.jl:83
    ...
  Stacktrace:
   [1] macro expansion at /Users/mzgubic/Projects/Zygote.jl/src/compiler/interface2.jl:0 [inlined]
   [2] _pullback(::Zygote.Context, ::typeof(+), ::NamedTuple{(:x, :y),Tuple{Int64,Zero}}, ::NamedTuple{(:x, :y),Tuple{Int64,Zero}}) at /Users/mzgubic/Projects/Zygote.jl/src/compiler/interface2.jl:13
   [3] accum at /Users/mzgubic/Projects/Zygote.jl/src/lib/lib.jl:8 [inlined]
   [4] _pullback(::Zygote.Context, ::typeof(Zygote.accum), ::NamedTuple{(:x, :y),Tuple{Int64,Zero}}, ::Zero) at /Users/mzgubic/Projects/Zygote.jl/src/compiler/interface2.jl:0
   [5] _pullback(::Zygote.Context, ::typeof(∂(λ)), ::Tuple{Zero,Zero,Int64}) at /Users/mzgubic/.julia/packages/ChainRules/N02le/src/rulesets/Base/fastmath_able.jl:188
   [6] ZBack at /Users/mzgubic/Projects/Zygote.jl/src/compiler/chainrules.jl:77 [inlined]
   [7] _pullback(::Zygote.Context, ::typeof(∂(λ)), ::Tuple{Zero,Zero,Int64}) at /Users/mzgubic/Projects/Zygote.jl/src/compiler/interface2.jl:0
   [8] f at /Users/mzgubic/Projects/Zygote.jl/test/chainrules.jl:147 [inlined]
   [9] _pullback(::Zygote.Context, ::typeof(∂(λ)), ::Tuple{Zero,Int64}) at /Users/mzgubic/Projects/Zygote.jl/src/compiler/interface2.jl:0
   [10] #45 at /Users/mzgubic/Projects/Zygote.jl/src/compiler/interface.jl:45 [inlined]
   [11] _pullback(::Zygote.Context, ::typeof(∂(λ)), ::Int64) at /Users/mzgubic/Projects/Zygote.jl/src/compiler/interface2.jl:0
   [12] g at /Users/mzgubic/Projects/Zygote.jl/test/chainrules.jl:151 [inlined]
   [13] _pullback(::Zygote.Context, ::typeof(∂(g)), ::Int64) at /Users/mzgubic/Projects/Zygote.jl/src/compiler/interface2.jl:0
   [14] #45 at /Users/mzgubic/Projects/Zygote.jl/src/compiler/interface.jl:45 [inlined]
   [15] _pullback(::Zygote.Context, ::Zygote.var"#45#46"{typeof(∂(g))}, ::Int64) at /Users/mzgubic/Projects/Zygote.jl/src/compiler/interface2.jl:0
   [16] h at /Users/mzgubic/Projects/Zygote.jl/test/chainrules.jl:158 [inlined]
   [17] _pullback(::Zygote.Context, ::var"#h#184"{var"#g#183"{var"#f#182"}}, ::Int64) at /Users/mzgubic/Projects/Zygote.jl/src/compiler/interface2.jl:0
   [18] _pullback(::Function, ::Int64) at /Users/mzgubic/Projects/Zygote.jl/src/compiler/interface.jl:38
   [19] pullback(::Function, ::Int64) at /Users/mzgubic/Projects/Zygote.jl/src/compiler/interface.jl:44
   [20] top-level scope at /Users/mzgubic/Projects/Zygote.jl/test/chainrules.jl:165
   [21] top-level scope at /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.3/Test/src/Test.jl:1107
   [22] top-level scope at /Users/mzgubic/Projects/Zygote.jl/test/chainrules.jl:147
   [23] top-level scope at /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.3/Test/src/Test.jl:1107
   [24] top-level scope at /Users/mzgubic/Projects/Zygote.jl/test/chainrules.jl:5
   [25] include at ./boot.jl:328 [inlined]
   [26] include_relative(::Module, ::String) at ./loading.jl:1105
   [27] include(::Module, ::String) at ./Base.jl:31
   [28] include(::String) at ./client.jl:424
   [29] top-level scope at /Users/mzgubic/Projects/Zygote.jl/test/runtests.jl:32
   [30] top-level scope at /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.3/Test/src/Test.jl:1107
   [31] top-level scope at /Users/mzgubic/Projects/Zygote.jl/test/runtests.jl:32
   [32] top-level scope at /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.3/Test/src/Test.jl:1107
   [33] top-level scope at /Users/mzgubic/Projects/Zygote.jl/test/runtests.jl:7
   [34] include at ./boot.jl:328 [inlined]
   [35] include_relative(::Module, ::String) at ./loading.jl:1105
   [36] include(::Module, ::String) at ./Base.jl:31
   [37] include(::String) at ./client.jl:424
   [38] top-level scope at none:6
   [39] eval(::Module, ::Any) at ./boot.jl:330
   [40] exec_options(::Base.JLOptions) at ./client.jl:263
   [41] _start() at ./client.jl:460

src/lib/lib.jl Outdated Show resolved Hide resolved
@oxinabox
Copy link
Member

oxinabox commented Oct 2, 2020

two similar errors [1], [2] coming from accum:

I don't know the cause, I suspect it is something that Zygote was suposed to remove prior to that point.
But those errors will be fixed by changing to Composite
so I think we can treat them as non-blocking and revisit them once we have composite in also

Copy link
Member

@willtebbutt willtebbutt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Broadly looking good, just a few points to go over.

src/Zygote.jl Outdated Show resolved Hide resolved
src/compiler/emit.jl Show resolved Hide resolved
src/compiler/interface.jl Outdated Show resolved Hide resolved
src/lib/broadcast.jl Outdated Show resolved Hide resolved
src/lib/broadcast.jl Outdated Show resolved Hide resolved
test/runtests.jl Show resolved Hide resolved
@oxinabox oxinabox merged commit c276d83 into FluxML:chainrules_types Oct 9, 2020
@DhairyaLGandhi
Copy link
Member

Are there some benchmarks to see the realized output bump in performance here?

@oxinabox
Copy link
Member

oxinabox commented Oct 9, 2020

Let's wait til we have all the changes in.
Note that this was only a PR to the chainrules_types branch,
where we are staging the 3 or 4 seperate PRs to bring everything over.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants