Skip to content

Commit c263b2b

Browse files
committed
[WIP] 🐛 TrajectoryData: Add pbc
1 parent d50cf17 commit c263b2b

File tree

1 file changed

+28
-7
lines changed

1 file changed

+28
-7
lines changed

src/aiida/orm/nodes/data/array/trajectory.py

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
###########################################################################
99
"""AiiDA class to deal with crystal structure trajectories."""
1010

11+
from __future__ import annotations
12+
1113
import collections.abc
1214
from typing import List
1315

@@ -33,7 +35,7 @@ def __init__(self, structurelist=None, **kwargs):
3335
if structurelist is not None:
3436
self.set_structurelist(structurelist)
3537

36-
def _internal_validate(self, stepids, cells, symbols, positions, times, velocities):
38+
def _internal_validate(self, stepids, cells, symbols, positions, times, velocities, pbc):
3739
"""Internal function to validate the type and shape of the arrays. See
3840
the documentation of py:meth:`.set_trajectory` for a description of the
3941
valid shape and type of the parameters.
@@ -82,8 +84,12 @@ def _internal_validate(self, stepids, cells, symbols, positions, times, velociti
8284
'have shape (s,n,3), '
8385
'with s=number of steps and n=number of symbols'
8486
)
87+
if not (isinstance(pbc, (list, tuple)) and len(pbc) == 3 and all(isinstance(val, bool) for val in pbc)):
88+
raise ValueError('`pbc` must be a list/tuple of length three with boolean values.')
8589

86-
def set_trajectory(self, symbols, positions, stepids=None, cells=None, times=None, velocities=None):
90+
def set_trajectory(
91+
self, symbols, positions, stepids=None, cells=None, times=None, velocities=None, pbc: None | list | tuple = None
92+
):
8793
r"""Store the whole trajectory, after checking that types and dimensions
8894
are correct.
8995
@@ -131,14 +137,19 @@ def set_trajectory(self, symbols, positions, stepids=None, cells=None, times=Non
131137
:param velocities: if specified, must be a float array with the same
132138
dimensions of the ``positions`` array.
133139
The array contains the velocities in the atoms.
140+
:param pbc: periodic boundary conditions of the structure. Should be a list of
141+
length three with booleans indicating if the structure is periodic in that
142+
direction. The same periodic boundary conditions are set for each step.
134143
135144
.. todo :: Choose suitable units for velocities
136145
"""
137146
import numpy
138147

139-
self._internal_validate(stepids, cells, symbols, positions, times, velocities)
148+
self._internal_validate(stepids, cells, symbols, positions, times, velocities, pbc)
140149
# set symbols as attribute for easier querying
141150
self.base.attributes.set('symbols', list(symbols))
151+
pbc = pbc or [True, True, True]
152+
self.base.attributes.set('pbc', list(pbc))
142153
self.set_array('positions', positions)
143154
if stepids is not None: # use input stepids
144155
self.set_array('steps', stepids)
@@ -189,7 +200,12 @@ def set_structurelist(self, structurelist):
189200
raise ValueError('Symbol lists have to be the same for all of the supplied structures')
190201
symbols = list(symbols_first)
191202
positions = numpy.array([[list(s.position) for s in x.sites] for x in structurelist])
192-
self.set_trajectory(stepids=stepids, cells=cells, symbols=symbols, positions=positions)
203+
pbc_set = {structure.pbc for structure in structurelist}
204+
if len(pbc_set) == 1:
205+
pbc = pbc_set.pop()
206+
else:
207+
raise ValueError(f'All structures should have the same `pbc`, found: {pbc_set}')
208+
self.set_trajectory(stepids=stepids, cells=cells, symbols=symbols, positions=positions, pbc=pbc)
193209

194210
def _validate(self):
195211
"""Verify that the required arrays are present and that their type and
@@ -264,6 +280,11 @@ def symbols(self) -> List[str]:
264280
"""
265281
return self.base.attributes.get('symbols')
266282

283+
@property
284+
def pbc(self) -> list[bool]:
285+
"""Return the list of periodic boundary conditions."""
286+
return self.base.attributes.get('pbc')
287+
267288
def get_positions(self):
268289
"""Return the array of positions, if it has already been set.
269290
@@ -338,7 +359,7 @@ def get_step_data(self, index):
338359
cell = cells[index, :, :]
339360
else:
340361
cell = None
341-
return (self.get_stepids()[index], time, cell, self.symbols, self.get_positions()[index, :, :], vel)
362+
return (self.get_stepids()[index], time, cell, self.symbols, self.get_positions()[index, :, :], vel, self.pbc)
342363

343364
def get_step_structure(self, index, custom_kinds=None):
344365
"""Return an AiiDA :py:class:`aiida.orm.nodes.data.structure.StructureData` node
@@ -364,7 +385,7 @@ def get_step_structure(self, index, custom_kinds=None):
364385
from aiida.orm.nodes.data.structure import Kind, Site, StructureData
365386

366387
# ignore step, time, and velocities
367-
_, _, cell, symbols, positions, _ = self.get_step_data(index)
388+
_, _, cell, symbols, positions, _, pbc = self.get_step_data(index)
368389

369390
if custom_kinds is not None:
370391
kind_names = []
@@ -384,7 +405,7 @@ def get_step_structure(self, index, custom_kinds=None):
384405
'passed {}, but the symbols are {}'.format(sorted(kind_names), sorted(symbols))
385406
)
386407

387-
struc = StructureData(cell=cell)
408+
struc = StructureData(cell=cell, pbc=pbc)
388409
if custom_kinds is not None:
389410
for _k in custom_kinds:
390411
struc.append_kind(_k)

0 commit comments

Comments
 (0)