Skip to content

Commit 0886582

Browse files
authored
Merge pull request #54 from mcabbott/tri
Add `ifelse` & `muladd` rules
2 parents df90af7 + 4a849ea commit 0886582

File tree

3 files changed

+47
-2
lines changed

3 files changed

+47
-2
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DiffRules"
22
uuid = "b552c78f-8df3-52c6-915a-8e097449b14b"
3-
version = "1.0.2"
3+
version = "1.1.0"
44

55
[deps]
66
NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"

src/rules.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,10 +61,15 @@
6161
@define_diffrule Base.deg2rad(x) = :( π / 180 )
6262
@define_diffrule Base.mod2pi(x) = :( isinteger($x / 2pi) ? NaN : 1 )
6363
@define_diffrule Base.rad2deg(x) = :( 180 / π )
64+
6465
@define_diffrule SpecialFunctions.gamma(x) =
6566
:( SpecialFunctions.digamma($x) * SpecialFunctions.gamma($x) )
6667
@define_diffrule SpecialFunctions.loggamma(x) =
6768
:( SpecialFunctions.digamma($x) )
69+
70+
@define_diffrule Base.identity(x) = :( 1 )
71+
@define_diffrule Base.conj(x) = :( 1 )
72+
@define_diffrule Base.adjoint(x) = :( 1 )
6873
@define_diffrule Base.transpose(x) = :( 1 )
6974
@define_diffrule Base.abs(x) = :( DiffRules._abs_deriv($x) )
7075

@@ -88,12 +93,22 @@ else
8893
@define_diffrule Base.atan(x, y) = :( $y / ($x^2 + $y^2) ), :( -$x / ($x^2 + $y^2) )
8994
end
9095
@define_diffrule Base.hypot(x, y) = :( $x / hypot($x, $y) ), :( $y / hypot($x, $y) )
96+
@define_diffrule Base.log(b, x) = :( log($x) * inv(-log($b)^2 * $b) ), :( inv($x) / log($b) )
97+
9198
@define_diffrule Base.mod(x, y) = :( first(promote(ifelse(isinteger($x / $y), NaN, 1), NaN)) ), :( z = $x / $y; first(promote(ifelse(isinteger(z), NaN, -floor(z)), NaN)) )
9299
@define_diffrule Base.rem(x, y) = :( first(promote(ifelse(isinteger($x / $y), NaN, 1), NaN)) ), :( z = $x / $y; first(promote(ifelse(isinteger(z), NaN, -trunc(z)), NaN)) )
93100
@define_diffrule Base.rem2pi(x, r) = :( 1 ), :NaN
94101
@define_diffrule Base.max(x, y) = :( $x > $y ? one($x) : zero($x) ), :( $x > $y ? zero($y) : one($y) )
95102
@define_diffrule Base.min(x, y) = :( $x > $y ? zero($x) : one($x) ), :( $x > $y ? one($y) : zero($y) )
96103

104+
# trinary #
105+
#---------#
106+
107+
@define_diffrule Base.muladd(x, y, z) = :($y), :($x), :(one($z))
108+
@define_diffrule Base.fma(x, y, z) = :($y), :($x), :(one($z))
109+
110+
@define_diffrule Base.ifelse(p, x, y) = false, :($p), :(!$p)
111+
97112
####################
98113
# SpecialFunctions #
99114
####################

test/runtests.jl

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ function finitediff(f, x)
1616
end
1717

1818

19-
non_numeric_arg_functions = [(:Base, :rem2pi, 2)]
19+
non_numeric_arg_functions = [(:Base, :rem2pi, 2), (:Base, :ifelse, 3)]
2020

2121
for (M, f, arity) in DiffRules.diffrules()
2222
(M, f, arity) non_numeric_arg_functions && continue
@@ -46,6 +46,22 @@ for (M, f, arity) in DiffRules.diffrules()
4646
@test isapprox(dy, finitediff(z -> $M.$f(foo, z), bar), rtol=0.05)
4747
end
4848
end
49+
elseif arity == 3
50+
@test DiffRules.hasdiffrule(M, f, 3)
51+
derivs = DiffRules.diffrule(M, f, :foo, :bar, :goo)
52+
@eval begin
53+
foo, bar, goo = randn(3)
54+
dx, dy, dz = $(derivs[1]), $(derivs[2]), $(derivs[3])
55+
if !(isnan(dx))
56+
@test isapprox(dx, finitediff(x -> $M.$f(x, bar, goo), foo), rtol=0.05)
57+
end
58+
if !(isnan(dy))
59+
@test isapprox(dy, finitediff(y -> $M.$f(foo, y, goo), bar), rtol=0.05)
60+
end
61+
if !(isnan(dz))
62+
@test isapprox(dz, finitediff(z -> $M.$f(foo, bar, z), goo), rtol=0.05)
63+
end
64+
end
4965
end
5066
end
5167

@@ -62,3 +78,17 @@ for xtype in [:Float64, :BigFloat, :Int64]
6278
end
6379
end
6480
end
81+
82+
# Test ifelse separately as first argument is boolean
83+
@test DiffRules.hasdiffrule(:Base, :ifelse, 3)
84+
derivs = DiffRules.diffrule(:Base, :ifelse, :foo, :bar, :goo)
85+
for cond in [true, false]
86+
@eval begin
87+
foo = $cond
88+
bar, gee = randn(2)
89+
dx, dy, dz = $(derivs[1]), $(derivs[2]), $(derivs[3])
90+
@test isapprox(dy, finitediff(y -> ifelse(foo, y, goo), bar), rtol=0.05)
91+
@test isapprox(dz, finitediff(z -> ifelse(foo, bar, z), goo), rtol=0.05)
92+
end
93+
end
94+

0 commit comments

Comments
 (0)