Skip to content

Commit

Permalink
basic calculators
Browse files Browse the repository at this point in the history
  • Loading branch information
tjjarvinen committed Mar 29, 2024
1 parent fc01ac7 commit 0781381
Show file tree
Hide file tree
Showing 7 changed files with 433 additions and 2 deletions.
4 changes: 4 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,14 @@ authors = ["Teemu Järvinen <teemu.j.jarvinen@gmail.com> and contributors"]
version = "0.1.0-DEV"

[deps]
AtomsBase = "a963bdd2-2df7-4f54-a1ee-49d51e6be12a"
AtomsCalculators = "a3e0e189-c65a-42c1-833c-339540406eb1"
Folds = "41a02a25-b8f0-4f67-bc48-60067656b558"

[compat]
AtomsBase = "0.3"
AtomsCalculators = "0.1"
Folds = "0.2.10"
julia = "1.9"

[extras]
Expand Down
8 changes: 7 additions & 1 deletion src/AtomsUtilityCalculators.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
module AtomsUtilityCalculators

# Write your package code here.
using AtomsCalculators
using AtomsBase


include("combination_calculator.jl")
include("reporting_calculator.jl")
include("subsystem_calculator.jl")

end
122 changes: 122 additions & 0 deletions src/combination_calculator.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
using Folds

export CombinationCalculator


"""
generate_keywords
This function is called when `CombinationCalculator` is used.
Default implementation will only pass keywords forward.
The call type is AtomsBase system first then all calculators and kwargs.
This will allow you to extend based on calculator type.
# Example
```julia
function AtomsUtilityCalculators.generate_keywords(sys, pp1::PairPotential, pp2::PairPotential; kwargs...)
if cutoff_radius(pp1) ≈ cutoff_radius(pp2)
nlist = PairList(sys, cutoff_radius(pp1))
return (; :nlist => nlist, kwargs...)
else
return kwargs
end
end
```
will check that PairPotentials have same cutoff radius.
Then calculates pairlist and passes it forward as a keyword.
"""
generate_keywords(sys, calculators...; kwargs...) = kwargs


"""
CombinationCalculator{N}
You can combine several calculators to one calculator with this.
Giving keyword argument `executor=SequentialEx()` toggles on multithdeaded execution
of calculators. Using `executor=DistributedEx()` executes calculators using multiprocessing.
Other use case is editing keywords that are passed on the calculators.
E.g. you can generate new keyword argument that is then passed to all calculators.
This allows you to share e.g. a pairlist between calculators.
To control what keywords are passed you need to extend `generate_keywords` function.
# Fields
- calculators::NTuple{N,Any} : NTuple that holds calculators
- executor::Any : Transducers executor used to execute calculation - default SequentialEx
- keywords::Function : function used to generate keywords for calculators
# Creation
```julia
CombinationCalculator( calc1, calc2, ...; executor=SequentialEx())
```
"""
mutable struct CombinationCalculator{N, T} # Mutable struct so that calculators can mutate themself
calculators::NTuple{N,Any}
executor::Any
keywords::Function
function CombinationCalculator(calculators...; executor=SequentialEx(), keyword_generator=nothing)
kgen = something(keyword_generator, generate_keywords)
new{length(calculators), typeof(kgen)}(calculators, executor, kgen)
end
end

function Base.show(io::IO, ::MIME"text/plain", calc::CombinationCalculator)
print(io, "CombinationCalculator - ", length(calc) , " calculators")
end

Base.length(cc::CombinationCalculator) = length(cc.calculators)

Base.getindex(cc::CombinationCalculator, i) = cc.calculators[i]
Base.lastindex(cc::CombinationCalculator) = length(cc)
Base.firstindex(cc::CombinationCalculator) = 1


AtomsCalculators.@generate_interface function AtomsCalculators.potential_energy(sys, calc::CombinationCalculator; kwargs...)
new_kwargs = calc.keywords(sys, calc.calculators...; kwargs...)
return Folds.sum( calc.calculators ) do c
AtomsCalculators.potential_energy(sys, c; new_kwargs...)
end
end

# We don't use AtomsCalculators.@generate_interface here
# as we want special version for forces!
function AtomsCalculators.forces(sys, calc::CombinationCalculator; kwargs...)
new_kwargs = calc.keywords(sys, calc.calculators...; kwargs...)
return Folds.sum( calc.calculators ) do c
AtomsCalculators.forces(sys, c; new_kwargs...)
end
end


function AtomsCalculators.calculate( ::AtomsCalculators.Forces, sys, calc::CombinationCalculator; kwargs...)
f = AtomsCalculators.forces(sys, calc; kwargs...)
return (; :forces => f)
end


function AtomsCalculators.forces!(f, sys, calc::CombinationCalculator; kwargs...)
new_kwargs = calc.keywords(sys, calc.calculators...; kwargs...)

# Nonallocating forces is only truly nonallocating when sequential
foreach( calc.calculators ) do cal
AtomsCalculators.forces!(f, sys, cal; new_kwargs...)
end
return f
end


AtomsCalculators.@generate_interface function AtomsCalculators.virial(sys, calc::CombinationCalculator; kwargs...)
new_kwargs = calc.keywords(sys, calc.calculators...; kwargs...)
return Folds.sum( calc.calculators ) do c
AtomsCalculators.virial(sys, c; new_kwargs...)
end
end
142 changes: 142 additions & 0 deletions src/reporting_calculator.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@

export ReportingCalculator

"""
generate_message(sys, calculator, calc_result; kwargs...) = calc_result
This is the default function that is called when `ReportingCalculator` collects
a message. Extending this allows you to control what is reported.
This function is ment to allow setting of global stetting. If you want to
set reporting function for an individual case, give `ReportingCalculator` keyword
`message_function=my_report` where `my_report` is function that returns your message.
If function returns `nothing` the message is ignored. You can use this to control
when message is sent.
"""
generate_message(sys, calculator, calc_result; kwargs...) = calc_result


"""
ReportingCalculator{T, TC, TF}
`ReportingCalculator` collects information during calculation
and sent it to a `Channel` that can be read.
# Fields
- `calculator::T` : caculator used in calculations
- `channel::Channel{TC}` : `Channel` where message is put
- `message::TF` : function that generates the message
# Creation
```julia
rcalc = ReportingCalculator(calculator, Channel(32))
rcalc = ReportingCalculator(calculator, Channel(32); message_function=my_message_function)
```
When `message_function` is omitted, `generate_message` function is used. See it for more details on how to control generated messages.
You can access the channel by calling calculator directly with `fetch` or `take!`.
"""
mutable struct ReportingCalculator{T}
calculator::T
channel::AbstractChannel
message::Function
function ReportingCalculator(
calc,
channel::AbstractChannel=Channel();
message_function=nothing
)
message = something(message_function, generate_message)
new{typeof(calc)}(calc, channel, message)
end
end


function Base.show(io::IO, ::MIME"text/plain", calc::ReportingCalculator)
print(io, "ReportingCalculator")
end

Base.fetch(rcalc::ReportingCalculator) = fetch(rcalc.channel)
Base.take!(rcalc::ReportingCalculator) = take!(rcalc.channel)

AtomsCalculators.zero_forces(sys, calc::ReportingCalculator) = AtomsCalculators.zero_forces(sys, calc.calculator)
AtomsCalculators.promote_force_type(sys, calc::ReportingCalculator) = AtomsCalculators.promote_force_type(sys, calc.calculator)


function AtomsCalculators.potential_energy(
sys,
calc::ReportingCalculator;
kwargs...
)
e = AtomsCalculators.potential_energy(sys, calc.calculator; kwargs...)
mess = calc.message(sys, calc.calculator, e; kwargs...)
if ! isnothing(mess)
put!(calc.channel, mess)
end
return e
end


function AtomsCalculators.virial(
sys,
calc::ReportingCalculator;
kwargs...
)
v = AtomsCalculators.virial(sys, calc.calculator; kwargs...)
mess = calc.message(sys, calc.calculator, v; kwargs...)
if ! isnothing(mess)
put!(calc.channel, mess)
end
return v
end


function AtomsCalculators.forces(
sys,
calc::ReportingCalculator;
kwargs...
)
f = AtomsCalculators.forces(sys, calc.calculator; kwargs...)
mess = calc.message(sys, calc.calculator, f; kwargs...)
if ! isnothing(mess)
put!(calc.channel, mess)
end
return f
end


function AtomsCalculators.forces!(
f,
sys,
calc::ReportingCalculator;
kwargs...
)
fout = AtomsCalculators.forces!(f, sys, calc.calculator; kwargs...)
mess = calc.message(sys, calc.calculator, fout; kwargs...)
if ! isnothing(mess)
put!(calc.channel, mess)
end
return fout
end


function AtomsCalculators.calculate(
calc_method::Union{
AtomsCalculators.Energy,
AtomsCalculators.Forces,
AtomsCalculators.Virial
},
sys,
calc::ReportingCalculator;
kwargs...
)
tmp = AtomsCalculators.calculate(calc_method, sys, calc.calculator; kwargs...)
mess = calc.message(sys, calc.calculator, tmp; kwargs...)
if ! isnothing(mess)
put!(calc.channel, mess)
end
return tmp
end
69 changes: 69 additions & 0 deletions src/subsystem_calculator.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@

export SubSystemCalculator

"""
SubSystemCalculator{T, TC}
Submits subsystem to given calculator.
The purpose of this calculator is that you can split system to smaller
system that can then be calculated with e.g. with different methods.
One possible use case here is QM/MM calculations where you can split
QM system out.
The structrure is mutable to allow mutable calculators.
# Fields
- `calculator::T` : calculator which is used for the subsystem calculation
- `subsys::TC` : definition of subsystem like array of indices - has to be iterable
"""
mutable struct SubSystemCalculator{T, TC} # Mutable struct so that calculator can mutate inself
calculator::T
subsys::TC
function SubSystemCalculator(calc, subsys)
@assert applicable(first, subsys) "subsys is not iterable"
new{typeof(calc), typeof(subsys)}(calc, subsys)
end
end

function Base.show(io::IO, ::MIME"text/plain", calc::SubSystemCalculator)
print(io, "SubSystemCalculator - subsystem size = ", length(calc.subsys))
end

AtomsCalculators.zero_forces(sys, calc::SubSystemCalculator) = AtomsCalculators.zero_forces(sys, calc.calculator)
AtomsCalculators.promote_force_type(sys, calc::SubSystemCalculator) = AtomsCalculators.promote_force_type(sys, calc.calculator)


function _generate_subsys(sys, calc::SubSystemCalculator)
@assert length(sys) >= length(calc.subsys)
sub_atoms = [ sys[i] for i in calc.subsys ]
sub_sys = FlexibleSystem(
sub_atoms;
[ k => sys[k] for k in keys(sys) ]...
)
return sub_sys
end


AtomsCalculators.@generate_interface function AtomsCalculators.potential_energy(sys, calc::SubSystemCalculator; kwargs...)
sub_sys = _generate_subsys(sys, calc)
return AtomsCalculators.potential_energy(sub_sys, calc.calculator; kwargs...)
end


AtomsCalculators.@generate_interface function AtomsCalculators.forces!(f, sys, calc::SubSystemCalculator; kwargs...)
@assert length(f) == length(sys)
sub_sys = _generate_subsys(sys, calc)
tmp_f = AtomsCalculators.zero_forces(sub_sys, calc)
AtomsCalculators.forces!(tmp_f, sub_sys, calc.calculator; kwargs...)
#TODO this wont work for GPU Arrays
for (i, val) in zip(calc.subsys, tmp_f)
f[i] += val
end
return f
end

AtomsCalculators.@generate_interface function AtomsCalculators.virial(sys, calc::SubSystemCalculator; kwargs...)
sub_sys = _generate_subsys(sys, calc)
return AtomsCalculators.virial(sub_sys, calc.calculator; kwargs...)
end
6 changes: 6 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
[deps]
AtomsBase = "a963bdd2-2df7-4f54-a1ee-49d51e6be12a"
AtomsCalculators = "a3e0e189-c65a-42c1-833c-339540406eb1"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
Loading

0 comments on commit 0781381

Please sign in to comment.