diff --git a/Project.toml b/Project.toml index 51337c18f..b378a42bd 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRules" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "1.46.0" +version = "1.46.1" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/src/rulesets/Base/indexing.jl b/src/rulesets/Base/indexing.jl index 6136fa6a2..beb042cf3 100644 --- a/src/rulesets/Base/indexing.jl +++ b/src/rulesets/Base/indexing.jl @@ -61,10 +61,9 @@ function frule((_, ẋ), ::typeof(getindex), x::AbstractArray, inds...) end function rrule(::typeof(getindex), x::AbstractArray, inds...) - function getindex_pullback(dy) - nots = map(Returns(NoTangent()), inds) - return (NoTangent(), thunked_∇getindex(x, dy, inds...), nots...) - end + nots = map(Returns(NoTangent()), inds) + getindex_pullback(dy) = (NoTangent(), thunked_∇getindex(x, dy, inds...), nots...) + getindex_pullback(z::AbstractZero) = (NoTangent(), z, nots...) return x[inds...], getindex_pullback end @@ -90,6 +89,7 @@ function ∇getindex(x::AbstractArray, dy, inds...) ∇getindex!(dx, dy, plain_inds...) return ProjectTo(x)(dx) # since we have x, may as well do this inside, not in rules end +∇getindex(x::AbstractArray, z::AbstractZero, inds...) = z """ _setindex_zero(x, dy, inds...) @@ -191,10 +191,9 @@ function frule((_, ẋ), ::typeof(view), x::AbstractArray, inds...) end function rrule(::typeof(view), x::AbstractArray, inds...) - function view_pullback(dy) - nots = map(Returns(NoTangent()), inds) - return (NoTangent(), thunked_∇getindex(x, dy, inds...), nots...) - end + nots = map(Returns(NoTangent()), inds) + view_pullback(dy) = (NoTangent(), thunked_∇getindex(x, dy, inds...), nots...) + view_pullback(z::AbstractZero) = (NoTangent(), z, nots...) return view(x, inds...), view_pullback end diff --git a/src/rulesets/Base/nondiff.jl b/src/rulesets/Base/nondiff.jl index 22aeb1748..2bcac37ca 100644 --- a/src/rulesets/Base/nondiff.jl +++ b/src/rulesets/Base/nondiff.jl @@ -189,6 +189,7 @@ @non_differentiable floatmax(::Any) @non_differentiable floatmin(::Any) @non_differentiable flush(::Any) +@non_differentiable foreach(::Any, ::Tuple{}) @non_differentiable gensym(::Symbol) @non_differentiable gensym(::String...) @@ -422,6 +423,7 @@ end @non_differentiable supertype(::Any) @non_differentiable Symbol(::Any...) @non_differentiable symlink(::AbstractString, ::AbstractString) +@non_differentiable summary(::Any) @non_differentiable take!(::Base.GenericIOBuffer) @non_differentiable take!(::IOStream) @@ -472,6 +474,7 @@ elseif isdefined(Base, :cumulative_compile_time_ns) end @non_differentiable Base.time_print(::Any...) @non_differentiable Base.OneTo(::Any...) +@non_differentiable Base.array_summary(::Any) @non_differentiable Broadcast.combine_styles(::Any...) @non_differentiable Broadcast.result_style(::Any) diff --git a/src/rulesets/Statistics/statistics.jl b/src/rulesets/Statistics/statistics.jl index 08be133fd..6dc00faae 100644 --- a/src/rulesets/Statistics/statistics.jl +++ b/src/rulesets/Statistics/statistics.jl @@ -6,6 +6,8 @@ _denom(x, dims) = size(x, dims) _denom(x, dims::Colon) = length(x) _denom(x, dims::Union{Tuple, AbstractArray}) = mapreduce(i->size(x, i), Base.mul_prod, unique(dims), init=1) +@non_differentiable _denom(::Any, ::Any) # else Zygote tries to AD unique(::Tuple) + function rrule(::typeof(mean), x::AbstractArray{<:Union{Real,Complex,AbstractArray}}; dims=:) y_sum, sum_pullback = rrule(sum, x; dims) n = _denom(x, dims)