Skip to content

Commit

Permalink
Merge pull request #1249 from mzgubic/mz/deprecate_nograd
Browse files Browse the repository at this point in the history
deprecate `Zygote.@nograd`
  • Loading branch information
ToucheSir authored Jul 30, 2022
2 parents 5ffbd43 + af434d6 commit cb59b6c
Show file tree
Hide file tree
Showing 11 changed files with 23 additions and 35 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[compat]
AbstractFFTs = "0.5, 1.0"
ChainRules = "1.36.2"
ChainRules = "1.37"
ChainRulesCore = "1.9"
ChainRulesTestUtils = "1"
DiffRules = "1.4"
Expand Down
2 changes: 1 addition & 1 deletion src/Zygote.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ include("profiler/Profile.jl")
end

@init @require Colors="5ae59095-9a9b-59fe-a467-6f913c188581" begin
@nograd Colors.ColorTypes._parameter_upper_bound
@non_differentiable Colors.ColorTypes._parameter_upper_bound(::Any...)
end

using InteractiveUtils
Expand Down
16 changes: 16 additions & 0 deletions src/deprecated.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,19 @@ macro ignore(ex)
$(esc(ex))
end)
end

using MacroTools: @q

macro nograd(ex)
Base.depwarn(
"`Zygote.@nograd myfunc` is deprecated, use `ChainRulesCore.@non_differentiable myfunc(::Any...)` instead.",
:nograd
)
isexpr(ex, :tuple) || (ex = Expr(:tuple, ex))
blk = @q begin end
for f in ex.args
back = MacroTools.@q _ -> ($__source__; nothing)
push!(blk.args, :(@inline Zygote._pullback(::Context, ::Core.Typeof($(esc(f))), args...) = $(esc(f))(args...), $back))
end
return blk
end
2 changes: 1 addition & 1 deletion src/forward/lib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ end
# TODO figure out why this made a test fail
zerolike(x::Union{Module,Type}) = nothing

# TODO: `@nograd` and `@linear`
# TODO: `@non_differentiable` and `@linear`

@tangent zerolike(x) = zerolike(x), _ -> zerolike(x)
@tangent one(x::Number) = one(x), _ -> zero(x)
Expand Down
7 changes: 0 additions & 7 deletions src/lib/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@ using Distributed: pmap, AbstractWorkerPool
@adjoint Array(xs::AbstractArray) = Array(xs), ȳ -> (ȳ,)
@adjoint Array(xs::Array) = Array(xs), ȳ -> (ȳ,)

@nograd ones, zeros, Base.OneTo, Colon(), one, zero, sizehint!, count

@adjoint copy(x::AbstractArray) = copy(x), ȳ -> (ȳ,)

@adjoint collect(x::Tuple) = collect(x), dy -> (Tuple(dy),)
Expand Down Expand Up @@ -233,11 +231,6 @@ end
end
end

for t in subtypes(AbstractWorkerPool)
@nograd t
end
@nograd workers

function _pullback(cx::AContext, ::typeof(collect), g::Base.Generator)
y, b = ∇map(cx, g.f, g.iter)
back(::Nothing) = nothing
Expand Down
4 changes: 0 additions & 4 deletions src/lib/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,6 @@ end

# Channels

@nograd Channel

grad_mut(ch::Channel) = Channel(ch.sz_max)

@adjoint! function put!(ch::Channel, x)
Expand Down Expand Up @@ -157,8 +155,6 @@ end

@adjoint Base.nameof(x::UnionAll) = nameof(x), _ -> (nothing,)

@nograd typeintersect

# Base.Fix1 and Base.Fix2: https://github.com/FluxML/Zygote.jl/issues/957
@adjoint function (g::Base.Fix1)(y)
f = g.f
Expand Down
2 changes: 1 addition & 1 deletion src/lib/buffer.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
grad_mut(b::Buffer) = fill!(similar(b.data, Any), nothing)
grad_mut(b::Buffer{T}) where T<:Number = fill!(similar(b.data, float(T)), 0)

@nograd Buffer
@non_differentiable Buffer(::Any...)

@adjoint function getindex(b::Buffer, i...)
b[i...], function (Δ)
Expand Down
12 changes: 0 additions & 12 deletions src/lib/grad.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,3 @@
using MacroTools: @q

macro nograd(ex)
isexpr(ex, :tuple) || (ex = Expr(:tuple, ex))
blk = @q begin end
for f in ex.args
back = MacroTools.@q _ -> ($__source__; nothing)
push!(blk.args, :(@inline Zygote._pullback(::Context, ::Core.Typeof($(esc(f))), args...) = $(esc(f))(args...), $back))
end
return blk
end

macro which(ex)
@capture(ex, f_(args__)) || error("Zygote.@which f(args...)")
:(InteractiveUtils.@which adjoint(Context(), $(esc(f)), $(esc.(args)...)))
Expand Down
2 changes: 0 additions & 2 deletions src/lib/lib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,6 @@ function accum(x::RefValue, y::RefValue)
end

# Core functions
@nograd eps, Base.eval, Core.TypeVar, Core.UnionAll, Symbol

@adjoint deepcopy(x) = deepcopy(x), ȳ -> (ȳ,)

@adjoint (::Type{V})(x...) where V<:Val = V(x...), _ -> nothing
Expand Down
3 changes: 0 additions & 3 deletions src/lib/number.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@

@nograd floor, ceil, trunc, round, div

@adjoint Base.literal_pow(::typeof(^), x::Number, ::Val{p}) where {p} =
Base.literal_pow(^,x,Val(p)),
Δ -> (nothing, Δ * conj(p * Base.literal_pow(^,x,Val(p-1))), nothing)
Expand Down
6 changes: 3 additions & 3 deletions test/lib/number.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
@testset "nograds" begin
@test gradient(floor, 1) === nothing
@test gradient(ceil, 1) === nothing
@test gradient(round, 1) === nothing
@test gradient(floor, 1) === (0.0,)
@test gradient(ceil, 1) === (0.0,)
@test gradient(round, 1) === (0.0,)
@test gradient(hash, 1) === nothing
@test gradient(div, 1, 2) === nothing
end #testset

0 comments on commit cb59b6c

Please sign in to comment.