Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions tests/models/test_orb.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
import traceback

import numpy as np
import pytest
from ase.geometry.cell import cell_to_cellpar as ase_c2p

from tests.conftest import DEVICE
from tests.models.conftest import (
consistency_test_simstate_fixtures,
make_model_calculator_consistency_test,
make_validate_model_outputs_test,
)
from torch_sim import SimState
from torch_sim.models.orb import cell_to_cellpar


try:
Expand Down Expand Up @@ -74,3 +78,16 @@ def orbv3_direct_20_omat_calculator() -> ORBCalculator:
test_validate_direct_model_outputs = make_validate_model_outputs_test(
model_fixture_name="orbv3_direct_20_omat_model",
)


def test_cell_to_cellpar(ti_sim_state: SimState) -> None:
assert np.allclose(
ase_c2p(ti_sim_state.row_vector_cell.squeeze()),
cell_to_cellpar(ti_sim_state.row_vector_cell.squeeze(0)).cpu().numpy(),
)
assert np.allclose(
ase_c2p(ti_sim_state.row_vector_cell.squeeze(), radians=True),
cell_to_cellpar(ti_sim_state.row_vector_cell.squeeze(0), radians=True)
.cpu()
.numpy(),
)
38 changes: 34 additions & 4 deletions torch_sim/models/orb.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@


try:
from ase.geometry import cell_to_cellpar
from orb_models.forcefield import featurization_utilities as feat_util
from orb_models.forcefield.atomic_system import SystemConfig
from orb_models.forcefield.base import AtomGraphs, _map_concat
Expand Down Expand Up @@ -59,6 +58,39 @@ def __init__(self, err: ImportError = exc, *_args: Any, **_kwargs: Any) -> None:
from torch_sim.typing import StateDict


def cell_to_cellpar(
cell: torch.Tensor,
radians: bool = False, # noqa: FBT001, FBT002
) -> torch.Tensor:
"""Returns the cell parameters [a, b, c, alpha, beta, gamma].
torch version of ase's cell_to_cellpar.

Args:
cell: lattice vector in row vector convention, same as ase
radians: If True, return angles in radians. Otherwise, return degrees (default).

Returns:
Tensor with [a, b, c, alpha, beta, gamma].
"""
lengths = torch.linalg.norm(cell, dim=1).squeeze()
angles = []
for i in range(3):
j = i - 1
k = i - 2
ll = lengths[j] * lengths[k]
if ll.item() > 1e-16:
cell_j = cell[j].squeeze()
cell_k = cell[k].squeeze()
x = torch.dot(cell_j, cell_k) / ll
angle = 180.0 / torch.pi * torch.arccos(x)
else:
angle = 90.0
angles.append(angle)
if radians:
angles = [angle * torch.pi / 180 for angle in angles]
return torch.concat((torch.tensor(lengths), torch.tensor(angles)))


def state_to_atom_graphs( # noqa: PLR0915
state: ts.SimState,
*,
Expand Down Expand Up @@ -181,9 +213,7 @@ def state_to_atom_graphs( # noqa: PLR0915
num_edges.append(len(edges[0]))

# Calculate lattice parameters
lattice_per_system = torch.from_numpy(
cell_to_cellpar(cell_per_system.squeeze(0).cpu().numpy())
)
lattice_per_system = cell_to_cellpar(cell_per_system.squeeze(0))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok I didn't see that the cell was extracted explicitly with the convention row_vector. When I say that the convention is not the same is that if A=Ase.atoms.cell[:] and B=SimState.cell, then A=B.T (transpose form). So in here, either you assume that the input is row and then you should have the same behaviour as ASE.cell_to_cellpar or you change that and then you have to change the input.
Given the current state of code, I suggest to use the same cell_to_cellpar convention as ASE, which is to assume that the input cell is a row_vector. Then modify the test accordingly

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, I have changed the function to use the row convention and also updated the test.


# Create features dictionaries
node_feats = {
Expand Down