diff --git a/Project.toml b/Project.toml index 57b6aa85f..3187acf10 100644 --- a/Project.toml +++ b/Project.toml @@ -1,10 +1,11 @@ name = "ChainRules" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "0.7.3" +version = "0.7.4" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" Requires = "ae029012-a4dd-5104-9daa-d747884805df" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" diff --git a/src/ChainRules.jl b/src/ChainRules.jl index 392080e5d..2cce49be6 100644 --- a/src/ChainRules.jl +++ b/src/ChainRules.jl @@ -6,6 +6,7 @@ using Reexport using Base.Broadcast: materialize, materialize!, broadcasted, Broadcasted, broadcastable using LinearAlgebra using LinearAlgebra.BLAS +using Random using Requires using Statistics @@ -35,6 +36,8 @@ include("rulesets/LinearAlgebra/dense.jl") include("rulesets/LinearAlgebra/structured.jl") include("rulesets/LinearAlgebra/factorization.jl") +include("rulesets/Random/random.jl") + # Note: The following is only required because package authors sometimes do not # declare their own rules using `ChainRulesCore.jl`. For arguably good reasons. # So we define them here for them. diff --git a/src/rulesets/Random/random.jl b/src/rulesets/Random/random.jl new file mode 100644 index 000000000..36d83ad9e --- /dev/null +++ b/src/rulesets/Random/random.jl @@ -0,0 +1,8 @@ +frule(Δargs, ::typeof(MersenneTwister), args...) = MersenneTwister(args...), Zero() + +function rrule(::typeof(MersenneTwister), args...) + function MersenneTwister_rrule(ΔΩ) + return (NO_FIELDS, map(_ -> Zero(), args)...) + end + return MersenneTwister(args...), MersenneTwister_rrule +end diff --git a/test/rulesets/Random/random.jl b/test/rulesets/Random/random.jl new file mode 100644 index 000000000..dec6c4ffe --- /dev/null +++ b/test/rulesets/Random/random.jl @@ -0,0 +1,22 @@ +@testset "random" begin + @testset "MersenneTwister" begin + @testset "no args" begin + rng, dΩ = frule((5.0,), MersenneTwister) + @test rng isa MersenneTwister + @test dΩ isa Zero + + rng, pb = rrule(MersenneTwister) + @test rng isa MersenneTwister + @test first(pb(10)) isa Zero + end + @testset "unary" begin + rng, dΩ = frule((5.0, 4.0), MersenneTwister, 123) + @test rng isa MersenneTwister + @test dΩ isa Zero + + rng, pb = rrule(MersenneTwister, 123) + @test rng isa MersenneTwister + @test all(map(x -> x isa Zero, pb(10))) + end + end +end diff --git a/test/runtests.jl b/test/runtests.jl index ec1a89353..88ea855f7 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -41,6 +41,12 @@ println("Testing ChainRules.jl") print(" ") + @testset "Random" begin + include(joinpath("rulesets", "Random", "random.jl")) + end + + print(" ") + @testset "packages" begin include(joinpath("rulesets", "packages", "NaNMath.jl")) include(joinpath("rulesets", "packages", "SpecialFunctions.jl"))