Skip to content

Commit

Permalink
Define getindex on rules (#57)
Browse files Browse the repository at this point in the history
We define `iterate` on them for convenience, and `getindex` provides a
similar convenience for cases where you're not sure whether the rules
resulting from a call to `rrule`/`frule` will be a `Tuple`. So instead
of writing `partials isa Tuple ? partials[i] : partials`, you can now
just write `partials[i]`.
  • Loading branch information
ararslan authored Jun 18, 2019
1 parent 314b08a commit db1eb6b
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 1 deletion.
4 changes: 4 additions & 0 deletions src/rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ abstract type AbstractRule end
Base.iterate(rule::AbstractRule) = (rule, nothing)
Base.iterate(::AbstractRule, ::Any) = nothing

# This ensures we don't need to check whether the result of `rrule`/`frule` is a tuple
# in order to get the `i`th rule (assuming it's 1)
Base.getindex(rule::AbstractRule, i::Integer) = i == 1 ? rule : throw(BoundsError())

"""
accumulate(Δ, rule::AbstractRule, args...)
Expand Down
4 changes: 3 additions & 1 deletion test/rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,16 @@ cool(x) = x + 1
@test rrx == 2
@test rr(1) == 1
end
@testset "iterating rules" begin
@testset "iterating and indexing rules" begin
_, rule = frule(+, 1)
i = 0
for r in rule
@test r === rule
i += 1
end
@test i == 1 # rules only iterate once, yielding themselves
@test rule[1] == rule
@test_throws BoundsError rule[2]
end
@testset "helper functions" begin
# Hits fallback, since we can't update `Diagonal`s in place
Expand Down

0 comments on commit db1eb6b

Please sign in to comment.