Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix N shape bugs #157

Merged
merged 14 commits into from
Feb 22, 2022
21 changes: 16 additions & 5 deletions nequip/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,13 +537,18 @@ def _per_atom_statistics(
"""
# using unique_consecutive handles the non-contiguous selected batch index
_, N = torch.unique_consecutive(batch, return_counts=True)
N = N.unsqueeze(-1)
assert N.ndim == 2
assert N.shape == (len(arr), 1)
assert arr.ndim == 2
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we make N the same shape as arr from the 2nd axis?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what do you mean?

data_dim = arr.shape[1]
arr = arr / N
assert arr.shape == (len(N), data_dim)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
data_dim = arr.shape[1]
arr = arr / N
assert arr.shape == (len(N), data_dim)
data_dim = arr.shape[1:]
arr = arr / N
assert arr.shape == (len(N), data_dim)

if ana_mode == "mean_std":
arr = arr / N
mean = torch.mean(arr)
std = torch.std(arr, unbiased=unbiased)
mean = torch.mean(arr, dim=0)
std = torch.std(arr, unbiased=unbiased, dim=0)
return mean, std
elif ana_mode == "rms":
arr = arr / N
return (torch.sqrt(torch.mean(arr.square())),)
else:
raise NotImplementedError(
Expand All @@ -567,8 +572,9 @@ def _per_species_statistics(
For a per-node quantity, computes the expected statistic but for each type instead of over all nodes.
"""
N = bincount(atom_types.squeeze(-1), batch)
assert N.ndim == 2 # [batch, n_type]
N = N[(N > 0).any(dim=1)] # deal with non-contiguous batch indexes

assert arr.ndim >= 2
if arr_is_per == "graph":

if ana_mode != "mean_std":
Expand All @@ -585,10 +591,15 @@ def _per_species_statistics(

if ana_mode == "mean_std":
mean = scatter_mean(arr, atom_types, dim=0)
assert mean.shape[1:] == arr.shape[1:] # [N, dims] -> [type, dims]
assert len(mean) == N.shape[1]
std = scatter_std(arr, atom_types, dim=0, unbiased=unbiased)
assert std.shape == mean.shape
return mean, std
elif ana_mode == "rms":
square = scatter_mean(arr.square(), atom_types, dim=0)
assert square.shape[1:] == arr.shape[1:] # [N, dims] -> [type, dims]
assert len(square) == N.shape[1]
dims = len(square.shape) - 1
for i in range(dims):
square = square.mean(axis=-1)
Expand Down
4 changes: 3 additions & 1 deletion nequip/data/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,4 +150,6 @@ def format(
+ "]"
).format(*zip(type_names, data))
else:
raise ValueError
raise ValueError(
f"Don't know how to format data=`{data}` for types {type_names} with element_formatter=`{element_formatter}`"
)
4 changes: 2 additions & 2 deletions nequip/model/_scaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,14 +208,14 @@ def PerSpeciesRescale(

if isinstance(scales, str):
s = scales
scales = computed_stats[str_names.index(scales)]
scales = computed_stats[str_names.index(scales)].squeeze(-1) # energy is 1D
logging.info(f"Replace string {s} to {scales}")
elif isinstance(scales, (list, float)):
scales = torch.as_tensor(scales)

if isinstance(shifts, str):
s = shifts
shifts = computed_stats[str_names.index(shifts)]
shifts = computed_stats[str_names.index(shifts)].squeeze(-1) # energy is 1D
logging.info(f"Replace string {s} to {shifts}")
elif isinstance(shifts, (list, float)):
shifts = torch.as_tensor(shifts)
Expand Down
14 changes: 12 additions & 2 deletions nequip/train/_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def __call__(
)
if self.func_name == "MSELoss":
loss = loss / N
assert loss.shape == pred[key].shape # [atom, dim]
if mean:
return loss.sum() / not_nan.sum()
else:
Expand All @@ -89,6 +90,7 @@ def __call__(
loss = loss / N
if self.func_name == "MSELoss":
loss = loss / N
assert loss.shape == pred[key].shape # [atom, dim]
if mean:
return loss.mean()
else:
Expand Down Expand Up @@ -128,25 +130,33 @@ def __call__(
if has_nan:
if len(reduce_dims) > 0:
per_atom_loss = per_atom_loss.sum(dim=reduce_dims)
assert per_atom_loss.ndim == 1

per_species_loss = scatter(per_atom_loss, spe_idx, dim=0)

assert per_species_loss.ndim == 1 # [type]

N = scatter(not_nan, spe_idx, dim=0)
N = N.sum(reduce_dims)
N = 1.0 / N
N = N.reciprocal()
N_species = ((N == N).int()).sum()
assert N.ndim == 1 # [type]

per_species_loss = (per_species_loss * N).sum() / N_species

return (per_species_loss * N).sum() / N_species
return per_species_loss

else:

if len(reduce_dims) > 0:
per_atom_loss = per_atom_loss.mean(dim=reduce_dims)
assert per_atom_loss.ndim == 1

# offset species index by 1 to use 0 for nan
_, inverse_species_index = torch.unique(spe_idx, return_inverse=True)

per_species_loss = scatter_mean(per_atom_loss, inverse_species_index, dim=0)
assert per_species_loss.ndim == 1 # [type]

return per_species_loss.mean()

Expand Down
63 changes: 63 additions & 0 deletions tests/unit/data/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
NpzDataset,
ASEDataset,
dataset_from_config,
register_fields,
deregister_fields,
)
from nequip.data.transforms import TypeMapper
from nequip.utils import Config
Expand Down Expand Up @@ -188,6 +190,67 @@ def test_edgewise_stats(self, npz_dataset):
# TODO: check correct


class TestPerAtomStatistics:
@pytest.mark.parametrize("mode", ["mean_std", "rms"])
def test_per_node_field(self, npz_dataset, mode):
# set up the transformer
npz_dataset = set_up_transformer(npz_dataset, True, False, False)

with pytest.raises(ValueError) as excinfo:
npz_dataset.statistics(
[AtomicDataDict.BATCH_KEY],
modes=[f"per_atom_{mode}"],
)
assert (
excinfo
== f"It doesn't make sense to ask for `{mode}` since `{AtomicDataDict.BATCH_KEY}` is not per-graph"
)

@pytest.mark.parametrize("fixed_field", [True, False])
@pytest.mark.parametrize("full_rank", [True, False])
@pytest.mark.parametrize("subset", [True, False])
@pytest.mark.parametrize(
"key,dim", [(AtomicDataDict.TOTAL_ENERGY_KEY, (1,)), ("somekey", (3,))]
)
def test_per_graph_field(
self, npz_dataset, fixed_field, full_rank, subset, key, dim
):
if key == "somekey":
register_fields(graph_fields=[key])

npz_dataset = set_up_transformer(npz_dataset, full_rank, fixed_field, subset)
if npz_dataset is None:
return

E = torch.rand((npz_dataset.len(),) + dim)
ref_mean = torch.mean(E / NATOMS, dim=0)
ref_std = torch.std(E / NATOMS, dim=0)

if subset:
E_orig_order = torch.zeros(
(npz_dataset.data[AtomicDataDict.TOTAL_ENERGY_KEY].shape[0],) + dim
)
E_orig_order[npz_dataset._indices] = E
npz_dataset.data[key] = E_orig_order
else:
npz_dataset.data[key] = E

((mean, std),) = npz_dataset.statistics(
[key],
modes=["per_atom_mean_std"],
)

print("mean", mean, ref_mean)
print("diff in mean", mean - ref_mean)
print("std", std, ref_std)

assert torch.allclose(mean, ref_mean, rtol=1e-1)
assert torch.allclose(std, ref_std, rtol=1e-2)

if key == "somekey":
deregister_fields(key)


class TestPerSpeciesStatistics:
@pytest.mark.parametrize("fixed_field", [True, False])
@pytest.mark.parametrize("mode", ["mean_std", "rms"])
Expand Down