-
Notifications
You must be signed in to change notification settings - Fork 32
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Series of tests for AD #114
Changes from all commits
2234d4f
a6159e1
4aeb0e3
b6a7901
f70adc1
2ae0cd6
d88dcff
6875aee
44368fb
b3142f6
44ad0cd
07631b6
960bad2
3e620ae
24cb00d
5b2e580
0bba1a5
7f52242
f1000b3
4023365
a73133b
d586967
577518f
9d82e1c
181341e
88c6af7
aa282a1
ffefd1f
6b5ba4d
b6ddf52
5c7eb6a
a4e5bb2
686ad8c
e94973e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,12 +1,14 @@ | ||
struct Delta <: Distances.PreMetric | ||
end | ||
|
||
@inline function Distances._evaluate(::Delta,a::AbstractVector{T},b::AbstractVector{T}) where {T} | ||
@inline function Distances._evaluate(::Delta, a::AbstractVector, b::AbstractVector) where {T} | ||
@boundscheck if length(a) != length(b) | ||
throw(DimensionMismatch("first array has length $(length(a)) which does not match the length of the second, $(length(b)).")) | ||
end | ||
return a == b | ||
end | ||
|
||
Distances.result_type(::Delta, Ta::Type, Tb::Type) = promote_type(Ta, Tb) | ||
|
||
@inline (dist::Delta)(a::AbstractArray, b::AbstractArray) = Distances._evaluate(dist, a, b) | ||
@inline (dist::Delta)(a::Number,b::Number) = a == b | ||
@inline (dist::Delta)(a::Number, b::Number) = a == b |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,13 +1,15 @@ | ||
struct DotProduct <: Distances.PreMetric end | ||
# struct DotProduct <: Distances.UnionSemiMetric end | ||
|
||
@inline function Distances._evaluate(::DotProduct, a::AbstractVector{T}, b::AbstractVector{T}) where {T} | ||
@inline function Distances._evaluate(::DotProduct, a::AbstractVector, b::AbstractVector) | ||
@boundscheck if length(a) != length(b) | ||
throw(DimensionMismatch("first array has length $(length(a)) which does not match the length of the second, $(length(b)).")) | ||
end | ||
return dot(a,b) | ||
end | ||
|
||
Distances.result_type(::DotProduct, Ta::Type, Tb::Type) = promote_type(Ta, Tb) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same here. |
||
|
||
@inline Distances.eval_op(::DotProduct, a::Real, b::Real) = a * b | ||
@inline (dist::DotProduct)(a::AbstractArray,b::AbstractArray) = Distances._evaluate(dist, a, b) | ||
@inline (dist::DotProduct)(a::Number,b::Number) = a * b |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,7 +8,9 @@ Distances.parameters(d::Sinus) = d.r | |
@inline (dist::Sinus)(a::AbstractArray, b::AbstractArray) = Distances._evaluate(dist, a, b) | ||
@inline (dist::Sinus)(a::Number, b::Number) = abs2(sinpi(a - b) / first(dist.r)) | ||
|
||
@inline function Distances._evaluate(d::Sinus, a::AbstractVector{T}, b::AbstractVector{T}) where {T} | ||
Distances.result_type(::Sinus{T}, Ta::Type, Tb::Type) where {T} = promote_type(T, Ta, Tb) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. And here. |
||
|
||
@inline function Distances._evaluate(d::Sinus, a::AbstractVector, b::AbstractVector) where {T} | ||
@boundscheck if (length(a) != length(b)) || length(a) != length(d.r) | ||
throw(DimensionMismatch("Dimensions of the inputs are not matching : a = $(length(a)), b = $(length(b)), r = $(length(d.r))")) | ||
end | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,7 @@ | ||
hadamard(x, y) = x .* y | ||
|
||
loggamma(x) = first(logabsgamma(x)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There's no reason to define There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is that really type piracy? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No, |
||
|
||
# Macro for checking arguments | ||
macro check_args(K, param, cond, desc=string(cond)) | ||
quote | ||
|
@@ -124,4 +126,3 @@ function validate_dims(x::AbstractVector, y::AbstractVector) | |
)) | ||
end | ||
end | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,9 +1,78 @@ | ||
## Adjoints Delta | ||
@adjoint function evaluate(s::Delta, x::AbstractVector, y::AbstractVector) | ||
evaluate(s, x, y), Δ -> begin | ||
(nothing, nothing, nothing) | ||
end | ||
end | ||
|
||
@adjoint function pairwise(d::Delta, X::AbstractMatrix, Y::AbstractMatrix; dims=2) | ||
D = pairwise(d, X, Y; dims = dims) | ||
if dims == 1 | ||
return D, Δ -> (nothing, nothing, nothing) | ||
else | ||
return D, Δ -> (nothing, nothing, nothing) | ||
end | ||
end | ||
|
||
@adjoint function pairwise(d::Delta, X::AbstractMatrix; dims=2) | ||
D = pairwise(d, X; dims = dims) | ||
if dims == 1 | ||
return D, Δ -> (nothing, nothing) | ||
else | ||
return D, Δ -> (nothing, nothing) | ||
end | ||
end | ||
|
||
## Adjoints DotProduct | ||
@adjoint function evaluate(s::DotProduct, x::AbstractVector, y::AbstractVector) | ||
dot(x, y), Δ -> begin | ||
(nothing, Δ .* y, Δ .* x) | ||
end | ||
end | ||
|
||
@adjoint function pairwise(d::DotProduct, X::AbstractMatrix, Y::AbstractMatrix; dims=2) | ||
D = pairwise(d, X, Y; dims = dims) | ||
if dims == 1 | ||
return D, Δ -> (nothing, Δ * Y, (X' * Δ)') | ||
else | ||
return D, Δ -> (nothing, (Δ * Y')', X * Δ) | ||
end | ||
end | ||
|
||
@adjoint function pairwise(d::DotProduct, X::AbstractMatrix; dims=2) | ||
D = pairwise(d, X; dims = dims) | ||
if dims == 1 | ||
return D, Δ -> (nothing, 2 * Δ * X) | ||
else | ||
return D, Δ -> (nothing, 2 * X * Δ) | ||
end | ||
end | ||
|
||
## Adjoints Sinus | ||
@adjoint function evaluate(s::Sinus, x::AbstractVector, y::AbstractVector) | ||
d = (x - y) | ||
sind = sinpi.(d) | ||
val = sum(abs2, sind ./ s.r) | ||
gradx = 2π .* cospi.(d) .* sind ./ (s.r .^ 2) | ||
val, Δ -> begin | ||
((r = -2Δ .* abs2.(sind) ./ s.r,), Δ * gradx, - Δ * gradx) | ||
end | ||
end | ||
|
||
@adjoint function loggamma(x) | ||
first(logabsgamma(x)) , Δ -> (Δ .* polygamma(0, x), ) | ||
end | ||
|
||
@adjoint function kappa(κ::MaternKernel, d::Real) | ||
ν = first(κ.ν) | ||
val, grad = pullback(_matern, ν, d) | ||
return ((iszero(d) ? one(d) : val), | ||
Δ -> begin | ||
∇ = grad(Δ) | ||
return ((ν = [∇[1]],), iszero(d) ? zero(d) : ∇[2]) | ||
end) | ||
end | ||
Comment on lines
+66
to
+74
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is this needed? The definition of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is needed for the edge case where There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does |
||
|
||
@adjoint function ColVecs(X::AbstractMatrix) | ||
back(Δ::NamedTuple) = (Δ.X,) | ||
back(Δ::AbstractMatrix) = (Δ,) | ||
|
@@ -22,10 +91,10 @@ end | |
return RowVecs(X), back | ||
end | ||
|
||
# @adjoint function evaluate(s::Sinus, x::AbstractVector, y::AbstractVector) | ||
# d = evaluate(s, x, y) | ||
# s = sum(sin.(π*(x-y))) | ||
# d, Δ -> begin | ||
# (Sinus(Δ ./ s.r), 2Δ .* cos.(x - y) * d, -2Δ .* cos.(x - y) * d) | ||
# end | ||
# end | ||
@adjoint function Base.map(t::Transform, X::ColVecs) | ||
pullback(_map, t, X) | ||
end | ||
|
||
@adjoint function Base.map(t::Transform, X::RowVecs) | ||
pullback(_map, t, X) | ||
end |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
[deps] | ||
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" | ||
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" | ||
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" | ||
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" | ||
Kronecker = "2c470bb0-bcc8-11e8-3dad-c9649493f05e" | ||
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" | ||
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150" | ||
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" | ||
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" | ||
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" | ||
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" | ||
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" | ||
|
||
[compat] | ||
Distances = "0.9" | ||
FiniteDifferences = "0.10" | ||
Flux = "0.10" | ||
ForwardDiff = "0.10" | ||
Kronecker = "0.4" | ||
PDMats = "0.9" | ||
ReverseDiff = "1.2" | ||
SpecialFunctions = "0.10" | ||
Zygote = "0.4" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This won't specialize on the types
Ta
andTb
(see https://docs.julialang.org/en/v1/manual/performance-tips/#Be-aware-of-when-Julia-avoids-specializing-1).There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Distances do this for determining the type of the allocated matrix in
pairwise!
, if this is not defined it defaults toFloat64
and therefore breaks AD for ForwardDiff and others :https://github.com/JuliaStats/Distances.jl/blob/f69f7888c92458ae671c893d079ecf5fc8d8accd/src/generic.jl#L35
https://github.com/JuliaStats/Distances.jl/blob/f69f7888c92458ae671c893d079ecf5fc8d8accd/src/generic.jl#L203
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, we don't want to restrict it to
Float64
only. My comment was just that Julia actually won't compile specialized versions of this function for different input types since you usedTa::Type
instead of::Type{Ta}
.