Skip to content

Commit

Permalink
Merge pull request #244 from SciML/add_remake
Browse files Browse the repository at this point in the history
Function to pass training parameters to the next solver
  • Loading branch information
ChrisRackauckas authored Feb 6, 2021
2 parents 4687d63 + 18cf8be commit 72c91a5
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 35 deletions.
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ QuasiMonteCarlo = "8a4e6c94-4038-4cdc-81c3-7e6ffdb2a71b"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
RuntimeGeneratedFunctions = "7e49a35a-f44a-4d26-94aa-eba1b4ca6b47"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Expand All @@ -42,12 +43,13 @@ Distributions = "0.23, 0.24"
Flux = "0.10.1, 0.11"
ForwardDiff = "0.10"
GalacticOptim = "0.3, 0.4"
ModelingToolkit = "4.3.4, 5"
ModelingToolkit = "5"
Optim = "1.0"
Quadrature = "1.5"
QuasiMonteCarlo = "0.2.1"
Reexport = "0.2, 1.0"
RuntimeGeneratedFunctions = "0.4, 0.5"
SciMLBase = "1.6"
StochasticDiffEq = "6.13"
Tracker = "0.2"
Zygote = "0.5, 0.6"
Expand Down
6 changes: 4 additions & 2 deletions src/NeuralPDE.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@ using GalacticOptim
using Quadrature
using QuasiMonteCarlo
using RuntimeGeneratedFunctions
using SciMLBase
import Tracker, Optim
import ModelingToolkit: value, nameof, toexpr, build_expr, expand_derivatives

import SciMLBase: @add_kwonly
abstract type NeuralPDEAlgorithm <: DiffEqBase.AbstractODEAlgorithm end
"""
TerminalPDEProblem(g, f, μ, σ, x0, tspan)
Expand Down Expand Up @@ -136,6 +137,7 @@ export NNODE, TerminalPDEProblem, NNPDEHan, NNPDENS, NNRODE,
build_loss_function, get_loss_function,
generate_training_sets, get_bc_varibles, get_bounds
get_phi, get_numeric_derivative,
build_symbolic_equation, build_symbolic_loss_function, symbolic_discretize
build_symbolic_equation, build_symbolic_loss_function, symbolic_discretize,
remake

end # module
65 changes: 35 additions & 30 deletions src/pinns_pde_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,51 +10,55 @@ Arguments:
* `derivative`: method that calculates the derivative.
"""
abstract type AbstractPINN{isinplace} <: SciMLBase.SciMLProblem end

struct PhysicsInformedNN{C,T,P,PH,DER,K}
struct PhysicsInformedNN{C,T,P,PH,DER,K} <: AbstractPINN{isinplace}
chain::C
strategy::T
init_params::P
phi::PH
derivative::DER
kwargs::K
end

function PhysicsInformedNN(chain,
strategy;
init_params = nothing,
phi = nothing,
derivative = nothing,
kwargs...)
if init_params == nothing
if chain isa AbstractArray
initθ = DiffEqFlux.initial_params.(chain)
@add_kwonly function PhysicsInformedNN{iip}(chain,
strategy,
init_params = nothing,
phi = nothing,
derivative = nothing;
kwargs...) where iip
if init_params == nothing
if chain isa AbstractArray
initθ = DiffEqFlux.initial_params.(chain)
else
initθ = DiffEqFlux.initial_params(chain)
end

else
initθ = DiffEqFlux.initial_params(chain)
initθ = init_params
end

else
initθ = init_params
end
if phi == nothing
if chain isa AbstractArray
_phi = get_phi.(chain)
else
_phi = get_phi(chain)
end
else
_phi = phi
end

if phi == nothing
if chain isa AbstractArray
_phi = get_phi.(chain)
if derivative == nothing
_derivative = get_numeric_derivative()
else
_phi = get_phi(chain)
_derivative = derivative
end
else
_phi = phi
new{typeof(chain),typeof(strategy),typeof(initθ),typeof(_phi),typeof(_derivative),typeof(kwargs)}(chain,strategy,initθ,_phi,_derivative, kwargs)
end
end
PhysicsInformedNN(chain,strategy,args...;kwargs...) = PhysicsInformedNN{true}(chain,strategy,args...;kwargs...)

if derivative == nothing
_derivative = get_numeric_derivative()
else
_derivative = derivative
end
SciMLBase.isinplace(prob::PhysicsInformedNN{iip}) where iip = iip

PhysicsInformedNN(chain,strategy,initθ,_phi,_derivative, kwargs)
end

abstract type TrainingStrategies end

Expand Down Expand Up @@ -613,7 +617,7 @@ function get_loss_function(loss_functions, bounds, strategy::QuadratureTraining;
return loss
end
function symbolic_discretize(pde_system::PDESystem, discretization::PhysicsInformedNN)
eqs = pde_system.eq
eqs = pde_system.eqs
bcs = pde_system.bcs

domains = pde_system.domain
Expand All @@ -638,9 +642,10 @@ function symbolic_discretize(pde_system::PDESystem, discretization::PhysicsInfor
bc_indvars = bc_indvar) for (bc,bc_indvar) in zip(bcs,bc_indvars)]
symbolic_pde_loss_function,symbolic_bc_loss_functions
end

# Convert a PDE problem into an OptimizationProblem
function DiffEqBase.discretize(pde_system::PDESystem, discretization::PhysicsInformedNN)
eqs = pde_system.eq
eqs = pde_system.eqs
bcs = pde_system.bcs

domains = pde_system.domain
Expand Down
9 changes: 7 additions & 2 deletions test/NNPDE_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,15 @@ discretization = NeuralPDE.PhysicsInformedNN(chain,
strategy;
init_params = nothing,
phi = nothing,
autodiff=false,
derivative = nothing,
)

pde_system = PDESystem(eq,bcs,domains,[θ],[u])
prob = NeuralPDE.discretize(pde_system,discretization)

res = GalacticOptim.solve(prob, ADAM(0.1); cb = cb, maxiters=1000)
prob2 = remake(prob,u0=res.minimizer)
res = GalacticOptim.solve(prob2, ADAM(0.001); cb = cb, maxiters=1000)
phi = discretization.phi

analytic_sol_func(t) = exp(-(t^2)/2)/(1+t+t^3) + t^2
Expand Down Expand Up @@ -417,7 +418,11 @@ discretization = NeuralPDE.PhysicsInformedNN(chain,
pde_system = PDESystem(eq,bcs,domains,[x],[p])
prob = NeuralPDE.discretize(pde_system,discretization)

res = GalacticOptim.solve(prob,Optim.BFGS(); cb = cb, maxiters=1000)
res = GalacticOptim.solve(prob,Optim.BFGS(); cb = cb, maxiters=800)
discretization2 = remake(discretization; strategy = NeuralPDE.GridTraining(dx/5), init_params =res.minimizer)
prob = NeuralPDE.discretize(pde_system,discretization2)

res = GalacticOptim.solve(prob,Optim.BFGS();cb=cb,maxiters=100)
phi = discretization.phi

analytic_sol_func(x) = 28*exp((1/(2*^2))*(2*α*x^2 - β*x^4))
Expand Down

0 comments on commit 72c91a5

Please sign in to comment.