Skip to content

Commit c89c5d0

Browse files
committed
Add ChainRules fallback
Add Manifest do it in the right place and add blacklisting Update blacklist to include all higher order functions blacklist broken sum Update src/compiler/interface2.jl
1 parent d9474cc commit c89c5d0

File tree

5 files changed

+105
-13
lines changed

5 files changed

+105
-13
lines changed

Manifest.toml

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,17 @@ git-tree-sha1 = "c69698c3d4a7255bc1b4bc2afc09f59db910243b"
2727
uuid = "00ebfdb7-1f24-5e51-bd34-a7502290713f"
2828
version = "0.6.2"
2929

30+
[[ChainRules]]
31+
deps = ["ChainRulesCore", "LinearAlgebra", "Reexport", "Requires", "Statistics"]
32+
git-tree-sha1 = "0d6f9017442dc7a00f53dcc1174e4e0c2a2c7751"
33+
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
34+
version = "0.2.1"
35+
36+
[[ChainRulesCore]]
37+
git-tree-sha1 = "a493cc9352df2d99790f9f1225dfd9fbc52cd13e"
38+
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
39+
version = "0.3.0"
40+
3041
[[CommonSubexpressions]]
3142
deps = ["Test"]
3243
git-tree-sha1 = "efdaf19ab11c7889334ca247ff4c9f7c322817b0"
@@ -53,9 +64,9 @@ version = "4.0.0"
5364

5465
[[DataStructures]]
5566
deps = ["InteractiveUtils", "OrderedCollections"]
56-
git-tree-sha1 = "0809951a1774dc724da22d26e4289bbaab77809a"
67+
git-tree-sha1 = "f94423c68f2e47db0d6f626a26d4872266e0ec3d"
5768
uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
58-
version = "0.17.0"
69+
version = "0.17.2"
5970

6071
[[Dates]]
6172
deps = ["Printf"]
@@ -83,15 +94,15 @@ uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"
8394

8495
[[FFTW]]
8596
deps = ["AbstractFFTs", "BinaryProvider", "Conda", "Libdl", "LinearAlgebra", "Reexport", "Test"]
86-
git-tree-sha1 = "e1a479d3c972f20c9a70563eec740bbfc786f515"
97+
git-tree-sha1 = "6c5b420da0b8c12098048561b8d58f81adea506f"
8798
uuid = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
88-
version = "0.3.0"
99+
version = "1.0.1"
89100

90101
[[FillArrays]]
91102
deps = ["LinearAlgebra", "Random", "SparseArrays"]
92-
git-tree-sha1 = "8fba6ddaf66b45dec830233cea0aae43eb1261ad"
103+
git-tree-sha1 = "16974065d5bfa867446d3228bc63f05a440e910b"
93104
uuid = "1a297f60-69ca-5386-bcde-b61e274b549b"
94-
version = "0.6.4"
105+
version = "0.7.2"
95106

96107
[[ForwardDiff]]
97108
deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "InteractiveUtils", "LinearAlgebra", "NaNMath", "Random", "SparseArrays", "SpecialFunctions", "StaticArrays", "Test"]
@@ -101,7 +112,7 @@ version = "0.10.3"
101112

102113
[[IRTools]]
103114
deps = ["InteractiveUtils", "MacroTools", "Test"]
104-
git-tree-sha1 = "a66befa9ebc63e465212281ac027c1526382bc4e"
115+
git-tree-sha1 = "09e4091acb2c4aac29a673fab16e0f1aa8672b2a"
105116
repo-rev = "master"
106117
repo-url = "https://github.com/MikeInnes/IRTools.jl.git"
107118
uuid = "7869d1d1-7146-5819-86e3-90919afe41df"
@@ -168,7 +179,7 @@ uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0"
168179
version = "0.3.7"
169180

170181
[[Pkg]]
171-
deps = ["Dates", "LibGit2", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"]
182+
deps = ["Dates", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"]
172183
uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
173184

174185
[[Printf]]
@@ -213,10 +224,10 @@ deps = ["LinearAlgebra", "Random"]
213224
uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
214225

215226
[[SpecialFunctions]]
216-
deps = ["BinDeps", "BinaryProvider", "Libdl", "Test"]
217-
git-tree-sha1 = "0b45dc2e45ed77f445617b99ff2adf0f5b0f23ea"
227+
deps = ["BinDeps", "BinaryProvider", "Libdl"]
228+
git-tree-sha1 = "3bdd374b6fd78faf0119b8c5d538788dbf910c6e"
218229
uuid = "276daf66-3868-5448-9aa4-cd146d93841b"
219-
version = "0.7.2"
230+
version = "0.8.0"
220231

221232
[[StaticArrays]]
222233
deps = ["LinearAlgebra", "Random", "Statistics"]

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
33
version = "0.3.4"
44

55
[deps]
6+
ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"
67
DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b"
78
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
89
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
@@ -20,6 +21,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
2021
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
2122

2223
[compat]
24+
ChainRules = "0.2.1"
2325
IRTools = "0.2.3"
2426
NNlib = "0.6"
2527
ZygoteRules = "0.2"

src/Zygote.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ using LinearAlgebra: copytri!, AbstractTriangular
55

66
import ZygoteRules: @adjoint, @adjoint!, AContext, adjoint, _pullback, pullback, literal_getproperty
77

8+
using ChainRules: ChainRules
89
using IRTools
910
using MacroTools, Requires
1011
using MacroTools: @forward

src/compiler/interface.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ end
2727
# interface2.jl
2828

2929
# Wrappers
30-
3130
_pullback(f, args...) = _pullback(Context(), f, args...)
3231

3332
tailmemaybe(::Nothing) = nothing

src/compiler/interface2.jl

Lines changed: 80 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,86 @@ using IRTools.Inner: argnames!, update!
33

44
ignore(T) = all(T -> T <: Type, T.parameters)
55

6-
@generated function _pullback(ctx::AContext, f, args...)
6+
7+
function _pullback(ctx::AContext, f, args...)
8+
if chainrules_blacklist(f, args...)
9+
# then don't even consider using ChainRules
10+
return _pullback_via_source2source(ctx, f, args...)
11+
end
12+
13+
res = ChainRules.rrule(f, args...)
14+
if res === nothing
15+
# No ChainRule defined, time to do the source tranform
16+
return _pullback_via_source2source(ctx, f, args...)
17+
else
18+
# Can just use ChainRule answer
19+
y, pb = res
20+
return y, _pullback_via_chainrules(pb)
21+
end
22+
end
23+
24+
#=="""
25+
chainrules_blacklist(f, args...,)
26+
27+
This is used to disable the use of ChainRule's definitions
28+
for particular functions/methods.
29+
30+
It is not required if a Zygote rule has already been defined directly.
31+
"""==#
32+
chainrules_blacklist(f, args...) = false
33+
34+
# ChainRules does higher-order functions badly
35+
# see https://github.com/JuliaDiff/ChainRules.jl/issues/122
36+
chainrules_blacklist(::typeof(map), args...) = true
37+
chainrules_blacklist(::typeof(broadcast), args...) = true
38+
chainrules_blacklist(::typeof(mapreduce), args...) = true
39+
chainrules_blacklist(::typeof(mapfoldl), args...) = true
40+
chainrules_blacklist(::typeof(mapfoldr), args...) = true
41+
chainrules_blacklist(::typeof(sum), f, x::AbstractArray{<:Real}) = true
42+
# Except for sum(abs2, xs), that is fine
43+
chainrules_blacklist(::typeof(sum), ::typeof(abs2), x::AbstractArray{<:Real}) = false
44+
45+
# ChainRules current Wirtinger deriviative is not compatible
46+
# reconsider after https://github.com/JuliaDiff/ChainRulesCore.jl/pull/29
47+
chainrules_blacklist(::typeof(abs), ::Complex) = true
48+
chainrules_blacklist(::typeof(abs2), ::Complex) = true
49+
chainrules_blacklist(::typeof(conj), ::Complex) = true
50+
chainrules_blacklist(::typeof(adjoint), ::Complex) = true
51+
chainrules_blacklist(::typeof(hypot), ::Complex) = true
52+
chainrules_blacklist(::typeof(angle), ::Complex) = true
53+
chainrules_blacklist(::typeof(imag), ::Complex) = true
54+
chainrules_blacklist(::typeof(real), ::Complex) = true
55+
56+
# Sum of nonarrays doesn't really work
57+
# Fixed in https://github.com/JuliaDiff/ChainRules.jl/pull/124
58+
chainrules_blacklist(::typeof(sum), x) = true
59+
chainrules_blacklist(::typeof(sum), x::AbstractArray{<:Real}) = false
60+
61+
62+
#=="""
63+
_pullback_via_chainrules(pb)
64+
65+
Converts a ChainRules pullback into a Zygote pullback.
66+
`pb` should be a ChainRules pullback, as returned from the second return value of `rrule`
67+
"""==#
68+
function _pullback_via_chainrules(pb)
69+
# The less optimized version of this code is
70+
# cback2zback(pb) = (Δs...) -> zextern.(pb(Δs...))
71+
function zback(Δs...)
72+
∂s = pb(Δs...)
73+
ntuple(length(∂s)) do ii
74+
= ∂s[ii]
75+
zextern(∂)
76+
end
77+
end
78+
end
79+
80+
zextern(x) = ChainRules.extern(x)
81+
zextern(::ChainRules.Zero) = nothing # Zygote loves calling things nothing
82+
zextern(::ChainRules.DNE) = nothing # Zygote loves calling things nothing
83+
84+
85+
@generated function _pullback_via_source2source(ctx::AContext, f, args...)
786
T = Tuple{f,args...}
887
ignore(T) && return :(f(args...), Pullback{$T}(()))
988
g = try _lookup_grad(T) catch e e end

0 commit comments

Comments
 (0)