From db1eb6bc1eea33b30f7f70e1e5b5ed3a9e7b7863 Mon Sep 17 00:00:00 2001 From: Alex Arslan Date: Tue, 18 Jun 2019 10:03:17 -0700 Subject: [PATCH] Define `getindex` on rules (#57) We define `iterate` on them for convenience, and `getindex` provides a similar convenience for cases where you're not sure whether the rules resulting from a call to `rrule`/`frule` will be a `Tuple`. So instead of writing `partials isa Tuple ? partials[i] : partials`, you can now just write `partials[i]`. --- src/rules.jl | 4 ++++ test/rules.jl | 4 +++- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/src/rules.jl b/src/rules.jl index c8e36d1ca..f0705878c 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -54,6 +54,10 @@ abstract type AbstractRule end Base.iterate(rule::AbstractRule) = (rule, nothing) Base.iterate(::AbstractRule, ::Any) = nothing +# This ensures we don't need to check whether the result of `rrule`/`frule` is a tuple +# in order to get the `i`th rule (assuming it's 1) +Base.getindex(rule::AbstractRule, i::Integer) = i == 1 ? rule : throw(BoundsError()) + """ accumulate(Δ, rule::AbstractRule, args...) diff --git a/test/rules.jl b/test/rules.jl index 9f52550a3..6132afa4f 100644 --- a/test/rules.jl +++ b/test/rules.jl @@ -12,7 +12,7 @@ cool(x) = x + 1 @test rrx == 2 @test rr(1) == 1 end - @testset "iterating rules" begin + @testset "iterating and indexing rules" begin _, rule = frule(+, 1) i = 0 for r in rule @@ -20,6 +20,8 @@ cool(x) = x + 1 i += 1 end @test i == 1 # rules only iterate once, yielding themselves + @test rule[1] == rule + @test_throws BoundsError rule[2] end @testset "helper functions" begin # Hits fallback, since we can't update `Diagonal`s in place