diff --git a/src/rule_definition_tools.jl b/src/rule_definition_tools.jl index d63c707e7..a77b16059 100644 --- a/src/rule_definition_tools.jl +++ b/src/rule_definition_tools.jl @@ -490,34 +490,45 @@ Similar applies for [`frule`](@ref) and [`ChainRulesCore.no_frule`](@ref) For more information see the [documentation on opting out of rules](@ref opt_out). """ macro opt_out(expr) - no_rule_target = _no_rule_target_rewrite!(deepcopy(expr)) + no_rule_target = _target_rewrite!(deepcopy(expr), true) + rule_target = _target_rewrite!(deepcopy(expr), false) return @strip_linenos quote $(esc(no_rule_target)) = nothing - $(esc(expr)) = nothing + $(esc(rule_target)) = nothing end end -"Rewrite method sig Expr for `rrule` to be for `no_rrule`, and `frule` to be `no_frule`." -function _no_rule_target_rewrite!(expr::Expr) +""" + _target_rewrite!(expr::Expr, no_rule) + +Rewrite method sig `expr` for `rrule` to be for `no_rrule` or `ChainRulesCore.rrule` +(with the CRC namespace qualification), depending on the `no_rule` argument. +Does the equivalent for `frule`. +""" +function _target_rewrite!(expr::Expr, no_rule) length(expr.args) === 0 && error("Malformed method expression. $expr") if expr.head === :call || expr.head === :where - expr.args[1] = _no_rule_target_rewrite!(expr.args[1]) + expr.args[1] = _target_rewrite!(expr.args[1], no_rule) elseif expr.head == :(.) && expr.args[1] == :ChainRulesCore - expr = _no_rule_target_rewrite!(expr.args[end]) + expr = _target_rewrite!(expr.args[end], no_rule) else error("Malformed method expression. $(expr)") end return expr end -_no_rule_target_rewrite!(qt::QuoteNode) = _no_rule_target_rewrite!(qt.value) -function _no_rule_target_rewrite!(call_target::Symbol) - return if call_target == :rrule - :(ChainRulesCore.no_rrule) - elseif call_target == :frule - :(ChainRulesCore.no_frule) +_target_rewrite!(qt::QuoteNode, no_rule) = _target_rewrite!(qt.value, no_rule) +function _target_rewrite!(call_target::Symbol, no_rule) + return if call_target == :rrule && no_rule + :($ChainRulesCore.no_rrule) + elseif call_target == :rrule && !no_rule + :($ChainRulesCore.rrule) + elseif call_target == :frule && no_rule + :($ChainRulesCore.no_frule) + elseif call_target == :frule && !no_rule + :($ChainRulesCore.frule) else - error("Unexpected opt-out target. Exprected frule or rrule, got: $call_target") + error("Unexpected opt-out target. Expected frule or rrule, got: $call_target") end end diff --git a/test/rule_definition_tools.jl b/test/rule_definition_tools.jl index ec9549e7e..5a177566d 100644 --- a/test/rule_definition_tools.jl +++ b/test/rule_definition_tools.jl @@ -285,7 +285,7 @@ end # workaround for https://github.com/domluna/JuliaFormatter.jl/issues/484 module IsolatedModuleForTestingScoping # check that rules can be defined by macros without any additional imports - using ChainRulesCore: @scalar_rule, @non_differentiable + using ChainRulesCore: @scalar_rule, @non_differentiable, @opt_out # ensure that functions, types etc. in module `ChainRulesCore` can't be resolved const ChainRulesCore = nothing @@ -303,11 +303,20 @@ module IsolatedModuleForTestingScoping my_id(x) = x @scalar_rule(my_id(x), 1.0) + # @opt_out + first_oa(x, y) = x + @scalar_rule(first_oa(x, y), (1, 0)) + # Declared without using the ChainRulesCore namespace qualification + # see https://github.com/JuliaDiff/ChainRulesCore.jl/issues/545 + @opt_out rrule(::typeof(first_oa), x::T, y::T) where {T<:Float16} + @opt_out frule(::Any, ::typeof(first_oa), x::T, y::T) where {T<:Float16} + module IsolatedSubmodule # check that rules defined in isolated module without imports can be called # without errors using ChainRulesCore: frule, rrule, ZeroTangent, NoTangent, derivatives_given_output - using ..IsolatedModuleForTestingScoping: fixed, fixed_kwargs, my_id + using ChainRulesCore: no_rrule, no_frule + using ..IsolatedModuleForTestingScoping: fixed, fixed_kwargs, my_id, first_oa using Test @testset "@non_differentiable" begin @@ -339,6 +348,25 @@ module IsolatedModuleForTestingScoping @test derivatives_given_output(y, my_id, x) == ((1.0,),) end + + @testset "@optout" begin + # rrule + @test rrule(first_oa, Float16(3.0), Float16(4.0)) === nothing + @test !isempty( + Iterators.filter(methods(no_rrule)) do m + m.sig <: Tuple{Any,typeof(first_oa),T,T} where {T<:Float16} + end, + ) + + # frule + @test frule((NoTangent(), 1, 0), first_oa, Float16(3.0), Float16(4.0)) === + nothing + @test !isempty( + Iterators.filter(methods(no_frule)) do m + m.sig <: Tuple{Any,Any,typeof(first_oa),T,T} where {T<:Float16} + end, + ) + end end end #! format: on