-
Notifications
You must be signed in to change notification settings - Fork 61
Enable MD workflows for any ASE calculator #39
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
500681a
7457f3f
67e852a
348055e
d65510f
a69e992
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I like this idea, and actually I have implemented this kind of thing for my |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,71 @@ | ||
| from __future__ import annotations | ||
| from copy import deepcopy | ||
| from typing import Optional, Callable | ||
|
|
||
| from ase.calculators.calculator import Calculator, all_changes | ||
| from ase.symbols import Symbols | ||
| import numpy as np | ||
| import torch | ||
|
|
||
| from nff.io.ase_calcs import AtomsBatch | ||
|
|
||
|
|
||
|
|
||
| class Potential(torch.nn.Module): | ||
| pass | ||
|
|
||
|
|
||
| class AsePotential(Potential): | ||
|
|
||
| def __init__(self, calculator: Calculator, embedding_fun: Optional[Callable[[AtomsBatch], torch.Tensor]] = None) \ | ||
| -> None: | ||
| super().__init__() | ||
| self.calculator = calculator | ||
| self.embedding_fun = embedding_fun | ||
|
|
||
| def __call__(self, batch: dict, **kwargs): | ||
HojeChun marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| properties = ["energy"] | ||
| if kwargs.get("requires_stress", False): | ||
| properties.append("stress") | ||
| if kwargs.get("requires_forces", False): | ||
| properties.append("forces") | ||
| if kwargs.get("requires_dipole", False): | ||
| properties.append("dipole") | ||
| if kwargs.get("requires_charges", False): | ||
| properties.append("charges") | ||
| if kwargs.get("requires_embedding", False): | ||
| if self.embedding_fun is None: | ||
| raise RuntimeError("Required embedding but no embedding function provided.") | ||
| embedding = self.embedding_fun(batch) | ||
| else: | ||
| embedding = None | ||
|
|
||
| nxyz = batch.get("nxyz") | ||
| if nxyz is None: | ||
| raise RuntimeError("Batch is missing 'nxyz' key.") | ||
| pbc = batch.get("pbc") | ||
| if pbc is not None: | ||
| pbc = np.array(pbc, dtype=bool) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes exactly. It's basically the |
||
| cell = np.array(batch.get("cell")).reshape(3, 3) | ||
| else: | ||
| cell = None | ||
| atoms_batch = AtomsBatch( | ||
| symbols=Symbols(nxyz[:,0].detach().cpu().numpy()), | ||
| positions=nxyz[:, 1:4].detach().cpu().numpy(), | ||
| pbc=pbc, | ||
| cell=cell, | ||
| device=batch.get("device", "cpu") | ||
| ) | ||
|
|
||
| self.calculator.calculate(atoms_batch, properties=properties, system_changes=all_changes) | ||
| results = deepcopy(self.calculator.results) | ||
| for key, value in results.items(): | ||
| if isinstance(value, str): | ||
| continue | ||
| if not hasattr(value, "__iter__"): | ||
| results[key] = torch.tensor([value], device=atoms_batch.device) | ||
| else: | ||
| results[key] = torch.tensor(value, device=atoms_batch.device) | ||
| if embedding is not None: | ||
| results["embedding"] = embedding | ||
| return results | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am not sure why you made changes for this. if you need load_foundations_path function as class methods, I guess you keep the original function and make
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't quite understand what you mean. As far as I understand your comment, this is what I did
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i mean Sorry for a wild example above |
Uh oh!
There was an error while loading. Please reload this page.