88###########################################################################
99"""AiiDA class to deal with crystal structure trajectories."""
1010
11+ from __future__ import annotations
12+
1113import collections .abc
1214from typing import List
1315
@@ -33,7 +35,7 @@ def __init__(self, structurelist=None, **kwargs):
3335 if structurelist is not None :
3436 self .set_structurelist (structurelist )
3537
36- def _internal_validate (self , stepids , cells , symbols , positions , times , velocities ):
38+ def _internal_validate (self , stepids , cells , symbols , positions , times , velocities , pbc ):
3739 """Internal function to validate the type and shape of the arrays. See
3840 the documentation of py:meth:`.set_trajectory` for a description of the
3941 valid shape and type of the parameters.
@@ -82,8 +84,12 @@ def _internal_validate(self, stepids, cells, symbols, positions, times, velociti
8284 'have shape (s,n,3), '
8385 'with s=number of steps and n=number of symbols'
8486 )
87+ if not (isinstance (pbc , (list , tuple )) and len (pbc ) == 3 and all (isinstance (val , bool ) for val in pbc )):
88+ raise ValueError ('`pbc` must be a list/tuple of length three with boolean values.' )
8589
86- def set_trajectory (self , symbols , positions , stepids = None , cells = None , times = None , velocities = None ):
90+ def set_trajectory (
91+ self , symbols , positions , stepids = None , cells = None , times = None , velocities = None , pbc : None | list | tuple = None
92+ ):
8793 r"""Store the whole trajectory, after checking that types and dimensions
8894 are correct.
8995
@@ -131,14 +137,20 @@ def set_trajectory(self, symbols, positions, stepids=None, cells=None, times=Non
131137 :param velocities: if specified, must be a float array with the same
132138 dimensions of the ``positions`` array.
133139 The array contains the velocities in the atoms.
140+ :param pbc: periodic boundary conditions of the structure. Should be a list of
141+ length three with booleans indicating if the structure is periodic in that
142+ direction. The same periodic boundary conditions are set for each step.
134143
135144 .. todo :: Choose suitable units for velocities
136145 """
137146 import numpy
138147
139- self ._internal_validate (stepids , cells , symbols , positions , times , velocities )
140- # set symbols as attribute for easier querying
148+ pbc = pbc or [True , True , True ]
149+
150+ self ._internal_validate (stepids , cells , symbols , positions , times , velocities , pbc )
151+ # set symbols/pbc as attributes for easier querying
141152 self .base .attributes .set ('symbols' , list (symbols ))
153+ self .base .attributes .set ('pbc' , list (pbc ))
142154 self .set_array ('positions' , positions )
143155 if stepids is not None : # use input stepids
144156 self .set_array ('steps' , stepids )
@@ -189,7 +201,12 @@ def set_structurelist(self, structurelist):
189201 raise ValueError ('Symbol lists have to be the same for all of the supplied structures' )
190202 symbols = list (symbols_first )
191203 positions = numpy .array ([[list (s .position ) for s in x .sites ] for x in structurelist ])
192- self .set_trajectory (stepids = stepids , cells = cells , symbols = symbols , positions = positions )
204+ pbc_set = {structure .pbc for structure in structurelist }
205+ if len (pbc_set ) == 1 :
206+ pbc = pbc_set .pop ()
207+ else :
208+ raise ValueError (f'All structures should have the same `pbc`, found: { pbc_set } ' )
209+ self .set_trajectory (stepids = stepids , cells = cells , symbols = symbols , positions = positions , pbc = pbc )
193210
194211 def _validate (self ):
195212 """Verify that the required arrays are present and that their type and
@@ -206,6 +223,7 @@ def _validate(self):
206223 self .get_positions (),
207224 self .get_times (),
208225 self .get_velocities (),
226+ self .pbc ,
209227 )
210228 # Should catch TypeErrors, ValueErrors, and KeyErrors for missing arrays
211229 except Exception as exception :
@@ -264,6 +282,11 @@ def symbols(self) -> List[str]:
264282 """
265283 return self .base .attributes .get ('symbols' )
266284
285+ @property
286+ def pbc (self ) -> list [bool ]:
287+ """Return the list of periodic boundary conditions."""
288+ return self .base .attributes .get ('pbc' )
289+
267290 def get_positions (self ):
268291 """Return the array of positions, if it has already been set.
269292
@@ -338,7 +361,7 @@ def get_step_data(self, index):
338361 cell = cells [index , :, :]
339362 else :
340363 cell = None
341- return (self .get_stepids ()[index ], time , cell , self .symbols , self .get_positions ()[index , :, :], vel )
364+ return (self .get_stepids ()[index ], time , cell , self .symbols , self .get_positions ()[index , :, :], vel , self . pbc )
342365
343366 def get_step_structure (self , index , custom_kinds = None ):
344367 """Return an AiiDA :py:class:`aiida.orm.nodes.data.structure.StructureData` node
@@ -364,7 +387,7 @@ def get_step_structure(self, index, custom_kinds=None):
364387 from aiida .orm .nodes .data .structure import Kind , Site , StructureData
365388
366389 # ignore step, time, and velocities
367- _ , _ , cell , symbols , positions , _ = self .get_step_data (index )
390+ _ , _ , cell , symbols , positions , _ , pbc = self .get_step_data (index )
368391
369392 if custom_kinds is not None :
370393 kind_names = []
@@ -384,7 +407,7 @@ def get_step_structure(self, index, custom_kinds=None):
384407 'passed {}, but the symbols are {}' .format (sorted (kind_names ), sorted (symbols ))
385408 )
386409
387- struc = StructureData (cell = cell )
410+ struc = StructureData (cell = cell , pbc = pbc )
388411 if custom_kinds is not None :
389412 for _k in custom_kinds :
390413 struc .append_kind (_k )
0 commit comments