Skip to content

Commit

Permalink
restructure ALRoutine
Browse files Browse the repository at this point in the history
  • Loading branch information
joannajzou committed May 23, 2024
1 parent 855e3f8 commit e23537b
Showing 1 changed file with 43 additions and 75 deletions.
118 changes: 43 additions & 75 deletions src/activelearning/activelearning.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import PotentialLearning: SubsetSelector, LearningProblem
export
ActiveLearnRoutine,
active_learn!,
Expand All @@ -20,7 +21,6 @@ Struct containing all parameters required to trigger retraining during simulatio
- `error_hist :: Dict` : dictionary of error metrics to record over simulation
# Keyword Arguments
- `burnin :: Integer = 0` : number of burn-in simulation steps before triggering training
- `update_func :: Function = update_sys` : function for updating the the training and simulation systems
- `train_func :: Function - train_potential_e!` : function defining training objective for `mlip`
- `train_steps :: Vector{<:Integer} = [1]` : vector of simulation steps at which training is trigger_activated
Expand All @@ -30,45 +30,45 @@ Struct containing all parameters required to trigger retraining during simulatio
"""

mutable struct ActiveLearnRoutine
ref :: Union{GeneralInteraction, PairwiseInteraction}
ref # :: Union{GeneralInteraction, PairwiseInteraction}
mlip :: MLInteraction
sys_train :: Vector{<:System}
eval_int :: Integrator
trigger :: ActiveLearningTrigger
error_hist :: Dict
burnin :: Integer
update_func :: Function
train_func :: Function
train_steps :: Vector{<:Integer}
param_hist :: Vector{<:Vector}

trainset :: Vector{<:System}
triggers :: Tuple # <: ActiveLearningTrigger
ss :: SubsetSelector
lp :: LearningProblem
data :: Union{Nothing, Dict}
end

function ActiveLearnRoutine(
ref::Union{GeneralInteraction, PairwiseInteraction},
mlip::MLInteraction,
sys_train::Vector{<:System},
eval_int::Integrator,
trigger::ActiveLearningTrigger,
error_hist::Dict;
burnin::Integer = 0,
update_func::Function = update_sys,
train_func::Function = train_potential_e!,
kwargs...
function ActiveLearnRoutine(;
ref,
mlip::MLInteraction,
trainset::Vector{<:System},
triggers::Tuple,
ss::SubsetSelector,
lp::LearningProblem,
kwargs...
)

train_steps = [1]
param_hist = [mlip.params]

update(sim, sys, ref) = update_func(sim, sys, ref; kwargs...)
train(sys, sys_train, ref) = train_func(sys, sys_train, ref; kwargs...)
return ActiveLearnRoutine(ref, mlip, sys_train, eval_int, trigger, error_hist, burnin, update, train, train_steps, param_hist)
# populate sys_train with descriptors
compute_local_descriptors(trainset, mlip)
compute_force_descriptors(trainset, mlip)

# initialize AL data
aldata = Dict(
"train_steps" => [1],
"param_hist" => [mlip.params],
"trigger_eval" => [],
)

return ActiveLearnRoutine(ref, mlip, trainset, triggers, ss, lp, aldata)
end






"""
active_learn!(sys::Union{System, Vector{<:System}}, sim::Simulator, n_steps::Integer, sys_train::Vector{<:System}, ref::Union{GeneralInteraction, PairwiseInteraction}, trigger::Union{Bool, ActiveLearningTrigger}; n_threads::Integer=Threads.nthreads(), burnin::Integer=100, run_loggers=true)
active_learn!(sys::Union{System, Vector{<:System}}, sim::Simulator, n_steps::Integer, sys_train::Vector{<:System}, ref::Union{GeneralInteraction, PairwiseInteraction}, trigger::Union{Bool, ActiveLearningTrigger}; n_threads::Integer=Threads.nthreads(), run_loggers=true)
Performs online active learning by molecular dynamics simulation defined in `sim`, using the retraining criterion defined in `trigger`.
Expand All @@ -80,7 +80,6 @@ Performs online active learning by molecular dynamics simulation defined in `sim
- `ref::Union{GeneralInteraction, PairwiseInteraction}` : interaction for computing reference values
- `trigger::Union{Bool, ActiveLearningTrigger}` : trigger which instantiates retraining
- `n_threads::Integer=Threads.nthreads()` : number of threads
- `burnin::Integer=100` : number of burn-in steps before active learning routine
- `run_loggers=true` : Bool for running loggers
"""
Expand All @@ -93,15 +92,11 @@ function active_learn!(sys::System,
run_loggers=true,
rng=Random.GLOBAL_RNG,
)
sys.coords = wrap_coords.(sys.coords, (sys.boundary,))
!iszero(sim.remove_CM_motion) && remove_CM_motion!(sys)
neighbors = find_neighbors(sys, sys.neighbor_finder; n_threads=n_threads)
run_loggers!(sys, neighbors, 0, run_loggers; n_threads=n_threads)
sys, neighbors = initialize_sim!(sys; n_threads=n_threads, run_loggers=run_loggers)
compute_error_metrics!(al)

ct = 0

for step_n in 1:n_steps
ct += 1
neighbors, ksd = simulation_step!(sys,
sim,
step_n,
Expand All @@ -115,14 +110,13 @@ function active_learn!(sys::System,
sim.sys_fix = reduce(vcat, [sim.sys_fix[2:end], sys_new])

# online active learning
if trigger_activated(al.trigger; step_n=step_n, ksd=ksd) && ct >= al.burnin
if trigger_activated(al.trigger; step_n=step_n, ksd=ksd)
al.sys_train = al.update_func(sim, sys, al.sys_train)
al.train_func(sys, al.sys_train, al.ref) # retrain potential
al.mlip = sys.general_inters[1]
append!(al.train_steps, step_n)
append!(al.param_hist, [sys.general_inters[1].params])
compute_error_metrics!(al)
ct = 0 # reset counter
end
end
return al
Expand Down Expand Up @@ -152,9 +146,7 @@ function active_learn!(ens::Vector{<:System},
run_loggers!(sys, nb, 0, run_loggers; n_threads=n_threads)
end

ct = 0
for step_n in 1:n_steps
ct += 1
nb_ens, ksd, bwd[step_n] = simulation_step!(ens,
nb_ens,
sim,
Expand All @@ -165,15 +157,14 @@ function active_learn!(ens::Vector{<:System},
)

# online active learning
if trigger_activated(al.trigger; step_n=step_n, ens_old=al.sys_train, ens_new=ens, ksd=ksd) && ct >= al.burnin
if trigger_activated(al.trigger; step_n=step_n, ens_old=al.sys_train, ens_new=ens, ksd=ksd)
println("train on step $step_n")
al.sys_train = al.update_func(sim, ens, al.sys_train)
al.train_func(ens, al.sys_train, al.ref) # retrain potential
al.mlip = ens[1].general_inters[1]
append!(al.train_steps, step_n)
append!(al.param_hist, [ens[1].general_inters[1].params])
compute_error_metrics!(al)
ct = 0 # reset counter
end
end

Expand All @@ -195,9 +186,7 @@ function active_learn!(sys::System,
run_loggers!(sys, neighbors, 0, run_loggers; n_threads=n_threads)
compute_error_metrics!(al)

ct = 0
for step_n in 1:n_steps
ct += 1
neighbors = simulation_step!(sys,
sim,
step_n,
Expand All @@ -208,15 +197,14 @@ function active_learn!(sys::System,
)

# online active learning
if trigger_activated(al.trigger; ens_old=al.sys_train, sys_new=sys, step_n=step_n) && ct >= al.burnin
if trigger_activated(al.trigger; ens_old=al.sys_train, sys_new=sys, step_n=step_n)
println("train on step $step_n")
al.sys_train = al.update_func(sim, sys, al.sys_train)
al.train_func(sys, al.sys_train, al.ref) # retrain potential
al.mlip = sys.general_inters[1]
append!(al.train_steps, step_n)
append!(al.param_hist, [sys.general_inters[1].params])
compute_error_metrics!(al)
ct = 0 # reset counter
end
end
return al
Expand Down Expand Up @@ -284,34 +272,7 @@ function update_sys(sim::StochasticSVGD,
end


"""
remove_loggers(sys::System)
A function which refines the system `sys` without loggers.
"""
# redefine the system without loggers
function remove_loggers(
sys::System
)
return System(
atoms=sys.atoms,
coords=sys.coords,
boundary=sys.boundary,
general_inters=sys.general_inters,
)
end

# redefine the system without loggers
function remove_loggers(
ens::Vector{<:System},
)
return [System(
atoms=sys.atoms,
coords=sys.coords,
boundary=sys.boundary,
general_inters=sys.general_inters,
) for sys in ens]
end

"""
compute_error_metrics!(al::ActiveLearnRoutine)
Expand All @@ -326,3 +287,10 @@ function compute_error_metrics!(al::ActiveLearnRoutine)
append!(al.error_hist["rmse_e"], r_e)
append!(al.error_hist["rmse_f"], r_f)
end



include("distributions.jl")
include("ensembles.jl")
include("kernels.jl")
include("training.jl")

0 comments on commit e23537b

Please sign in to comment.