Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 42 additions & 7 deletions src/aiida/orm/nodes/data/array/trajectory.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,14 @@
###########################################################################
"""AiiDA class to deal with crystal structure trajectories."""

from __future__ import annotations

import collections.abc
from typing import List
from warnings import warn

from aiida.common.pydantic import MetadataField
from aiida.common.warnings import AiidaDeprecationWarning

from .array import ArrayData

Expand All @@ -33,7 +37,7 @@ def __init__(self, structurelist=None, **kwargs):
if structurelist is not None:
self.set_structurelist(structurelist)

def _internal_validate(self, stepids, cells, symbols, positions, times, velocities):
def _internal_validate(self, stepids, cells, symbols, positions, times, velocities, pbc):
"""Internal function to validate the type and shape of the arrays. See
the documentation of py:meth:`.set_trajectory` for a description of the
valid shape and type of the parameters.
Expand Down Expand Up @@ -82,8 +86,14 @@ def _internal_validate(self, stepids, cells, symbols, positions, times, velociti
'have shape (s,n,3), '
'with s=number of steps and n=number of symbols'
)

def set_trajectory(self, symbols, positions, stepids=None, cells=None, times=None, velocities=None):
if not (isinstance(pbc, (list, tuple)) and len(pbc) == 3 and all(isinstance(val, bool) for val in pbc)):
raise ValueError('`pbc` must be a list/tuple of length three with boolean values.')
if cells is None and list(pbc) != [False, False, False]:
raise ValueError('Periodic boundary conditions are only possible when a cell is defined.')

def set_trajectory(
self, symbols, positions, stepids=None, cells=None, times=None, velocities=None, pbc: None | list | tuple = None
):
r"""Store the whole trajectory, after checking that types and dimensions
are correct.

Expand Down Expand Up @@ -131,14 +141,28 @@ def set_trajectory(self, symbols, positions, stepids=None, cells=None, times=Non
:param velocities: if specified, must be a float array with the same
dimensions of the ``positions`` array.
The array contains the velocities in the atoms.
:param pbc: periodic boundary conditions of the structure. Should be a list of
length three with booleans indicating if the structure is periodic in that
direction. The same periodic boundary conditions are set for each step.

.. todo :: Choose suitable units for velocities
"""
import numpy

self._internal_validate(stepids, cells, symbols, positions, times, velocities)
# set symbols as attribute for easier querying
if cells is None:
pbc = pbc or [False, False, False]
elif pbc is None:
warn(
"When 'cells' is not None, the periodic boundary conditions should be explicitly specified via the "
"'pbc' keyword argument. Defaulting to '[True, True, True]', but this will raise in v3.0.0.",
AiidaDeprecationWarning,
)
pbc = [True, True, True]

self._internal_validate(stepids, cells, symbols, positions, times, velocities, pbc)
# set symbols/pbc as attributes for easier querying
self.base.attributes.set('symbols', list(symbols))
self.base.attributes.set('pbc', list(pbc))
self.set_array('positions', positions)
if stepids is not None: # use input stepids
self.set_array('steps', stepids)
Expand Down Expand Up @@ -189,7 +213,12 @@ def set_structurelist(self, structurelist):
raise ValueError('Symbol lists have to be the same for all of the supplied structures')
symbols = list(symbols_first)
positions = numpy.array([[list(s.position) for s in x.sites] for x in structurelist])
self.set_trajectory(stepids=stepids, cells=cells, symbols=symbols, positions=positions)
pbc_set = {structure.pbc for structure in structurelist}
if len(pbc_set) == 1:
pbc = pbc_set.pop()
else:
raise ValueError(f'All structures should have the same `pbc`, found: {pbc_set}')
self.set_trajectory(stepids=stepids, cells=cells, symbols=symbols, positions=positions, pbc=pbc)

def _validate(self):
"""Verify that the required arrays are present and that their type and
Expand All @@ -206,6 +235,7 @@ def _validate(self):
self.get_positions(),
self.get_times(),
self.get_velocities(),
self.pbc,
)
# Should catch TypeErrors, ValueErrors, and KeyErrors for missing arrays
except Exception as exception:
Expand Down Expand Up @@ -264,6 +294,11 @@ def symbols(self) -> List[str]:
"""
return self.base.attributes.get('symbols')

@property
def pbc(self) -> list[bool]:
"""Return the list of periodic boundary conditions."""
return self.base.attributes.get('pbc')

def get_positions(self):
"""Return the array of positions, if it has already been set.

Expand Down Expand Up @@ -384,7 +419,7 @@ def get_step_structure(self, index, custom_kinds=None):
'passed {}, but the symbols are {}'.format(sorted(kind_names), sorted(symbols))
)

struc = StructureData(cell=cell)
struc = StructureData(cell=cell, pbc=self.pbc)
if custom_kinds is not None:
for _k in custom_kinds:
struc.append_kind(_k)
Expand Down
101 changes: 97 additions & 4 deletions tests/orm/nodes/data/test_trajectory.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@
import numpy as np
import pytest

from aiida.common.warnings import AiidaDeprecationWarning
from aiida.orm import StructureData, TrajectoryData, load_node


@pytest.fixture
def trajectory_data():
"""Return a dictionary of data to create a ``TrajectoryData``."""
symbols = ['H'] * 5 + ['Cl'] * 5
pbc = [True, True, True]
stepids = np.arange(1000, 3000, 10)
times = stepids * 0.01
positions = np.arange(6000, dtype=float).reshape((200, 10, 3))
Expand All @@ -23,6 +25,7 @@ def trajectory_data():
'cells': cells,
'times': times,
'velocities': velocities,
'pbc': pbc,
}


Expand Down Expand Up @@ -107,10 +110,11 @@ def test_trajectory_get_step_data(self, trajectory_data):
stepid, time, cell, symbols, positions, velocities = trajectory.get_step_data(-2)
assert stepid == trajectory_data['stepids'][-2]
assert time == trajectory_data['times'][-2]
np.array_equal(cell, trajectory_data['cells'][-2, :, :])
np.array_equal(symbols, trajectory_data['symbols'])
np.array_equal(positions, trajectory_data['positions'][-2, :, :])
np.array_equal(velocities, trajectory_data['velocities'][-2, :, :])
assert np.array_equal(cell, trajectory_data['cells'][-2, :, :])
assert np.array_equal(symbols, trajectory_data['symbols'])
assert np.array_equal(trajectory.pbc, trajectory_data['pbc'])
assert np.array_equal(positions, trajectory_data['positions'][-2, :, :])
assert np.array_equal(velocities, trajectory_data['velocities'][-2, :, :])

def test_trajectory_get_step_data_empty(self, trajectory_data):
"""Test the `get_step_data` method when some arrays are not defined."""
Expand All @@ -123,6 +127,8 @@ def test_trajectory_get_step_data_empty(self, trajectory_data):
assert np.array_equal(symbols, trajectory_data['symbols'])
assert np.array_equal(positions, trajectory_data['positions'][3, :, :])
assert velocities is None
# In case the cell is not defined, there should be no periodic boundary conditions
assert np.array_equal(trajectory.pbc, [False, False, False])

def test_trajectory_get_step_structure(self, trajectory_data):
"""Test the `get_step_structure` method."""
Expand All @@ -141,3 +147,90 @@ def test_trajectory_get_step_structure(self, trajectory_data):

with pytest.raises(IndexError):
trajectory.get_step_structure(500)

def test_trajectory_pbc_structures(self, trajectory_data):
"""Test the `pbc` for the `TrajectoryData` using structure inputs."""
# Test non-pbc structure with no cell
structure = StructureData(cell=None, pbc=[False, False, False])
structure.append_atom(position=[0.0, 0.0, 0.0], symbols='H')

trajectory = TrajectoryData(structurelist=(structure,))

trajectory.get_step_structure(0).store() # Verify that the `StructureData` can be stored
assert trajectory.get_step_structure(0).pbc == structure.pbc

# Test failure for incorrect pbc
trajectory_data_incorrect = trajectory_data.copy()
trajectory_data_incorrect['pbc'] = [0, 0, 0]
with pytest.raises(ValueError, match='`pbc` must be a list/tuple of length three with boolean values'):
trajectory = TrajectoryData()
trajectory.set_trajectory(**trajectory_data_incorrect)

# Test failure when structures have different pbc
cell = [[3.0, 0.1, 0.3], [-0.05, 3.0, -0.2], [0.02, -0.08, 3.0]]
structure_periodic = StructureData(cell=cell)
structure_periodic.append_atom(position=[0.0, 0.0, 0.0], symbols='H')
structure_non_periodic = StructureData(cell=cell, pbc=[False, False, False])
structure_non_periodic.append_atom(position=[0.0, 0.0, 0.0], symbols='H')

with pytest.raises(ValueError, match='All structures should have the same `pbc`'):
TrajectoryData(structurelist=(structure_periodic, structure_non_periodic))

def test_trajectory_pbc_set_trajectory(self):
"""Test the `pbc` for the `TrajectoryData` using `set_trajectory`."""
data = {
'symbols': ['H'],
'positions': np.array(
[
[
[0.0, 0.0, 0.0],
]
]
),
}
trajectory = TrajectoryData()

data.update(
{
'cells': None,
'pbc': None,
}
)
trajectory.set_trajectory(**data)
assert trajectory.get_step_structure(0).pbc == (False, False, False)

data.update(
{
'cells': None,
'pbc': [False, False, False],
}
)
trajectory.set_trajectory(**data)
assert trajectory.get_step_structure(0).pbc == (False, False, False)

data.update(
{
'cells': None,
'pbc': [True, False, False],
}
)
with pytest.raises(ValueError, match='Periodic boundary conditions are only possible when a cell is defined'):
trajectory.set_trajectory(**data)

data.update(
{
'cells': np.array([[[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]]),
'pbc': None,
}
)
with pytest.warns(AiidaDeprecationWarning, match="When 'cells' is not None, the periodic"):
trajectory.set_trajectory(**data)

data.update(
{
'cells': np.array([[[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]]),
'pbc': (True, False, False),
}
)
trajectory.set_trajectory(**data)
assert trajectory.get_step_structure(0).pbc == (True, False, False)