Skip to content

Commit

Permalink
Make @opt_out rrule(...) automatically qualify rrule namespace as…
Browse files Browse the repository at this point in the history
… `ChainRulesCore.rrule` (#546)
  • Loading branch information
mzgubic authored Feb 23, 2022
1 parent 3394de6 commit e3362cb
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 15 deletions.
37 changes: 24 additions & 13 deletions src/rule_definition_tools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -490,34 +490,45 @@ Similar applies for [`frule`](@ref) and [`ChainRulesCore.no_frule`](@ref)
For more information see the [documentation on opting out of rules](@ref opt_out).
"""
macro opt_out(expr)
no_rule_target = _no_rule_target_rewrite!(deepcopy(expr))
no_rule_target = _target_rewrite!(deepcopy(expr), true)
rule_target = _target_rewrite!(deepcopy(expr), false)

return @strip_linenos quote
$(esc(no_rule_target)) = nothing
$(esc(expr)) = nothing
$(esc(rule_target)) = nothing
end
end

"Rewrite method sig Expr for `rrule` to be for `no_rrule`, and `frule` to be `no_frule`."
function _no_rule_target_rewrite!(expr::Expr)
"""
_target_rewrite!(expr::Expr, no_rule)
Rewrite method sig `expr` for `rrule` to be for `no_rrule` or `ChainRulesCore.rrule`
(with the CRC namespace qualification), depending on the `no_rule` argument.
Does the equivalent for `frule`.
"""
function _target_rewrite!(expr::Expr, no_rule)
length(expr.args) === 0 && error("Malformed method expression. $expr")
if expr.head === :call || expr.head === :where
expr.args[1] = _no_rule_target_rewrite!(expr.args[1])
expr.args[1] = _target_rewrite!(expr.args[1], no_rule)
elseif expr.head == :(.) && expr.args[1] == :ChainRulesCore
expr = _no_rule_target_rewrite!(expr.args[end])
expr = _target_rewrite!(expr.args[end], no_rule)
else
error("Malformed method expression. $(expr)")
end
return expr
end
_no_rule_target_rewrite!(qt::QuoteNode) = _no_rule_target_rewrite!(qt.value)
function _no_rule_target_rewrite!(call_target::Symbol)
return if call_target == :rrule
:(ChainRulesCore.no_rrule)
elseif call_target == :frule
:(ChainRulesCore.no_frule)
_target_rewrite!(qt::QuoteNode, no_rule) = _target_rewrite!(qt.value, no_rule)
function _target_rewrite!(call_target::Symbol, no_rule)
return if call_target == :rrule && no_rule
:($ChainRulesCore.no_rrule)
elseif call_target == :rrule && !no_rule
:($ChainRulesCore.rrule)
elseif call_target == :frule && no_rule
:($ChainRulesCore.no_frule)
elseif call_target == :frule && !no_rule
:($ChainRulesCore.frule)
else
error("Unexpected opt-out target. Exprected frule or rrule, got: $call_target")
error("Unexpected opt-out target. Expected frule or rrule, got: $call_target")
end
end

Expand Down
32 changes: 30 additions & 2 deletions test/rule_definition_tools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ end
# workaround for https://github.com/domluna/JuliaFormatter.jl/issues/484
module IsolatedModuleForTestingScoping
# check that rules can be defined by macros without any additional imports
using ChainRulesCore: @scalar_rule, @non_differentiable
using ChainRulesCore: @scalar_rule, @non_differentiable, @opt_out

# ensure that functions, types etc. in module `ChainRulesCore` can't be resolved
const ChainRulesCore = nothing
Expand All @@ -303,11 +303,20 @@ module IsolatedModuleForTestingScoping
my_id(x) = x
@scalar_rule(my_id(x), 1.0)

# @opt_out
first_oa(x, y) = x
@scalar_rule(first_oa(x, y), (1, 0))
# Declared without using the ChainRulesCore namespace qualification
# see https://github.com/JuliaDiff/ChainRulesCore.jl/issues/545
@opt_out rrule(::typeof(first_oa), x::T, y::T) where {T<:Float16}
@opt_out frule(::Any, ::typeof(first_oa), x::T, y::T) where {T<:Float16}

module IsolatedSubmodule
# check that rules defined in isolated module without imports can be called
# without errors
using ChainRulesCore: frule, rrule, ZeroTangent, NoTangent, derivatives_given_output
using ..IsolatedModuleForTestingScoping: fixed, fixed_kwargs, my_id
using ChainRulesCore: no_rrule, no_frule
using ..IsolatedModuleForTestingScoping: fixed, fixed_kwargs, my_id, first_oa
using Test

@testset "@non_differentiable" begin
Expand Down Expand Up @@ -339,6 +348,25 @@ module IsolatedModuleForTestingScoping

@test derivatives_given_output(y, my_id, x) == ((1.0,),)
end

@testset "@optout" begin
# rrule
@test rrule(first_oa, Float16(3.0), Float16(4.0)) === nothing
@test !isempty(
Iterators.filter(methods(no_rrule)) do m
m.sig <: Tuple{Any,typeof(first_oa),T,T} where {T<:Float16}
end,
)

# frule
@test frule((NoTangent(), 1, 0), first_oa, Float16(3.0), Float16(4.0)) ===
nothing
@test !isempty(
Iterators.filter(methods(no_frule)) do m
m.sig <: Tuple{Any,Any,typeof(first_oa),T,T} where {T<:Float16}
end,
)
end
end
end
#! format: on

2 comments on commit e3362cb

@mzgubic
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Error while trying to register: "Tag with name v1.12.1 already exists and points to a different commit"

Please sign in to comment.