88###########################################################################
99"""AiiDA class to deal with crystal structure trajectories."""
1010
11+ from __future__ import annotations
12+
1113import collections .abc
1214from typing import List
1315
@@ -33,7 +35,7 @@ def __init__(self, structurelist=None, **kwargs):
3335 if structurelist is not None :
3436 self .set_structurelist (structurelist )
3537
36- def _internal_validate (self , stepids , cells , symbols , positions , times , velocities ):
38+ def _internal_validate (self , stepids , cells , symbols , positions , times , velocities , pbc ):
3739 """Internal function to validate the type and shape of the arrays. See
3840 the documentation of py:meth:`.set_trajectory` for a description of the
3941 valid shape and type of the parameters.
@@ -82,8 +84,12 @@ def _internal_validate(self, stepids, cells, symbols, positions, times, velociti
8284 'have shape (s,n,3), '
8385 'with s=number of steps and n=number of symbols'
8486 )
87+ if not (isinstance (pbc , (list , tuple )) and len (pbc ) == 3 and all (isinstance (val , bool ) for val in pbc )):
88+ raise ValueError ('`pbc` must be a list/tuple of length three with boolean values.' )
8589
86- def set_trajectory (self , symbols , positions , stepids = None , cells = None , times = None , velocities = None ):
90+ def set_trajectory (
91+ self , symbols , positions , stepids = None , cells = None , times = None , velocities = None , pbc : None | list | tuple = None
92+ ):
8793 r"""Store the whole trajectory, after checking that types and dimensions
8894 are correct.
8995
@@ -131,14 +137,19 @@ def set_trajectory(self, symbols, positions, stepids=None, cells=None, times=Non
131137 :param velocities: if specified, must be a float array with the same
132138 dimensions of the ``positions`` array.
133139 The array contains the velocities in the atoms.
140+ :param pbc: periodic boundary conditions of the structure. Should be a list of
141+ length three with booleans indicating if the structure is periodic in that
142+ direction. The same periodic boundary conditions are set for each step.
134143
135144 .. todo :: Choose suitable units for velocities
136145 """
137146 import numpy
138147
139- self ._internal_validate (stepids , cells , symbols , positions , times , velocities )
148+ self ._internal_validate (stepids , cells , symbols , positions , times , velocities , pbc )
140149 # set symbols as attribute for easier querying
141150 self .base .attributes .set ('symbols' , list (symbols ))
151+ pbc = pbc or [True , True , True ]
152+ self .base .attributes .set ('pbc' , list (pbc ))
142153 self .set_array ('positions' , positions )
143154 if stepids is not None : # use input stepids
144155 self .set_array ('steps' , stepids )
@@ -189,7 +200,12 @@ def set_structurelist(self, structurelist):
189200 raise ValueError ('Symbol lists have to be the same for all of the supplied structures' )
190201 symbols = list (symbols_first )
191202 positions = numpy .array ([[list (s .position ) for s in x .sites ] for x in structurelist ])
192- self .set_trajectory (stepids = stepids , cells = cells , symbols = symbols , positions = positions )
203+ pbc_set = {structure .pbc for structure in structurelist }
204+ if len (pbc_set ) == 1 :
205+ pbc = pbc_set .pop ()
206+ else :
207+ raise ValueError (f'All structures should have the same `pbc`, found: { pbc_set } ' )
208+ self .set_trajectory (stepids = stepids , cells = cells , symbols = symbols , positions = positions , pbc = pbc )
193209
194210 def _validate (self ):
195211 """Verify that the required arrays are present and that their type and
@@ -264,6 +280,11 @@ def symbols(self) -> List[str]:
264280 """
265281 return self .base .attributes .get ('symbols' )
266282
283+ @property
284+ def pbc (self ) -> list [bool ]:
285+ """Return the list of periodic boundary conditions."""
286+ return self .base .attributes .get ('pbc' )
287+
267288 def get_positions (self ):
268289 """Return the array of positions, if it has already been set.
269290
@@ -338,7 +359,7 @@ def get_step_data(self, index):
338359 cell = cells [index , :, :]
339360 else :
340361 cell = None
341- return (self .get_stepids ()[index ], time , cell , self .symbols , self .get_positions ()[index , :, :], vel )
362+ return (self .get_stepids ()[index ], time , cell , self .symbols , self .get_positions ()[index , :, :], vel , self . pbc )
342363
343364 def get_step_structure (self , index , custom_kinds = None ):
344365 """Return an AiiDA :py:class:`aiida.orm.nodes.data.structure.StructureData` node
@@ -364,7 +385,7 @@ def get_step_structure(self, index, custom_kinds=None):
364385 from aiida .orm .nodes .data .structure import Kind , Site , StructureData
365386
366387 # ignore step, time, and velocities
367- _ , _ , cell , symbols , positions , _ = self .get_step_data (index )
388+ _ , _ , cell , symbols , positions , _ , pbc = self .get_step_data (index )
368389
369390 if custom_kinds is not None :
370391 kind_names = []
@@ -384,7 +405,7 @@ def get_step_structure(self, index, custom_kinds=None):
384405 'passed {}, but the symbols are {}' .format (sorted (kind_names ), sorted (symbols ))
385406 )
386407
387- struc = StructureData (cell = cell )
408+ struc = StructureData (cell = cell , pbc = pbc )
388409 if custom_kinds is not None :
389410 for _k in custom_kinds :
390411 struc .append_kind (_k )
0 commit comments