Skip to content

Commit

Permalink
rename differentials (#162)
Browse files Browse the repository at this point in the history
* ignore dev

* bump version, compat

* rename DoesNotExist

* rename Composite to Tangent

* rename Zero to ZeroTangent

* update docs

* docs manifest

* Update docs/Project.toml

Co-authored-by: Lyndon White <oxinabox@ucc.asn.au>

Co-authored-by: Lyndon White <oxinabox@ucc.asn.au>
  • Loading branch information
mzgubic and oxinabox authored May 26, 2021
1 parent b96f892 commit 266d6fa
Show file tree
Hide file tree
Showing 8 changed files with 137 additions and 75 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
*.jl.*.cov
*.jl.mem
/Manifest.toml
dev/

# Docs:
docs/build/
Expand Down
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "FiniteDifferences"
uuid = "26cc04aa-876d-5657-8c51-4c34ba976000"
version = "0.12.6"
version = "0.12.7"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand All @@ -11,7 +11,7 @@ Richardson = "708f8203-808e-40c0-ba2d-98a6953ed40d"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"

[compat]
ChainRulesCore = "0.9"
ChainRulesCore = "0.9.44"
Richardson = "1.2"
StaticArrays = "0.12, 1.0"
julia = "1"
Expand Down
97 changes: 79 additions & 18 deletions docs/Manifest.toml
Original file line number Diff line number Diff line change
@@ -1,39 +1,59 @@
# This file is machine-generated - editing it directly is not advised

[[ArgTools]]
uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f"

[[Artifacts]]
uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33"

[[Base64]]
uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"

[[ChainRulesCore]]
deps = ["LinearAlgebra", "MuladdMacro", "SparseArrays"]
git-tree-sha1 = "15081c431bb25848ad9b0d172a65794f3a3e197a"
deps = ["Compat", "LinearAlgebra", "SparseArrays"]
git-tree-sha1 = "b391f22252b8754f4440de1f37ece49d8a7314bb"
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
version = "0.9.24"
version = "0.9.44"

[[Compat]]
deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "SHA", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]
git-tree-sha1 = "e4e2b39db08f967cc1360951f01e8a75ec441cab"
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
version = "3.30.0"

[[Dates]]
deps = ["Printf"]
uuid = "ade2ca70-3891-5945-98fb-dc099432e06a"

[[DelimitedFiles]]
deps = ["Mmap"]
uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab"

[[Distributed]]
deps = ["Random", "Serialization", "Sockets"]
uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"

[[DocStringExtensions]]
deps = ["LibGit2", "Markdown", "Pkg", "Test"]
git-tree-sha1 = "50ddf44c53698f5e784bbebb3f4b21c5807401b1"
git-tree-sha1 = "9d4f64f79012636741cf01133158a54b24924c32"
uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
version = "0.8.3"
version = "0.8.4"

[[Documenter]]
deps = ["Base64", "Dates", "DocStringExtensions", "IOCapture", "InteractiveUtils", "JSON", "LibGit2", "Logging", "Markdown", "REPL", "Test", "Unicode"]
git-tree-sha1 = "a4875e0763112d6d017126f3944f4133abb342ae"
uuid = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
version = "0.25.5"

[[Downloads]]
deps = ["ArgTools", "LibCURL", "NetworkOptions"]
uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6"

[[FiniteDifferences]]
deps = ["ChainRulesCore", "LinearAlgebra", "Printf", "Random", "Richardson", "StaticArrays"]
path = ".."
uuid = "26cc04aa-876d-5657-8c51-4c34ba976000"
version = "0.11.5"
version = "0.12.7"

[[IOCapture]]
deps = ["Logging"]
Expand All @@ -51,10 +71,22 @@ git-tree-sha1 = "81690084b6198a2e1da36fcfda16eeca9f9f24e4"
uuid = "682c06a0-de6a-54ab-a142-c8b1cf79cde6"
version = "0.21.1"

[[LibCURL]]
deps = ["LibCURL_jll", "MozillaCACerts_jll"]
uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21"

[[LibCURL_jll]]
deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"]
uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0"

[[LibGit2]]
deps = ["Printf"]
deps = ["Base64", "NetworkOptions", "Printf", "SHA"]
uuid = "76f85450-5226-5b5a-8eaa-529ad045b433"

[[LibSSH2_jll]]
deps = ["Artifacts", "Libdl", "MbedTLS_jll"]
uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8"

[[Libdl]]
uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb"

Expand All @@ -69,30 +101,35 @@ uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"
deps = ["Base64"]
uuid = "d6f4376e-aef5-505a-96c1-9c027394607a"

[[MbedTLS_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1"

[[Mmap]]
uuid = "a63ad114-7e13-5084-954f-fe012c677804"

[[MuladdMacro]]
git-tree-sha1 = "c6190f9a7fc5d9d5915ab29f2134421b12d24a68"
uuid = "46d2c3a1-f734-5fdb-9937-b9b9aeba4221"
version = "0.2.2"
[[MozillaCACerts_jll]]
uuid = "14a3606d-f60d-562e-9121-12d972cd8159"

[[NetworkOptions]]
uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908"

[[Parsers]]
deps = ["Dates"]
git-tree-sha1 = "50c9a9ed8c714945e01cd53a21007ed3865ed714"
git-tree-sha1 = "c8abc88faa3f7a3950832ac5d6e690881590d6dc"
uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0"
version = "1.0.15"
version = "1.1.0"

[[Pkg]]
deps = ["Dates", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"]
deps = ["Artifacts", "Dates", "Downloads", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"]
uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"

[[Printf]]
deps = ["Unicode"]
uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7"

[[REPL]]
deps = ["InteractiveUtils", "Markdown", "Sockets"]
deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"]
uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"

[[Random]]
Expand All @@ -111,6 +148,10 @@ uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce"
[[Serialization]]
uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b"

[[SharedArrays]]
deps = ["Distributed", "Mmap", "Random", "Serialization"]
uuid = "1a1011a3-84de-559e-8e89-a11a2f7dc383"

[[Sockets]]
uuid = "6462fe0b-24de-5631-8697-dd941f90decc"

Expand All @@ -120,16 +161,24 @@ uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"

[[StaticArrays]]
deps = ["LinearAlgebra", "Random", "Statistics"]
git-tree-sha1 = "9da72ed50e94dbff92036da395275ed114e04d49"
git-tree-sha1 = "c635017268fd51ed944ec429bcc4ad010bcea900"
uuid = "90137ffa-7385-5640-81b9-e52037218182"
version = "1.0.1"
version = "1.2.0"

[[Statistics]]
deps = ["LinearAlgebra", "SparseArrays"]
uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[[TOML]]
deps = ["Dates"]
uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76"

[[Tar]]
deps = ["ArgTools", "SHA"]
uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e"

[[Test]]
deps = ["Distributed", "InteractiveUtils", "Logging", "Random"]
deps = ["InteractiveUtils", "Logging", "Random", "Serialization"]
uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[[UUIDs]]
Expand All @@ -138,3 +187,15 @@ uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"

[[Unicode]]
uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"

[[Zlib_jll]]
deps = ["Libdl"]
uuid = "83775a58-1f1d-513f-b197-d71354ab007a"

[[nghttp2_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d"

[[p7zip_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0"
14 changes: 7 additions & 7 deletions src/difference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,21 @@ If `(y - x) / ε` is defined, then this operation is equivalent to doing that. F
where these operations aren't defined, `difference` can still be defined without commiting
type piracy while `-` and `/` cannot.
"""
difference(::Real, ::T, ::T) where {T<:Symbol} = DoesNotExist()
difference(::Real, ::T, ::T) where {T<:AbstractChar} = DoesNotExist()
difference(::Real, ::T, ::T) where {T<:AbstractString} = DoesNotExist()
difference(::Real, ::T, ::T) where {T<:Integer} = DoesNotExist()
difference(::Real, ::T, ::T) where {T<:Symbol} = NoTangent()
difference(::Real, ::T, ::T) where {T<:AbstractChar} = NoTangent()
difference(::Real, ::T, ::T) where {T<:AbstractString} = NoTangent()
difference(::Real, ::T, ::T) where {T<:Integer} = NoTangent()

difference::Real, y::T, x::T) where {T<:Number} = (y - x) / ε

difference::Real, y::T, x::T) where {T<:StridedArray} = difference.(ε, y, x)

function difference::Real, y::T, x::T) where {T<:Tuple}
return Composite{T}(difference.(ε, y, x)...)
return Tangent{T}(difference.(ε, y, x)...)
end

function difference::Real, ys::T, xs::T) where {T<:NamedTuple}
return Composite{T}(; map((y, x) -> difference(ε, y, x), ys, xs)...)
return Tangent{T}(; map((y, x) -> difference(ε, y, x), ys, xs)...)
end

function difference::Real, y::T, x::T) where {T}
Expand All @@ -38,7 +38,7 @@ function difference(ε::Real, y::T, x::T) where {T}
tangents = map(field_names) do field_name
difference(ε, getfield(y, field_name), getfield(x, field_name))
end
return Composite{T}(; NamedTuple{field_names}(tangents)...)
return Tangent{T}(; NamedTuple{field_names}(tangents)...)
else
return NO_FIELDS
end
Expand Down
18 changes: 9 additions & 9 deletions src/rand_tangent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@ Returns a randomly generated tangent vector appropriate for the primal value `x`
"""
rand_tangent(x) = rand_tangent(Random.GLOBAL_RNG, x)

rand_tangent(rng::AbstractRNG, x::Symbol) = DoesNotExist()
rand_tangent(rng::AbstractRNG, x::AbstractChar) = DoesNotExist()
rand_tangent(rng::AbstractRNG, x::AbstractString) = DoesNotExist()
rand_tangent(rng::AbstractRNG, x::Symbol) = NoTangent()
rand_tangent(rng::AbstractRNG, x::AbstractChar) = NoTangent()
rand_tangent(rng::AbstractRNG, x::AbstractString) = NoTangent()

rand_tangent(rng::AbstractRNG, x::Integer) = DoesNotExist()
rand_tangent(rng::AbstractRNG, x::Integer) = NoTangent()

rand_tangent(rng::AbstractRNG, x::T) where {T<:Number} = randn(rng, T)

Expand All @@ -20,11 +20,11 @@ rand_tangent(rng::AbstractRNG, ::BigFloat) = big(randn(rng))
rand_tangent(rng::AbstractRNG, x::StridedArray) = rand_tangent.(Ref(rng), x)

function rand_tangent(rng::AbstractRNG, x::T) where {T<:Tuple}
return Composite{T}(rand_tangent.(Ref(rng), x)...)
return Tangent{T}(rand_tangent.(Ref(rng), x)...)
end

function rand_tangent(rng::AbstractRNG, xs::T) where {T<:NamedTuple}
return Composite{T}(; map(x -> rand_tangent(rng, x), xs)...)
return Tangent{T}(; map(x -> rand_tangent(rng, x), xs)...)
end

function rand_tangent(rng::AbstractRNG, x::T) where {T}
Expand All @@ -37,11 +37,11 @@ function rand_tangent(rng::AbstractRNG, x::T) where {T}
tangents = map(field_names) do field_name
rand_tangent(rng, getfield(x, field_name))
end
if all(tangent isa DoesNotExist for tangent in tangents)
if all(tangent isa NoTangent for tangent in tangents)
# if none of my fields can be perturbed then I can't be perturbed
return DoesNotExist()
return NoTangent()
else
Composite{T}(; NamedTuple{field_names}(tangents)...)
Tangent{T}(; NamedTuple{field_names}(tangents)...)
end
else
return NO_FIELDS
Expand Down
8 changes: 4 additions & 4 deletions src/to_vec.jl
Original file line number Diff line number Diff line change
Expand Up @@ -176,15 +176,15 @@ end


# ChainRulesCore Differentials
function FiniteDifferences.to_vec(x::Composite{P}) where{P}
function FiniteDifferences.to_vec(x::Tangent{P}) where{P}
x_canon = canonicalize(x) # to be safe, fill in every field and put in primal order.
x_inner = ChainRulesCore.backing(x_canon)
x_vec, back_inner = FiniteDifferences.to_vec(x_inner)
function Composite_from_vec(y_vec)
function Tangent_from_vec(y_vec)
y_back = back_inner(y_vec)
return Composite{P, typeof(y_back)}(y_back)
return Tangent{P, typeof(y_back)}(y_back)
end
return x_vec, Composite_from_vec
return x_vec, Tangent_from_vec
end

function FiniteDifferences.to_vec(x::AbstractZero)
Expand Down
Loading

0 comments on commit 266d6fa

Please sign in to comment.