Skip to content

Commit ac70fb7

Browse files
committed
[WIP] 🐛 TrajectoryData: Add pbc
1 parent d50cf17 commit ac70fb7

File tree

2 files changed

+52
-14
lines changed

2 files changed

+52
-14
lines changed

src/aiida/orm/nodes/data/array/trajectory.py

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
###########################################################################
99
"""AiiDA class to deal with crystal structure trajectories."""
1010

11+
from __future__ import annotations
12+
1113
import collections.abc
1214
from typing import List
1315

@@ -33,7 +35,7 @@ def __init__(self, structurelist=None, **kwargs):
3335
if structurelist is not None:
3436
self.set_structurelist(structurelist)
3537

36-
def _internal_validate(self, stepids, cells, symbols, positions, times, velocities):
38+
def _internal_validate(self, stepids, cells, symbols, positions, times, velocities, pbc):
3739
"""Internal function to validate the type and shape of the arrays. See
3840
the documentation of py:meth:`.set_trajectory` for a description of the
3941
valid shape and type of the parameters.
@@ -82,8 +84,12 @@ def _internal_validate(self, stepids, cells, symbols, positions, times, velociti
8284
'have shape (s,n,3), '
8385
'with s=number of steps and n=number of symbols'
8486
)
87+
if not (isinstance(pbc, (list, tuple)) and len(pbc) == 3 and all(isinstance(val, bool) for val in pbc)):
88+
raise ValueError('`pbc` must be a list/tuple of length three with boolean values.')
8589

86-
def set_trajectory(self, symbols, positions, stepids=None, cells=None, times=None, velocities=None):
90+
def set_trajectory(
91+
self, symbols, positions, stepids=None, cells=None, times=None, velocities=None, pbc: None | list | tuple = None
92+
):
8793
r"""Store the whole trajectory, after checking that types and dimensions
8894
are correct.
8995
@@ -131,14 +137,20 @@ def set_trajectory(self, symbols, positions, stepids=None, cells=None, times=Non
131137
:param velocities: if specified, must be a float array with the same
132138
dimensions of the ``positions`` array.
133139
The array contains the velocities in the atoms.
140+
:param pbc: periodic boundary conditions of the structure. Should be a list of
141+
length three with booleans indicating if the structure is periodic in that
142+
direction. The same periodic boundary conditions are set for each step.
134143
135144
.. todo :: Choose suitable units for velocities
136145
"""
137146
import numpy
138147

139-
self._internal_validate(stepids, cells, symbols, positions, times, velocities)
140-
# set symbols as attribute for easier querying
148+
pbc = pbc or [True, True, True]
149+
150+
self._internal_validate(stepids, cells, symbols, positions, times, velocities, pbc)
151+
# set symbols/pbc as attributes for easier querying
141152
self.base.attributes.set('symbols', list(symbols))
153+
self.base.attributes.set('pbc', list(pbc))
142154
self.set_array('positions', positions)
143155
if stepids is not None: # use input stepids
144156
self.set_array('steps', stepids)
@@ -189,7 +201,12 @@ def set_structurelist(self, structurelist):
189201
raise ValueError('Symbol lists have to be the same for all of the supplied structures')
190202
symbols = list(symbols_first)
191203
positions = numpy.array([[list(s.position) for s in x.sites] for x in structurelist])
192-
self.set_trajectory(stepids=stepids, cells=cells, symbols=symbols, positions=positions)
204+
pbc_set = {structure.pbc for structure in structurelist}
205+
if len(pbc_set) == 1:
206+
pbc = pbc_set.pop()
207+
else:
208+
raise ValueError(f'All structures should have the same `pbc`, found: {pbc_set}')
209+
self.set_trajectory(stepids=stepids, cells=cells, symbols=symbols, positions=positions, pbc=pbc)
193210

194211
def _validate(self):
195212
"""Verify that the required arrays are present and that their type and
@@ -206,6 +223,7 @@ def _validate(self):
206223
self.get_positions(),
207224
self.get_times(),
208225
self.get_velocities(),
226+
self.pbc,
209227
)
210228
# Should catch TypeErrors, ValueErrors, and KeyErrors for missing arrays
211229
except Exception as exception:
@@ -264,6 +282,11 @@ def symbols(self) -> List[str]:
264282
"""
265283
return self.base.attributes.get('symbols')
266284

285+
@property
286+
def pbc(self) -> list[bool]:
287+
"""Return the list of periodic boundary conditions."""
288+
return self.base.attributes.get('pbc')
289+
267290
def get_positions(self):
268291
"""Return the array of positions, if it has already been set.
269292
@@ -338,7 +361,7 @@ def get_step_data(self, index):
338361
cell = cells[index, :, :]
339362
else:
340363
cell = None
341-
return (self.get_stepids()[index], time, cell, self.symbols, self.get_positions()[index, :, :], vel)
364+
return (self.get_stepids()[index], time, cell, self.symbols, self.get_positions()[index, :, :], vel, self.pbc)
342365

343366
def get_step_structure(self, index, custom_kinds=None):
344367
"""Return an AiiDA :py:class:`aiida.orm.nodes.data.structure.StructureData` node
@@ -364,7 +387,7 @@ def get_step_structure(self, index, custom_kinds=None):
364387
from aiida.orm.nodes.data.structure import Kind, Site, StructureData
365388

366389
# ignore step, time, and velocities
367-
_, _, cell, symbols, positions, _ = self.get_step_data(index)
390+
_, _, cell, symbols, positions, _, pbc = self.get_step_data(index)
368391

369392
if custom_kinds is not None:
370393
kind_names = []
@@ -384,7 +407,7 @@ def get_step_structure(self, index, custom_kinds=None):
384407
'passed {}, but the symbols are {}'.format(sorted(kind_names), sorted(symbols))
385408
)
386409

387-
struc = StructureData(cell=cell)
410+
struc = StructureData(cell=cell, pbc=pbc)
388411
if custom_kinds is not None:
389412
for _k in custom_kinds:
390413
struc.append_kind(_k)

tests/orm/nodes/data/test_trajectory.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
def trajectory_data():
1111
"""Return a dictionary of data to create a ``TrajectoryData``."""
1212
symbols = ['H'] * 5 + ['Cl'] * 5
13+
pbc = [True, True, True]
1314
stepids = np.arange(1000, 3000, 10)
1415
times = stepids * 0.01
1516
positions = np.arange(6000, dtype=float).reshape((200, 10, 3))
@@ -23,6 +24,7 @@ def trajectory_data():
2324
'cells': cells,
2425
'times': times,
2526
'velocities': velocities,
27+
'pbc': pbc,
2628
}
2729

2830

@@ -104,23 +106,25 @@ def test_trajectory_get_step_data(self, trajectory_data):
104106
"""Test the ``get_step_data`` method."""
105107
trajectory = TrajectoryData()
106108
trajectory.set_trajectory(**trajectory_data)
107-
stepid, time, cell, symbols, positions, velocities = trajectory.get_step_data(-2)
109+
stepid, time, cell, symbols, positions, velocities, pbc = trajectory.get_step_data(-2)
108110
assert stepid == trajectory_data['stepids'][-2]
109111
assert time == trajectory_data['times'][-2]
110-
np.array_equal(cell, trajectory_data['cells'][-2, :, :])
111-
np.array_equal(symbols, trajectory_data['symbols'])
112-
np.array_equal(positions, trajectory_data['positions'][-2, :, :])
113-
np.array_equal(velocities, trajectory_data['velocities'][-2, :, :])
112+
assert np.array_equal(cell, trajectory_data['cells'][-2, :, :])
113+
assert np.array_equal(symbols, trajectory_data['symbols'])
114+
assert np.array_equal(pbc, trajectory_data['pbc'])
115+
assert np.array_equal(positions, trajectory_data['positions'][-2, :, :])
116+
assert np.array_equal(velocities, trajectory_data['velocities'][-2, :, :])
114117

115118
def test_trajectory_get_step_data_empty(self, trajectory_data):
116119
"""Test the `get_step_data` method when some arrays are not defined."""
117120
trajectory = TrajectoryData()
118121
trajectory.set_trajectory(symbols=trajectory_data['symbols'], positions=trajectory_data['positions'])
119-
stepid, time, cell, symbols, positions, velocities = trajectory.get_step_data(3)
122+
stepid, time, cell, symbols, positions, velocities, pbc = trajectory.get_step_data(3)
120123
assert stepid == 3
121124
assert time is None
122125
assert cell is None
123126
assert np.array_equal(symbols, trajectory_data['symbols'])
127+
assert np.array_equal(pbc, trajectory_data['pbc'])
124128
assert np.array_equal(positions, trajectory_data['positions'][3, :, :])
125129
assert velocities is None
126130

@@ -141,3 +145,14 @@ def test_trajectory_get_step_structure(self, trajectory_data):
141145

142146
with pytest.raises(IndexError):
143147
trajectory.get_step_structure(500)
148+
149+
def test_trajectory_pbc(self):
150+
"""Test the `pbc` for the `TrajectoryData` ."""
151+
trajectory = TrajectoryData()
152+
structure = StructureData(cell=None, pbc=[False, False, False])
153+
structure.append_atom(position=[0.0, 0.0, 0.0], symbols='H')
154+
155+
trajectory = TrajectoryData(structurelist=(structure,))
156+
157+
trajectory.get_step_structure(0).store() # Verify that the `StructureData` can be stored
158+
assert trajectory.get_step_structure(0).pbc == structure.pbc

0 commit comments

Comments
 (0)