Skip to content

Commit

Permalink
MersenneTwister (#223)
Browse files Browse the repository at this point in the history
* Nograd for MersenneTwister

* MersenneTwister frule

* Bumps patch version

* Bump patch version
  • Loading branch information
willtebbutt authored Jul 5, 2020
1 parent ce39b65 commit 697e7e4
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 1 deletion.
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
name = "ChainRules"
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
version = "0.7.3"
version = "0.7.4"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Expand Down
3 changes: 3 additions & 0 deletions src/ChainRules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ using Reexport
using Base.Broadcast: materialize, materialize!, broadcasted, Broadcasted, broadcastable
using LinearAlgebra
using LinearAlgebra.BLAS
using Random
using Requires
using Statistics

Expand Down Expand Up @@ -35,6 +36,8 @@ include("rulesets/LinearAlgebra/dense.jl")
include("rulesets/LinearAlgebra/structured.jl")
include("rulesets/LinearAlgebra/factorization.jl")

include("rulesets/Random/random.jl")

# Note: The following is only required because package authors sometimes do not
# declare their own rules using `ChainRulesCore.jl`. For arguably good reasons.
# So we define them here for them.
Expand Down
8 changes: 8 additions & 0 deletions src/rulesets/Random/random.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
frule(Δargs, ::typeof(MersenneTwister), args...) = MersenneTwister(args...), Zero()

function rrule(::typeof(MersenneTwister), args...)
function MersenneTwister_rrule(ΔΩ)
return (NO_FIELDS, map(_ -> Zero(), args)...)
end
return MersenneTwister(args...), MersenneTwister_rrule
end
22 changes: 22 additions & 0 deletions test/rulesets/Random/random.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
@testset "random" begin
@testset "MersenneTwister" begin
@testset "no args" begin
rng, dΩ = frule((5.0,), MersenneTwister)
@test rng isa MersenneTwister
@testisa Zero

rng, pb = rrule(MersenneTwister)
@test rng isa MersenneTwister
@test first(pb(10)) isa Zero
end
@testset "unary" begin
rng, dΩ = frule((5.0, 4.0), MersenneTwister, 123)
@test rng isa MersenneTwister
@testisa Zero

rng, pb = rrule(MersenneTwister, 123)
@test rng isa MersenneTwister
@test all(map(x -> x isa Zero, pb(10)))
end
end
end
6 changes: 6 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,12 @@ println("Testing ChainRules.jl")

print(" ")

@testset "Random" begin
include(joinpath("rulesets", "Random", "random.jl"))
end

print(" ")

@testset "packages" begin
include(joinpath("rulesets", "packages", "NaNMath.jl"))
include(joinpath("rulesets", "packages", "SpecialFunctions.jl"))
Expand Down

4 comments on commit 697e7e4

@willtebbutt
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register()

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/17498

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.7.4 -m "<description of version>" 697e7e4e168015d216c86863c5c205f954766ddc
git push origin v0.7.4

@baggepinnen
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This commit might have caused #227

@sethaxen
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ooooh, I see the problem. I'll put together a fix PR.

Please sign in to comment.