Skip to content
This repository was archived by the owner on May 15, 2025. It is now read-only.

[WIP] Testing out a vmap implementation #141

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Testing out a vmap implementation
  • Loading branch information
avik-pal committed May 4, 2024
commit ecbad27a721fb0218aaeca6767ac63350581dc1f
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "SimpleNonlinearSolve"
uuid = "727e6d20-b764-4bd8-a329-72de5adea6c7"
authors = ["SciML"]
version = "1.8.0"
version = "1.8.1"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
4 changes: 2 additions & 2 deletions src/nlsolve/dfsane.jl
Original file line number Diff line number Diff line change
Expand Up @@ -110,15 +110,15 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleDFSane{M}, args...
fx_norm_new = NONLINEARSOLVE_DEFAULT_NORM(fx)^nexp

while k < maxiters
Bool(fx_norm_new ≤ (f_bar + η - γ * α_p^2 * fx_norm)) && break
all(fx_norm_new ≤ (f_bar + η - γ * α_p^2 * fx_norm)) && break

α_tp = α_p^2 * fx_norm / (fx_norm_new + (T(2) * α_p - T(1)) * fx_norm)
@bb @. x_cache = x - α_m * d

fx = __eval_f(prob, fx, x_cache)
fx_norm_new = NONLINEARSOLVE_DEFAULT_NORM(fx)^nexp

Bool(fx_norm_new ≤ (f_bar + η - γ * α_m^2 * fx_norm)) && break
all(fx_norm_new ≤ (f_bar + η - γ * α_m^2 * fx_norm)) && break

α_tm = α_m^2 * fx_norm / (fx_norm_new + (T(2) * α_m - T(1)) * fx_norm)
α_p = clamp(α_tp, τ_min * α_p, τ_max * α_p)
Expand Down
6 changes: 3 additions & 3 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -321,19 +321,19 @@ function check_termination(tc_cache, fx, x, xo, prob, alg)
end
function check_termination(tc_cache, fx, x, xo, prob, alg,
::AbstractNonlinearTerminationMode)
tc_cache(fx, x, xo) &&
all(tc_cache(fx, x, xo)) &&
return build_solution(prob, alg, x, fx; retcode = ReturnCode.Success)
return nothing
end
function check_termination(tc_cache, fx, x, xo, prob, alg,
::AbstractSafeNonlinearTerminationMode)
tc_cache(fx, x, xo) &&
all(tc_cache(fx, x, xo)) &&
return build_solution(prob, alg, x, fx; retcode = tc_cache.retcode)
return nothing
end
function check_termination(tc_cache, fx, x, xo, prob, alg,
::AbstractSafeBestNonlinearTerminationMode)
if tc_cache(fx, x, xo)
if all(tc_cache(fx, x, xo))
if isinplace(prob)
prob.f(fx, x, prob.p)
else
Expand Down