Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
KirillZubov committed Sep 30, 2021
1 parent d4db55d commit e8e62ff
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions test/forward_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ using DiffEqBase
using Test, NeuralPDE
println("Starting Soon!")
using SciMLBase
import ModelingToolkit: Interval, infimum, supremum
import ModelingToolkit: Interval

@testset "ODE" begin
@parameters x
Expand All @@ -19,17 +19,17 @@ import ModelingToolkit: Interval, infimum, supremum
domains = [x Interval(0.0,1.0)]
chain = FastChain((x,p) -> x.^2)

chain([1],Float32[])
chain([1],Float64[])
strategy_ = NeuralPDE.GridTraining(0.1)
discretization = NeuralPDE.PhysicsInformedNN(chain,strategy_)
discretization = NeuralPDE.PhysicsInformedNN(chain,strategy_;init_params = Float64[])
@named pde_system = PDESystem(eq,bcs,domains,[x],[u(x)])
prob = NeuralPDE.discretize(pde_system,discretization)

train_data =prob.f.f.loss_function.pde_loss_function.pde_loss_functions.contents[1].train_set
inner_loss =prob.f.f.loss_function.pde_loss_function.pde_loss_functions.contents[1].loss_function

dudx(x) = @. 2*x
@test inner_loss(train_data, Float32[]) dudx(train_data) rtol = 0.001
@test inner_loss(train_data, Float64[]) dudx(train_data) rtol = 1e-8
end

@testset "derivatives" begin
Expand Down

0 comments on commit e8e62ff

Please sign in to comment.