Skip to content

Commit

Permalink
Remove literal_getfield usage
Browse files Browse the repository at this point in the history
  • Loading branch information
willtebbutt committed Aug 3, 2024
1 parent 784dbad commit a169fcc
Show file tree
Hide file tree
Showing 10 changed files with 57 additions and 113 deletions.
57 changes: 22 additions & 35 deletions src/gp/lti_sde.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,25 +71,19 @@ end
function build_lgssm(f::LTISDE, x::AbstractVector, Σys::AbstractVector)
m = get_mean(f)
k = get_kernel(f)
s = Zygote.literal_getfield(f, Val(:storage))
As, as, Qs, emission_proj, x0 = lgssm_components(m, k, x, s)
As, as, Qs, emission_proj, x0 = lgssm_components(m, k, x, f.storage)
return LGSSM(
GaussMarkovModel(Forward(), As, as, Qs, x0), build_emissions(emission_proj, Σys),
)
end

function build_lgssm(ft::FiniteLTISDE)
f = Zygote.literal_getfield(ft, Val(:f))
x = Zygote.literal_getfield(ft, Val(:x))
Σys = noise_var_to_time_form(x, Zygote.literal_getfield(ft, Val(:Σy)))
return build_lgssm(f, x, Σys)
end
build_lgssm(ft::FiniteLTISDE) = build_lgssm(ft.f, ft.x, noise_var_to_time_form(ft.x, ft.Σy))

get_mean(f::LTISDE) = get_mean(Zygote.literal_getfield(f, Val(:f)))
get_mean(f::GP) = Zygote.literal_getfield(f, Val(:mean))
get_mean(f::LTISDE) = get_mean(f.f)
get_mean(f::GP) = f.mean

get_kernel(f::LTISDE) = get_kernel(Zygote.literal_getfield(f, Val(:f)))
get_kernel(f::GP) = Zygote.literal_getfield(f, Val(:kernel))
get_kernel(f::LTISDE) = get_kernel(f.f)
get_kernel(f::GP) = f.kernel

function build_emissions(
(Hs, hs)::Tuple{AbstractVector, AbstractVector}, Σs::AbstractVector,
Expand Down Expand Up @@ -332,20 +326,18 @@ end
# Scaled

function to_sde(k::ScaledKernel, storage::StorageType{T}) where {T<:Real}
_k = Zygote.literal_getfield(k, Val(:kernel))
σ² = Zygote.literal_getfield(k, Val(:σ²))
F, q, H = to_sde(_k, storage)
σ = sqrt(convert(eltype(storage), only(σ²)))
F, q, H = to_sde(k.kernel, storage)
σ = sqrt(convert(eltype(storage), only(k.σ²)))
return F, σ^2 * q, σ * H
end

stationary_distribution(k::ScaledKernel, storage::StorageType) = stationary_distribution(Zygote.literal_getfield(k, Val(:kernel)), storage)
function stationary_distribution(k::ScaledKernel, storage::StorageType)
return stationary_distribution(k.kernel, storage)
end

function lgssm_components(k::ScaledKernel, ts::AbstractVector, storage_type::StorageType)
_k = Zygote.literal_getfield(k, Val(:kernel))
σ² = Zygote.literal_getfield(k, Val(:σ²))
As, as, Qs, emission_proj, x0 = lgssm_components(_k, ts, storage_type)
σ = sqrt(convert(eltype(storage_type), only(σ²)))
As, as, Qs, emission_proj, x0 = lgssm_components(k.kernel, ts, storage_type)
σ = sqrt(convert(eltype(storage_type), only(k.σ²)))
return As, as, Qs, _scale_emission_projections(emission_proj, σ), x0
end

Expand All @@ -360,34 +352,29 @@ end
# Stretched

function to_sde(k::TransformedKernel{<:Kernel, <:ScaleTransform}, storage::StorageType)
_k = Zygote.literal_getfield(k, Val(:kernel))
s = Zygote.literal_getfield(Zygote.literal_getfield(k, Val(:transform)), Val(:s))
F, q, H = to_sde(_k, storage)
return F * only(s), q, H
F, q, H = to_sde(k.kernel, storage)
return F * only(k.transform.s), q, H
end

stationary_distribution(k::TransformedKernel{<:Kernel, <:ScaleTransform}, storage::StorageType) = stationary_distribution(Zygote.literal_getfield(k, Val(:kernel)), storage)
function stationary_distribution(
k::TransformedKernel{<:Kernel, <:ScaleTransform}, storage::StorageType
)
return stationary_distribution(k.kernel, storage)
end

function lgssm_components(
k::TransformedKernel{<:Kernel, <:ScaleTransform},
ts::AbstractVector,
storage_type::StorageType,
)
_k = Zygote.literal_getfield(k, Val(:kernel))
s = Zygote.literal_getfield(Zygote.literal_getfield(k, Val(:transform)), Val(:s))
return lgssm_components(_k, apply_stretch(s[1], ts), storage_type)
return lgssm_components(k.kernel, apply_stretch(k.transform.s[1], ts), storage_type)
end

apply_stretch(a, ts::AbstractVector{<:Real}) = a * ts

apply_stretch(a, ts::StepRangeLen) = a * ts

function apply_stretch(a, ts::RegularSpacing)
t0 = Zygote.literal_getfield(ts, Val(:t0))
Δt = Zygote.literal_getfield(ts, Val(:Δt))
N = Zygote.literal_getfield(ts, Val(:N))
return RegularSpacing(a * t0, a * Δt, N)
end
apply_stretch(a, ts::RegularSpacing) = RegularSpacing(a * ts.t0, a * ts.Δt, ts.N)

# Product

Expand Down
19 changes: 6 additions & 13 deletions src/gp/posterior_lti_sde.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,19 +25,12 @@ function AbstractGPs.marginals(fx::FinitePosteriorLTISDE)
model_post = replace_observation_noise_cov(posterior(model, ys), σ²s_pr_full)
return destructure(x, map(marginals, marginals(model_post))[pr_indices])
else
f = Zygote.literal_getfield(fx, Val(:f))
prior = Zygote.literal_getfield(f, Val(:prior))
x = Zygote.literal_getfield(fx, Val(:x))
data = Zygote.literal_getfield(f, Val(:data))
Σy = Zygote.literal_getfield(data, Val(:Σy))
Σy_diag = Zygote.literal_getfield(Σy, Val(:diag))
y = Zygote.literal_getfield(data, Val(:y))

Σy_new = Zygote.literal_getfield(fx, Val(:Σy))

model = build_lgssm(AbstractGPs.FiniteGP(prior, x, Σy))
Σys_new = noise_var_to_time_form(x, Σy_new)
ys = observations_to_time_form(x, y)
f = fx.f
x = fx.x
data = f.data
model = build_lgssm(AbstractGPs.FiniteGP(f.prior, x, data.Σy))
Σys_new = noise_var_to_time_form(x, fx.Σy)
ys = observations_to_time_form(x, data.y)
model_post = replace_observation_noise_cov(posterior(model, ys), Σys_new)
return destructure(x, map(marginals, marginals(model_post)))
end
Expand Down
2 changes: 1 addition & 1 deletion src/models/gauss_markov_model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ function is_of_storage_type(model::GaussMarkovModel, s::StorageType)
return is_of_storage_type((model.As, model.as, model.Qs, model.x0), s)
end

x0(model::GaussMarkovModel) = Zygote.literal_getfield(model, Val(:x0))
x0(model::GaussMarkovModel) = model.x0

function get_adjoint_storage(x::GaussMarkovModel, n::Int, Δx::Tangent{T,<:NamedTuple{(:A, :a, :Q)}}) where {T}
return (
Expand Down
29 changes: 9 additions & 20 deletions src/models/lgssm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,9 @@ struct LGSSM{Ttransitions<:GaussMarkovModel, Temissions<:StructArray} <: Abstrac
emissions::Temissions
end

@inline function transitions(model::LGSSM)
return Zygote.literal_getfield(model, Val(:transitions))
end
@inline transitions(model::LGSSM) = model.transitions

@inline function emissions(model::LGSSM)
return Zygote.literal_getfield(model, Val(:emissions))
end
@inline emissions(model::LGSSM) = model.emissions

@inline ordering(model::LGSSM) = ordering(transitions(model))
ChainRulesCore.@non_differentiable ordering(model)
Expand Down Expand Up @@ -58,17 +54,11 @@ struct ElementOfLGSSM{Tordering, Ttransition, Temission}
emission::Temission
end

@inline function ordering(x::ElementOfLGSSM)
return Zygote.literal_getfield(x, Val(:ordering))
end
@inline ordering(x::ElementOfLGSSM) = x.ordering

@inline function transition_dynamics(x::ElementOfLGSSM)
return Zygote.literal_getfield(x, Val(:transition))
end
@inline transition_dynamics(x::ElementOfLGSSM) = x.transition

@inline function emission_dynamics(x::ElementOfLGSSM)
return Zygote.literal_getfield(x, Val(:emission))
end
@inline emission_dynamics(x::ElementOfLGSSM) = x.emission

@inline function Base.getindex(model::LGSSM, n::Int)
return ElementOfLGSSM(ordering(model), model.transitions[n], model.emissions[n])
Expand Down Expand Up @@ -206,11 +196,10 @@ end
function posterior(prior::LGSSM, y::AbstractVector)
_check_inputs(prior, y)
new_trans, xf = _a_bit_of_posterior(prior, y)
A = zygote_friendly_map(x -> Zygote.literal_getfield(x, Val(:A)), new_trans)
a = zygote_friendly_map(x -> Zygote.literal_getfield(x, Val(:a)), new_trans)
Q = zygote_friendly_map(x -> Zygote.literal_getfield(x, Val(:Q)), new_trans)
ems = Zygote.literal_getfield(prior, Val(:emissions))
return LGSSM(GaussMarkovModel(reverse(ordering(prior)), A, a, Q, xf), ems)
A = zygote_friendly_map(x -> x.A, new_trans)
a = zygote_friendly_map(x -> x.a, new_trans)
Q = zygote_friendly_map(x -> x.Q, new_trans)
return LGSSM(GaussMarkovModel(reverse(ordering(prior)), A, a, Q, xf), prior.emissions)
end

function _check_inputs(prior, y)
Expand Down
47 changes: 11 additions & 36 deletions src/models/linear_gaussian_conditionals.jl
Original file line number Diff line number Diff line change
Expand Up @@ -126,14 +126,9 @@ dim_out(f::SmallOutputLGC) = size(f.A, 1)

dim_in(f::SmallOutputLGC) = size(f.A, 2)

noise_cov(f::SmallOutputLGC) = Zygote.literal_getfield(f, Val(:Q))
noise_cov(f::SmallOutputLGC) = f.Q

function get_fields(f::SmallOutputLGC)
A = Zygote.literal_getfield(f, Val(:A))
a = Zygote.literal_getfield(f, Val(:a))
Q = Zygote.literal_getfield(f, Val(:Q))
return A, a, Q
end
get_fields(f::SmallOutputLGC) = (f.A, f.a, f.Q)

function posterior_and_lml(x::Gaussian, f::SmallOutputLGC, y::AbstractVector{<:Real})
m, P = get_fields(x)
Expand Down Expand Up @@ -191,14 +186,9 @@ dim_out(f::LargeOutputLGC) = size(f.A, 1)

dim_in(f::LargeOutputLGC) = size(f.A, 2)

noise_cov(f::LargeOutputLGC) = Zygote.literal_getfield(f, Val(:Q))
noise_cov(f::LargeOutputLGC) = f.Q

function get_fields(f::LargeOutputLGC)
A = Zygote.literal_getfield(f, Val(:A))
a = Zygote.literal_getfield(f, Val(:a))
Q = Zygote.literal_getfield(f, Val(:Q))
return A, a, Q
end
get_fields(f::LargeOutputLGC) = (f.A, f.a, f.Q)

function posterior_and_lml(x::Gaussian, f::LargeOutputLGC, y::AbstractVector{<:Real})
m, _P = get_fields(x)
Expand Down Expand Up @@ -258,18 +248,12 @@ dim_out(f::ScalarOutputLGC) = 1

dim_in(f::ScalarOutputLGC) = size(f.A, 2)

function get_fields(f::ScalarOutputLGC)
A = Zygote.literal_getfield(f, Val(:A))
a = Zygote.literal_getfield(f, Val(:a))
Q = Zygote.literal_getfield(f, Val(:Q))
return A, a, Q
end
get_fields(f::ScalarOutputLGC) = (f.A, f.a, f.Q)

noise_cov(f::ScalarOutputLGC) = Zygote.literal_getfield(f, Val(:Q))
noise_cov(f::ScalarOutputLGC) = f.Q

function conditional_rand::Real, f::ScalarOutputLGC, x::AbstractVector)
A, a, Q = get_fields(f)
return (A * x + a) + sqrt(Q) * ε
return (f.A * x + f.a) + sqrt(f.Q) * ε
end

ε_randn(rng::AbstractRNG, f::ScalarOutputLGC) = randn(rng, eltype(f))
Expand Down Expand Up @@ -323,16 +307,9 @@ dim_out(f::BottleneckLGC) = dim_out(f.fan_out)

dim_in(f::BottleneckLGC) = size(f.H, 2)

noise_cov(f::BottleneckLGC) = noise_cov(Zygote.literal_getfield(f, Val(:fan_out)))
noise_cov(f::BottleneckLGC) = noise_cov(f.fan_out)

function get_fields(f::BottleneckLGC)
H = Zygote.literal_getfield(f, Val(:H))
h = Zygote.literal_getfield(f, Val(:h))
fan_out = Zygote.literal_getfield(f, Val(:fan_out))
return H, h, fan_out
end

fan_out(f::BottleneckLGC) = Zygote.literal_getfield(f, Val(:fan_out))
get_fields(f::BottleneckLGC) = (f.H, f.h, f.fan_out)

function conditional_rand::AbstractVector{<:Real}, f::BottleneckLGC, x::AbstractVector)
H, h, fan_out = get_fields(f)
Expand All @@ -348,12 +325,10 @@ function _project(x::Gaussian, f::BottleneckLGC)
return Gaussian(H * m + h, H * P * H' + ident_eps(x))
end

function predict(x::Gaussian, f::BottleneckLGC)
return predict(_project(x, f), fan_out(f))
end
predict(x::Gaussian, f::BottleneckLGC) = predict(_project(x, f), f.fan_out)

function predict_marginals(x::Gaussian, f::BottleneckLGC)
return predict_marginals(_project(x, f), fan_out(f))
return predict_marginals(_project(x, f), f.fan_out)
end

function posterior_and_lml(x::Gaussian, f::BottleneckLGC, y::AbstractVector)
Expand Down
2 changes: 1 addition & 1 deletion src/models/missings.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ function _fill_in_missings(Σs::Vector, y::AbstractVector{Union{Missing, T}}) wh
end

function fill_in_missings::Diagonal, y::AbstractVector{<:Union{Missing, <:Real}})
Σ_diag_filled, y_filled = fill_in_missings(Zygote.literal_getfield(Σ, Val(:diag)), y)
Σ_diag_filled, y_filled = fill_in_missings(Σ.diag, y)
return Diagonal(Σ_diag_filled), y_filled
end

Expand Down
4 changes: 2 additions & 2 deletions src/space_time/rectilinear_grid.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ struct RectilinearGrid{
xr::Txr
end

get_space(x::RectilinearGrid) = Zygote.literal_getfield(x, Val(:xl))
get_space(x::RectilinearGrid) = x.xl

get_times(x::RectilinearGrid) = Zygote.literal_getfield(x, Val(:xr))
get_times(x::RectilinearGrid) = x.xr

Base.size(X::RectilinearGrid) = (length(X.xl) * length(X.xr),)

Expand Down
4 changes: 2 additions & 2 deletions src/space_time/regular_in_time.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ struct RegularInTime{
vs::Tvs
end

get_space(x::RegularInTime) = Zygote.literal_getfield(x, Val(:vs))
get_space(x::RegularInTime) = x.vs

get_times(x::RegularInTime) = Zygote.literal_getfield(x, Val(:ts))
get_times(x::RegularInTime) = x.ts

Base.size(x::RegularInTime) = (sum(length, x.vs), )

Expand Down
4 changes: 2 additions & 2 deletions src/util/gaussian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ end

dim(x::Gaussian) = length(x.m)

AbstractGPs.mean(x::Gaussian) = Zygote.literal_getfield(x, Val(:m))
AbstractGPs.mean(x::Gaussian) = x.m

AbstractGPs.cov(x::Gaussian) = Zygote.literal_getfield(x, Val(:P))
AbstractGPs.cov(x::Gaussian) = x.P

AbstractGPs.var(x::Gaussian{<:AbstractVector}) = diag(cov(x))

Expand Down
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ ENV["TESTING"] = "TRUE"
# Select any of this to test a particular aspect.
# To test everything, simply set GROUP to "all"
# ENV["GROUP"] = "test gp"
const GROUP = get(ENV, "GROUP", "test")
const GROUP = get(ENV, "GROUP", "all")
OUTER_GROUP = first(split(GROUP, ' '))

const TEST_TYPE_INFER = false # Test type stability over the tests
Expand Down

0 comments on commit a169fcc

Please sign in to comment.