Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 22 additions & 10 deletions ipsuite/base/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ class ProcessAtoms(IPSNode):
"""

data: list[ase.Atoms] = zntrack.deps()
data_file: str = zntrack.dvc.deps(None)
data_file: str = zntrack.deps_path(None)
data_id: typing.Optional[int] = None # zntrack.params(None)
atoms: list[ase.Atoms] = fields.Atoms()

def _post_init_(self):
Expand All @@ -50,17 +51,22 @@ def update_data(self):
def get_data(self) -> list[ase.Atoms]:
"""Get the atoms data to process."""
if self.data is not None:
return self.data
data = self.data
elif self.data_file is not None:
try:
with self.state.fs.open(pathlib.Path(self.data_file).as_posix()) as f:
return list(ase.io.iread(f))
data = list(ase.io.iread(f))
except FileNotFoundError:
# File can not be opened with DVCFileSystem, try normal open
return list(ase.io.iread(self.data_file))
data = list(ase.io.iread(self.data_file))
else:
raise ValueError("No data given.")

if self.data_id is not None:
return [data[self.data_id]]
else:
return data


class ProcessSingleAtom(IPSNode):
"""Protocol for objects that process a single atom.
Expand Down Expand Up @@ -88,8 +94,8 @@ class ProcessSingleAtom(IPSNode):
"""

data: typing.Union[ase.Atoms, typing.List[ase.Atoms]] = zntrack.deps()
data_file: str = zntrack.dvc.deps(None)
data_id: typing.Optional[int] = zntrack.zn.params(0)
data_file: str = zntrack.deps_path(None)
data_id: typing.Optional[int] = zntrack.params(0)

atoms: typing.List[ase.Atoms] = fields.Atoms()

Expand Down Expand Up @@ -134,10 +140,16 @@ class AnalyseProcessAtoms(IPSNode):
"""Analyse the output of a ProcessAtoms Node."""

data: ProcessAtoms = zntrack.deps()
reference: ProcessAtoms = zntrack.deps(None)

def get_data(self) -> typing.Tuple[list[ase.Atoms], list[ase.Atoms]]:
self.data.update_data() # otherwise, data might not be available
return self.data.data, self.data.atoms

if self.reference is None:
self.data.update_data() # otherwise, data might not be available
return self.data.data, self.data.atoms
else:
# TODO: support both, Nodes and Connections
return self.reference, self.data


class Mapping(ProcessAtoms):
Expand All @@ -160,8 +172,8 @@ class Mapping(ProcessAtoms):
The indices of the molecules will be frozen for all configurations.
"""

molecules: list[ase.Atoms] = zntrack.zn.outs()
frozen: bool = zntrack.zn.params(False)
molecules: list[ase.Atoms] = zntrack.outs()
frozen: bool = zntrack.params(False)

# TODO, should we allow to transfer the frozen mapping to another node?
# mapping = Mapping(frozen=True, reference=mapping)
Expand Down