diff --git a/Project.toml b/Project.toml index b66961e93..e0e265c76 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRules" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "1.2.0" +version = "1.3.0" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/src/rulesets/Base/nondiff.jl b/src/rulesets/Base/nondiff.jl index c13441de8..7b8eaa3eb 100644 --- a/src/rulesets/Base/nondiff.jl +++ b/src/rulesets/Base/nondiff.jl @@ -128,6 +128,7 @@ @non_differentiable eachline(::AbstractString) @non_differentiable eachline(::IO) @non_differentiable eachmatch(::Regex, ::AbstractString) +@non_differentiable eltype(::Type) @non_differentiable endswith(::AbstractString, ::AbstractString) @non_differentiable endswith(::AbstractString, ::Regex) @non_differentiable eof(::Any) diff --git a/src/rulesets/Core/core.jl b/src/rulesets/Core/core.jl index 9249d399d..5753e3b2e 100644 --- a/src/rulesets/Core/core.jl +++ b/src/rulesets/Core/core.jl @@ -7,5 +7,22 @@ @non_differentiable Core.isdefined(::Any, ::Any) @non_differentiable Core.:(<:)(::Any, ::Any) -@non_differentiable Core.apply_type(::Any, ::Any) +@non_differentiable Core.apply_type(::Any, ::Any...) @non_differentiable Core.typeof(::Any) + +frule((_, ẋ, _), ::typeof(typeassert), x, T) = (typeassert(x, T), ẋ) +function rrule(::typeof(typeassert), x, T) + typeassert_pullback(Δ) = (NoTangent(), Δ, NoTangent()) + return typeassert(x, T), typeassert_pullback +end + +frule((_, _, ȧ, ḃ), ::typeof(ifelse), c, a, b) = (ifelse(c, a, b), ifelse(c, ȧ, ḃ)) +function rrule(::typeof(ifelse), c, a, b) + ifelse_pullback(Δ) = (NoTangent(), NoTangent(), ifelse(c, Δ, ZeroTangent()), ifelse(c, ZeroTangent(), Δ)) + return ifelse(c, a, b), ifelse_pullback +end +# ensure type stability for numbers +function rrule(::typeof(ifelse), c, a::Number, b::Number) + ifelse_pullback(Δ) = (NoTangent(), NoTangent(), ifelse(c, Δ, zero(Δ)), ifelse(c, zero(Δ), Δ)) + return ifelse(c, a, b), ifelse_pullback +end diff --git a/src/rulesets/LinearAlgebra/structured.jl b/src/rulesets/LinearAlgebra/structured.jl index fcb983b49..84bce2a77 100644 --- a/src/rulesets/LinearAlgebra/structured.jl +++ b/src/rulesets/LinearAlgebra/structured.jl @@ -107,6 +107,8 @@ end _Adjoint_mat_pullback(ȳ::Tangent, proj) = (NoTangent(), proj(ȳ.parent)) _Adjoint_mat_pullback(ȳ::AbstractVecOrMat, proj) = (NoTangent(), proj(adjoint(ȳ))) _Adjoint_mat_pullback(ȳ::AbstractThunk, proj) = return _Adjoint_mat_pullback(unthunk(ȳ), proj) +# currently needed by Diffractor (ref https://github.com/JuliaDiff/Diffractor.jl/issues/25) +_Adjoint_mat_pullback(ȳ::AbstractZero, proj) = (NoTangent(), proj(ȳ)) function rrule(::Type{<:Adjoint}, A::AbstractMatrix{<:Number}) project_A = ProjectTo(A) Adjoint_mat_pullback(ȳ) = _Adjoint_mat_pullback(ȳ, project_A) @@ -116,6 +118,8 @@ end _Adjoint_vec_pullback(ȳ::Tangent) = (NoTangent(), vec(ȳ.parent)) _Adjoint_vec_pullback(ȳ::AbstractMatrix) = (NoTangent(), vec(adjoint(ȳ))) _Adjoint_vec_pullback(ȳ::AbstractThunk) = return _Adjoint_vec_pullback(unthunk(ȳ)) +# currently needed by Diffractor (ref https://github.com/JuliaDiff/Diffractor.jl/issues/25) +_Adjoint_vec_pullback(ȳ::AbstractZero) = (NoTangent(), ȳ) function rrule(::Type{<:Adjoint}, A::AbstractVector{<:Number}) return Adjoint(A), _Adjoint_vec_pullback end @@ -123,6 +127,8 @@ end _adjoint_mat_pullback(ȳ::Tangent, proj) = (NoTangent(), proj(ȳ.parent)) _adjoint_mat_pullback(ȳ::AbstractVecOrMat, proj) = (NoTangent(), proj(adjoint(ȳ))) _adjoint_mat_pullback(ȳ::AbstractThunk, proj) = return _adjoint_mat_pullback(unthunk(ȳ), proj) +# currently needed by Diffractor (ref https://github.com/JuliaDiff/Diffractor.jl/issues/25) +_adjoint_mat_pullback(ȳ::AbstractZero, proj) = (NoTangent(), proj(ȳ)) function rrule(::typeof(adjoint), A::AbstractMatrix{<:Number}) project_A = ProjectTo(A) adjoint_mat_pullback(ȳ) = _adjoint_mat_pullback(ȳ, project_A) @@ -132,6 +138,8 @@ end _adjoint_vec_pullback(ȳ::Tangent) = (NoTangent(), vec(ȳ.parent)) _adjoint_vec_pullback(ȳ::AbstractMatrix) = (NoTangent(), vec(adjoint(ȳ))) _adjoint_vec_pullback(ȳ::AbstractThunk) = return _adjoint_vec_pullback(unthunk(ȳ)) +# currently needed by Diffractor (ref https://github.com/JuliaDiff/Diffractor.jl/issues/25) +_adjoint_vec_pullback(ȳ::AbstractZero) = (NoTangent(), ȳ) function rrule(::typeof(adjoint), A::AbstractVector{<:Number}) return adjoint(A), _adjoint_vec_pullback end @@ -145,6 +153,8 @@ end _Transpose_mat_pullback(ȳ::Tangent, proj) = (NoTangent(), proj(ȳ.parent)) _Transpose_mat_pullback(ȳ::AbstractVecOrMat, proj) = (NoTangent(), proj(Transpose(ȳ))) _Transpose_mat_pullback(ȳ::AbstractThunk, proj) = return _Transpose_mat_pullback(unthunk(ȳ), proj) +# currently needed by Diffractor (ref https://github.com/JuliaDiff/Diffractor.jl/issues/25) +_Transpose_mat_pullback(ȳ::AbstractZero, proj) = (NoTangent(), proj(ȳ)) function rrule(::Type{<:Transpose}, A::AbstractMatrix{<:Number}) project_A = ProjectTo(A) Transpose_mat_pullback(ȳ) = _Transpose_mat_pullback(ȳ, project_A) @@ -154,6 +164,8 @@ end _Transpose_vec_pullback(ȳ::Tangent) = (NoTangent(), vec(ȳ.parent)) _Transpose_vec_pullback(ȳ::AbstractMatrix) = (NoTangent(), vec(Transpose(ȳ))) _Transpose_vec_pullback(ȳ::AbstractThunk) = return _Transpose_vec_pullback(unthunk(ȳ)) +# currently needed by Diffractor (ref https://github.com/JuliaDiff/Diffractor.jl/issues/25) +_Transpose_vec_pullback(ȳ::AbstractZero) = (NoTangent(), ȳ) function rrule(::Type{<:Transpose}, A::AbstractVector{<:Number}) return Transpose(A), _Transpose_vec_pullback end @@ -161,6 +173,8 @@ end _transpose_mat_pullback(ȳ::Tangent, proj) = (NoTangent(), proj(ȳ.parent)) _transpose_mat_pullback(ȳ::AbstractVecOrMat, proj) = (NoTangent(), proj(transpose(ȳ))) _transpose_mat_pullback(ȳ::AbstractThunk, proj) = return _transpose_mat_pullback(unthunk(ȳ), proj) +# currently needed by Diffractor (ref https://github.com/JuliaDiff/Diffractor.jl/issues/25) +_transpose_mat_pullback(ȳ::AbstractZero, proj) = (NoTangent(), proj(ȳ)) function rrule(::typeof(transpose), A::AbstractMatrix{<:Number}) project_A = ProjectTo(A) transpose_mat_pullback(ȳ) = _transpose_mat_pullback(ȳ, project_A) @@ -170,6 +184,8 @@ end _transpose_vec_pullback(ȳ::Tangent) = (NoTangent(), vec(ȳ.parent)) _transpose_vec_pullback(ȳ::AbstractMatrix) = (NoTangent(), vec(transpose(ȳ))) _transpose_vec_pullback(ȳ::AbstractThunk) = return _transpose_vec_pullback(unthunk(ȳ)) +# currently needed by Diffractor (ref https://github.com/JuliaDiff/Diffractor.jl/issues/25) +_transpose_vec_pullback(ȳ::AbstractZero) = (NoTangent(), ȳ) function rrule(::typeof(transpose), A::AbstractVector{<:Number}) return transpose(A), _transpose_vec_pullback end diff --git a/test/rulesets/Core/core.jl b/test/rulesets/Core/core.jl new file mode 100644 index 000000000..8f3db8955 --- /dev/null +++ b/test/rulesets/Core/core.jl @@ -0,0 +1,12 @@ +@testset "typeassert" begin + test_rrule(typeassert, 1.1, Float64) + test_frule(typeassert, 1.1, Float64) +end + +@testset "ifelse" begin + test_rrule(ifelse, true, 1.1, 2.0) + test_frule(ifelse, false, 1.1, 2.0) + + test_rrule(ifelse, true, [1.1], [2.0]; check_inferred=false) + test_frule(ifelse, false, [1.1], [2.0]; check_inferred=false) +end diff --git a/test/runtests.jl b/test/runtests.jl index 5570a1bdb..dbd473a17 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -25,6 +25,10 @@ println("Testing ChainRules.jl") include_test("test_helpers.jl") println() @testset "rulesets" begin + @testset "Core" begin + include_test("rulesets/Core/core.jl") + end + @testset "Base" begin include_test("rulesets/Base/base.jl") include_test("rulesets/Base/fastmath_able.jl")