Skip to content

Commit

Permalink
update convert, adapt
Browse files Browse the repository at this point in the history
  • Loading branch information
KirillZubov committed Oct 9, 2024
1 parent 2df6e34 commit e74955d
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 18 deletions.
6 changes: 2 additions & 4 deletions src/advancedHMC_MCMC.jl
Original file line number Diff line number Diff line change
Expand Up @@ -354,8 +354,7 @@ NN OUTPUT AT t,θ ~ phi(t,θ).
function (f::LogTargetDensity{C, S})(t::AbstractVector,
θ) where {C <: Lux.AbstractExplicitLayer, S}
θ = vector_to_parameters(θ, f.init_params)
θ_ = ComponentArrays.getdata(θ)
eltypeθ, typeθ = eltype(θ_), parameterless_type(θ_)
eltypeθ, typeθ = eltype(θ), parameterless_type(ComponentArrays.getdata(θ))
t_ = convert.(eltypeθ, adapt(typeθ, t'))
y, st = f.chain(t_, θ, f.st)
ChainRulesCore.@ignore_derivatives f.st = st
Expand All @@ -365,8 +364,7 @@ end
function (f::LogTargetDensity{C, S})(t::Number,
θ) where {C <: Lux.AbstractExplicitLayer, S}
θ = vector_to_parameters(θ, f.init_params)
θ_ = ComponentArrays.getdata(θ)
eltypeθ, typeθ = eltype(θ_), parameterless_type(θ_)
eltypeθ, typeθ = eltype(θ), parameterless_type(ComponentArrays.getdata(θ))
t_ = convert.(eltypeθ, adapt(typeθ, [t]))
y, st = f.chain(t_, θ, f.st)
ChainRulesCore.@ignore_derivatives f.st = st
Expand Down
12 changes: 4 additions & 8 deletions src/ode_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,7 @@ end

function (f::ODEPhi{C, T, U})(t::Number,
θ) where {C <: Lux.AbstractExplicitLayer, T, U <: Number}
θ_ = ComponentArrays.getdata.depvar)
eltypeθ, typeθ = eltype(θ_), parameterless_type(θ_)
eltypeθ, typeθ = eltype(θ), parameterless_type(ComponentArrays.getdata(θ))
t_ = convert.(eltypeθ, adapt(typeθ, [t]))
y, st = f.chain(t_, θ.depvar, f.st)
ChainRulesCore.@ignore_derivatives f.st = st
Expand All @@ -129,17 +128,15 @@ end
function (f::ODEPhi{C, T, U})(t::AbstractVector,
θ) where {C <: Lux.AbstractExplicitLayer, T, U <: Number}
# Batch via data as row vectors
θ_ = ComponentArrays.getdata.depvar)
eltypeθ, typeθ = eltype(θ_), parameterless_type(θ_)
eltypeθ, typeθ = eltype(θ), parameterless_type(ComponentArrays.getdata(θ))
t_ = convert.(eltypeθ, adapt(typeθ, t'))
y, st = f.chain(t_, θ.depvar, f.st)
ChainRulesCore.@ignore_derivatives f.st = st
f.u0 .+ (t' .- f.t0) .* y
end

function (f::ODEPhi{C, T, U})(t::Number, θ) where {C <: Lux.AbstractExplicitLayer, T, U}
θ_ = ComponentArrays.getdata.depvar)
eltypeθ, typeθ = eltype(θ_), parameterless_type(θ_)
eltypeθ, typeθ = eltype(θ), parameterless_type(ComponentArrays.getdata(θ))
t_ = convert.(eltypeθ, adapt(typeθ, [t]))
y, st = f.chain(t_, θ.depvar, f.st)
ChainRulesCore.@ignore_derivatives f.st = st
Expand All @@ -149,8 +146,7 @@ end
function (f::ODEPhi{C, T, U})(t::AbstractVector,
θ) where {C <: Lux.AbstractExplicitLayer, T, U}
# Batch via data as row vectors
θ_ = ComponentArrays.getdata.depvar)
eltypeθ, typeθ = eltype(θ_), parameterless_type(θ_)
eltypeθ, typeθ = eltype(θ), parameterless_type(ComponentArrays.getdata(θ))
t_ = convert.(eltypeθ, adapt(typeθ, t'))
y, st = f.chain(t_, θ.depvar, f.st)
ChainRulesCore.@ignore_derivatives f.st = st
Expand Down
9 changes: 3 additions & 6 deletions src/pinn_types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -505,17 +505,15 @@ mutable struct Phi{C, S}
end

function (f::Phi{<:Lux.AbstractExplicitLayer})(x::Number, θ)
θ_ = ComponentArrays.getdata(θ)
eltypeθ, typeθ = eltype(θ_), parameterless_type(θ_)
eltypeθ, typeθ = eltype(θ), parameterless_type(ComponentArrays.getdata(θ))
x_ = convert.(eltypeθ, adapt(typeθ, [x]))
y, st = f.f(x_, θ, f.st)
ChainRulesCore.@ignore_derivatives f.st = st
y
end

function (f::Phi{<:Lux.AbstractExplicitLayer})(x::AbstractArray, θ)
θ_ = ComponentArrays.getdata(θ)
eltypeθ, typeθ = eltype(θ_), parameterless_type(θ_)
eltypeθ, typeθ = eltype(θ), parameterless_type(ComponentArrays.getdata(θ))
x_ = convert.(eltypeθ, adapt(typeθ, x))
y, st = f.f(x_, θ, f.st)
ChainRulesCore.@ignore_derivatives f.st = st
Expand All @@ -528,8 +526,7 @@ end

# the method to calculate the derivative
function numeric_derivative(phi, u, x, εs, order, θ)
θ_ = ComponentArrays.getdata(θ)
eltypeθ, typeθ = eltype(θ_), parameterless_type(θ_)
eltypeθ, typeθ = eltype(θ), parameterless_type(ComponentArrays.getdata(θ))

ε = εs[order]
_epsilon = inv(first(ε[ε .!= zero(ε)]))
Expand Down

0 comments on commit e74955d

Please sign in to comment.