Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add i-PI server calculator and update i-PI driver to AB v0.4 #31

Merged
merged 10 commits into from
Sep 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
210 changes: 188 additions & 22 deletions src/ipi/ipi_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,18 @@ using Unitful
using UnitfulAtomic

export run_driver
export IPIcalculator

const hdrlen = 12
const pos_type = typeof(SVector(1., 1., 1.)u"bohr") #should be bohr

# Define context untits for better conversions
const bohr = Unitful.ContextUnits(u"bohr", u"Å")
const hartree = Unitful.ContextUnits(u"hartree", u"eV")


const pos_type = typeof( zero( SVector{3, Float64} ) * bohr )
const force_el_type = typeof( zero( SVector{3, Float64} ) * (hartree/bohr) )
const virial_type = typeof( zero( SMatrix{3, 3, Float64} ) * hartree )

function sendmsg(comm, message; nbytes=hdrlen)
@info "Sending message" message
Expand Down Expand Up @@ -51,42 +60,94 @@ function recvinit(comm)
)
end

function send_init(comm)
@info "Sending INIT"
write(comm, one(Int32))
str = "ok"
write(comm, sizeof(str))
write(comm, str)
return true
end


function recvposdata(comm)
raw_cell = read(comm, 9*8)
raw_icell = read(comm, 9*8) # drop this (inverce cell)
raw_cell = read(comm, 9*sizeof(Float64))
raw_icell = read(comm, 9*sizeof(Float64)) # drop this (inverce cell)
natoms = read(comm, Int32)
raw_pos = read(comm, 8*3*natoms)
raw_pos = read(comm, sizeof(Float64)*3*natoms)
data_cell = reinterpret(pos_type, raw_cell)
data_pos = reinterpret(pos_type, raw_pos)
@info "Position data recieved"
return (;
:cell => Vector(data_cell),
:cell => Tuple(Vector(data_cell)),
:positions => Vector(data_pos) # clean type a little bit
)
end

function send_pos_data(comm, sys)
@info "Sending position data"
box_tmp = vcat(bounding_box(sys)...)
box = (Float64 ∘ ustrip).(u"bohr", box_tmp)
write(comm, box)
write(comm, zeros(3,3) ) #inverse cell that is to be dropped
l = length(sys)
write(comm, Int32(l) )
pos = map( 1:l ) do i
SVector{3, Float64}(ustrip.(u"bohr", position(sys, i)))
end
write(comm, pos)
@info "Position data sent"
return true
end


function sendforce(comm, e::Number, forces::AbstractVector, virial::AbstractMatrix)
etype = (eltype ∘ eltype)(forces)
f_tmp = reinterpret(reshape, etype, forces)
sendmsg(comm, "FORCEREADY")
write(comm, ustrip(u"hartree", e) )
write(comm, (Float64 ∘ ustrip)(u"hartree", e) )
write(comm, Int32( length(forces) ) )
write(comm, ustrip.(u"hartree/bohr", f_tmp) )
write(comm, ustrip.(u"hartree", virial) )
write(comm, (Float64 ∘ ustrip).(u"hartree/bohr", f_tmp) )
write(comm, (Float64 ∘ ustrip).(u"hartree", virial) )

# Send single byte at end to make sure we are alive
write(comm, one(Int32) )
write(comm, zero(UInt8) )
end


function recv_force(comm)
sendmsg(comm, "GETFORCE")
mess = recvmsg(comm)
if mess == "FORCEREADY"
@info "Recieving forces"
e = read(comm, Float64)
n = read(comm, Int32)
f_raw = read(comm, sizeof(Float64)*3*n)
v_raw = read(comm, sizeof(Float64)*9)

# Reading end message that is dropped
i = read(comm, Int32)
_ = read(comm, i)

f = reinterpret(force_el_type, f_raw) |> Vector
v = reinterpret(virial_type, v_raw)[1]
return (
energy = e * hartree,
forces = f,
virial = v
)
else
error("Expected \"FORCEREADY\", but received \"$mess\"")
end
end

"""
run_driver(address, calculator, init_structure; port=31415, unixsocket=false )
run_driver(address, calculator, init_structure; port=31415, unixsocket=false, basename="/tmp/ipi_" )

Connect I-PI driver to server at given `address`. Use kword `port` (default 31415) to
specify port. If kword `unixsocket` is true, `address` is understood to be the name of the socket
and `port` option is ignored.
specify port. If kword `unixsocket` is true, `basename*address` is understood to be the name of the socket
and `port` option is ignored.

You need to give initial structure as I-PI protocol does not transfer atom symbols.
This means that, if you want to change the number of atoms or their symbols, you need
Expand All @@ -95,22 +156,21 @@ to lauch a new driver.
Calculator events are logged at info level by default. If you do not want them to be logged,
change logging status for IPI module.
"""
function run_driver(address, calc, init_structure; port=31415, unixsocket=false )
function run_driver(address, calc, init_structure; port=31415, unixsocket=false, basename="/tmp/ipi_" )
if unixsocket
comm = connect("/tmp/ipi_"*address)
comm = connect(basename*address)
else
comm = connect(address, port)
end
has_init = true # we have init structure as an input
has_data = false
data = nothing

masses = atomic_mass(init_structure)
symbols = atomic_symbol(init_structure)
anumbers = atomic_number(init_structure)
positions = position(init_structure)
cell = bounding_box(init_structure)
pbc = boundary_conditions(init_structure)
pbc = periodicity(init_structure)
masses = mass(init_structure, :)
atom_species = species(init_structure, :)
positions = position(init_structure, :)
box = bounding_box(init_structure)


while true
Expand All @@ -132,9 +192,9 @@ function run_driver(address, calc, init_structure; port=31415, unixsocket=false
elseif header == "POSDATA"
pos = recvposdata(comm)
positions = pos[:positions]
cell = pos[:cell]
@assert length(symbols) == length(positions) "received amount of position data does no match the atomic symbol data"
system = FastSystem(cell, pbc, positions, symbols, anumbers, masses)
box = pos[:cell]
@assert length(atom_species) == length(positions) "received amount of position data does no match the atomic symbol data"
system = FastSystem(box, pbc, positions, atom_species, masses)
data = AtomsCalculators.energy_forces_virial(system, calc)
has_data = true
elseif header == "GETFORCE"
Expand All @@ -151,3 +211,109 @@ function run_driver(address, calc, init_structure; port=31415, unixsocket=false

end
end



## Server specific part

"""
IPIcalculator(address=ip"127.0.0.1"; port=31415, unixsocket=false, basename="/tmp/ipi_" )

Creates i-PI https://ipi-code.org/ server that works as an AtomsCalculators compatible calculators
once i-PI driver has been connected.

By default the calculator will log protocol calls to the client. If you want to suppress these,
you need to change the logging level of IPI module.

# Args
- `address=ip"127.0.0.1"` - server address, if `unixsocket=true` is considered as unixsocket address

# Kwargs
- `basename="/tmp/ipi_"` - prefixed to address if `unixsocket=true`
- `port=31415` - network port the server is using
- `unixsocket=false` - use unixsocket for the connection
"""
mutable struct IPIcalculator{TS, TC}
server::TS
sock::TC
function IPIcalculator(address=ip"127.0.0.1"; port=31415, unixsocket=false, basename="/tmp/ipi_" )
server, sock = start_ipi_server(address; port=port, unixsocket=unixsocket, basename=basename)
new{typeof(server), typeof(sock)}(server, sock)
end
end

function start_ipi_server(address; port=31415, unixsocket=false, basename="/tmp/ipi_", tries=5 )
@info "Starting i-PI server"
server = nothing
if unixsocket
server = listen(basename*address)
else
server = listen(address, port)
end
get_connection(server; tries=tries) # returns server, socket
end

function get_connection(server; tries=5)
sock = accept(server)
i = 1
while isopen(sock) || i < tries
sendmsg(sock, "STATUS")
mess = recvmsg(sock)
if mess == "NEEDINIT"
sendmsg(sock, "INIT")
send_init(sock)
continue
elseif mess == "READY"
return server, sock
else
i += 1
close(sock)
sock = accept(server)
end
end
error("Could not form a connection to a working i-PI driver")
end


function AtomsCalculators.energy_forces_virial(sys, ipi::IPIcalculator; kwargs...)
if ! isopen(ipi.sock)
@info "reconnecting to i-PI driver"
_, sock = get_connection(ipi.server)
ipi.sock = sock
end
sendmsg(ipi.sock, "POSDATA")
send_pos_data(ipi.sock, sys)
sendmsg(ipi.sock, "STATUS")
mess = recvmsg(ipi.sock)
if mess == "HAVEDATA"
return recv_force(ipi.sock)
else
error("Expected \"HAVEDATA\", but received \"$mess\"")
end
end


AtomsCalculators.@generate_interface function AtomsCalculators.potential_energy(sys, ipi::IPIcalculator; kwargs...)
tmp = AtomsCalculators.energy_forces_virial(sys, ipi)
return tmp.energy
end

AtomsCalculators.@generate_interface function AtomsCalculators.forces(sys, ipi::IPIcalculator; kwargs...)
tmp = AtomsCalculators.energy_forces_virial(sys, ipi)
return tmp.forces
end

AtomsCalculators.@generate_interface function AtomsCalculators.virial(sys, ipi::IPIcalculator; kwargs...)
tmp = AtomsCalculators.energy_forces_virial(sys, ipi)
return tmp.virial
end

function AtomsCalculators.energy_forces(sys, ipi::IPIcalculator; kwargs...)
tmp = AtomsCalculators.energy_forces_virial(sys, ipi)
return tmp
end



AtomsCalculators.energy_unit(::IPIcalculator) = hartree
AtomsCalculators.length_unit(::IPIcalculator) = bohr
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,6 @@ using Unitful
@testset "AtomsCalculatorsUtilities.jl" begin
@testset "Calculators" begin include("test_calculators.jl") end
@testset "PairPotentials" begin include("test_pairpotential.jl") end
@testset "i-PI driver and server" begin include("test_ipi.jl") end
# @testset "FD Tests" begin include("test_fdtests.jl") end
end
52 changes: 52 additions & 0 deletions test/test_ipi.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
using AtomsBase
using AtomsCalculators
using AtomsCalculators.Testing
using AtomsCalculatorsUtilities.IPI
using AtomsCalculatorsUtilities.PairPotentials
using Unitful
using Base.Threads


hydrogen = isolated_system([
:H => [0.1, 0, 0.]u"Å",
:H => [0, 0, 1.]u"Å",
:H => [4., 0, 0.]u"Å",
:H => [4., 1., 0.]u"Å"
])

box = (
[10.0, 0., 0.]u"Å",
[0.0, 10., 0.]u"Å",
[0.0, 0., 10.]u"Å",
)

pbc = (false, false, false)

hydrogen = FlexibleSystem(hydrogen[:], bounding_box=box, periodicity=pbc)


V = SimplePairPotential(
x-> (x-0.9)^2-1,
1,
1,
2.0u"Å"
)


ipi_future = @spawn IPIcalculator(port=33415)
sleep(1) # we need to yeald to start the server

ipi_driver = @spawn run_driver("127.0.0.1", V, hydrogen; port=33415)
sleep(1) # we need to yeald to connect to the server

calc = fetch(ipi_future)

##

test_energy_forces_virial(hydrogen, calc)

@test AtomsCalculators.potential_energy(hydrogen, V) ≈ AtomsCalculators.potential_energy(hydrogen, calc)
f_v = AtomsCalculators.forces(hydrogen, V)
f_ipi = AtomsCalculators.forces(hydrogen, calc)
@test all( isapprox.(f_v, f_ipi) )
@test AtomsCalculators.virial(hydrogen, V) ≈ AtomsCalculators.virial(hydrogen, calc)