Skip to content

Commit

Permalink
add observable value to TriggerLogger
Browse files Browse the repository at this point in the history
record value of trigger.eval at current time step
  • Loading branch information
joannajzou committed May 21, 2024
1 parent e25d60d commit 1df4946
Showing 1 changed file with 16 additions and 9 deletions.
25 changes: 16 additions & 9 deletions src/loggers/triggerlogger.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,37 +8,44 @@ export TriggerLogger
A logger which holds a record of evaluations of the trigger function for active learning.
# Arguments
- `trigger::ActiveLearningTrigger` : trigger function.
- `n_steps::Int` : time step interval at which the trigger function is evaluated.
- `history::Vector{T}` : record of the trigger function evaluation.
- `trigger::ActiveLearningTrigger` : trigger function.
- `observable::T` : value of the trigger function of type `T`.
- `n_steps::Int` : time step interval at which the trigger function is evaluated.
- `history::Vector{T}` : record of the trigger function evaluation.
"""
mutable struct TriggerLogger{A, T}
trigger::A
observable::T
n_steps::Int
history::Vector{T}
end


function TriggerLogger(trigger::ActiveLearningTrigger, T::DataType, n_steps::Integer)
return TriggerLogger{typeof(trigger), T}(trigger, n_steps, T[])
return TriggerLogger{typeof(trigger), T}(trigger, T[], n_steps, T[])
end
TriggerLogger(trigger::ActiveLearningTrigger, n_steps::Integer) = TriggerLogger(trigger, Float64, n_steps)


Base.values(logger::TriggerLogger) = logger.history


function log_property!(logger::TriggerLogger, s::System, neighbors=nothing,
step_n::Integer=0; n_threads::Integer=Threads.nthreads(), kwargs...)

obs = logger.trigger.eval(s)
logger.observable = obs

if (step_n % logger.n_steps) == 0
if typeof(logger.trigger) <: Union{UpperThreshold, LowerThreshold, MaxVol}
obs = logger.trigger.eval(; kwargs...)
if typeof(logger.trigger) <: Union{Bool, TimeInterval}
return
else
push!(logger.history, obs)
end
end
end


Base.values(logger::TriggerLogger) = logger.history


function Base.show(io::IO, fl::TriggerLogger)
print(io, "TriggerLogger{", eltype(fl.trigger), ", ", eltype(eltype(values(fl))), "} with n_steps ",
fl.n_steps, ", ", length(values(fl)), " frames recorded for ",
Expand Down

0 comments on commit 1df4946

Please sign in to comment.