Skip to content

Commit 937a9fc

Browse files
committed
Fix tests, add copy
1 parent 970a090 commit 937a9fc

File tree

2 files changed

+3
-2
lines changed

2 files changed

+3
-2
lines changed

src/DeepBSDE.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ function DiffEqBase.solve(
153153
function neural_sde(init_cond)
154154
output_func(sol,i) = ((sol[end][1:end-1], sol[end][end]),false)
155155
function prob_func(prob,i,repeat)
156-
SDEProblem(prob.f , prob.g , init_cond, prob.tspan , prob.p ,noise_rate_prototype = prob.noise_rate_prototype)
156+
SDEProblem(prob.f , prob.g , init_cond, prob.tspan , prob.p ,noise_rate_prototype = copy(prob.noise_rate_prototype))
157157
end
158158
ensembleprob = EnsembleProblem(prob, output_func = output_func, prob_func = prob_func)
159159
sim = solve(ensembleprob,sdealg,ensemblealg, dt=dt, save_everystep = false;sensealg=DiffEqSensitivity.TrackerAdjoint(),trajectories=trajectories)

test/DeepBSDE.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ end
8080
σᵀ∇u = Flux.Chain(Dense(d+1,hls,relu),
8181
Dense(hls,hls,relu),
8282
Dense(hls,d))
83+
opt = ADAM(0.005)
8384
pdealg = DeepBSDE(u0, σᵀ∇u, opt=opt)
8485

8586
sol = solve(prob,
@@ -93,7 +94,7 @@ end
9394

9495
u_analytical(x,t) = sum(x.^2) .+ d*t
9596
analytical_sol = u_analytical(x0, tspan[end])
96-
error_l2 = rel_error_l2(res.us,analytical_sol)
97+
error_l2 = rel_error_l2(sol.us,analytical_sol)
9798
println("error_l2 = ", error_l2, "\n")
9899
@test error_l2 < 1.0
99100
end

0 commit comments

Comments
 (0)