-
Notifications
You must be signed in to change notification settings - Fork 28
refactor: Refactor trajectory data loading and improve the docstrings #289
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
0164e2c
46c7f94
d23e630
180d4a7
c6d2b89
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -167,66 +167,105 @@ def from_text_data(cls, | |
| @classmethod | ||
| def from_ase_traj(cls, | ||
| root: str, | ||
| get_Hamiltonian = False, | ||
| get_overlap = False, | ||
| get_DM = False, | ||
| get_eigenvalues = False, | ||
| info = None): | ||
|
|
||
| assert not get_Hamiltonian * get_DM, "Hamiltonian and Density Matrix can only loaded one at a time, for which will occupy the same attribute in the AtomicData." | ||
|
|
||
| get_Hamiltonian: bool = False, | ||
| get_overlap: bool = False, | ||
| get_DM: bool = False, | ||
| get_eigenvalues: bool = False, | ||
| info: Optional[Dict] = None): | ||
| ''' | ||
| Build the _TrajData instance by reading the data from the single data directory | ||
| that organized in the way compatible with the ASE. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| root: str | ||
| The folder where the data is stored, including the traj file. | ||
| get_Hamiltonian: bool | ||
| Whether to load the hamiltonian blocks. | ||
| get_overlap: bool | ||
| Whether to load the overlap blocks. | ||
| get_DM: bool | ||
| Whether to load the density matrix blocks. | ||
| get_eigenvalues: bool | ||
| Whether to load the eigenvalues. | ||
| info: dict | ||
| The description of the data, may be inconsistent with the real data, in this case, | ||
| the info will be updated. | ||
| ''' | ||
| assert not get_Hamiltonian * get_DM, \ | ||
| "Hamiltonian and Density Matrix can only loaded one at a time, " + \ | ||
| "for which will occupy the same attribute in the AtomicData." | ||
|
|
||
| # read the ase trajectory file... | ||
| traj_file = glob.glob(f"{root}/*.traj") | ||
| assert len(traj_file) == 1, print("only one ase trajectory file can be provided.") | ||
| traj = Trajectory(traj_file[0], 'r') | ||
| nframes = len(traj) | ||
| assert nframes > 0, print("trajectory file is empty.") | ||
| if nframes != info.get("nframes", None): | ||
| info['nframes'] = nframes | ||
| log.info(f"Number of frames ({nframes}) in trajectory file does not match the number of frames in info file.") | ||
|
|
||
| natoms = traj[0].positions.shape[0] | ||
| if natoms != info["natoms"]: | ||
| info["natoms"] = natoms | ||
|
|
||
| pbc = info.get("pbc",None) | ||
| if pbc is None: | ||
| pbc = traj[0].pbc.tolist() | ||
| info["pbc"] = pbc | ||
|
|
||
| if isinstance(pbc, bool): | ||
| pbc = [pbc] * 3 | ||
|
|
||
| if pbc != traj[0].pbc.tolist(): | ||
| log.warning("!! PBC setting in info file does not match the PBC setting in trajectory file, we use the one in info json. BE CAREFUL!") | ||
|
|
||
| positions = [] | ||
| cell = [] | ||
| atomic_numbers = [] | ||
|
|
||
| for atoms in traj: | ||
| positions.append(atoms.get_positions()) | ||
| assert len(traj_file) == 1, "only one ase trajectory file can be provided." | ||
|
|
||
| atomic_numbers, positions, cell = None, None, None | ||
| # use the context manager to avoid memory leak | ||
| with Trajectory(traj_file[0], 'r') as traj: | ||
| # there are some dimensions of importance: nframe, natom, ... | ||
| nframes = len(traj) | ||
| assert nframes > 0, "trajectory file is empty." | ||
| # if there is discrepancy between nframes in info and traj, then update the info. | ||
| if nframes != info.get("nframes", None): | ||
| info['nframes'] = nframes | ||
| log.info(f"Number of frames ({nframes}) in trajectory file does not match the info file.") | ||
|
|
||
| atomic_numbers.append(atoms.get_atomic_numbers()) | ||
| if (np.abs(atoms.get_cell()-np.zeros([3,3]))< 1e-6).all(): | ||
| cell = None | ||
| else: | ||
| cell.append(atoms.get_cell()) | ||
|
|
||
| positions = np.array(positions) | ||
| positions = positions.reshape(nframes,natoms, 3) | ||
|
|
||
| if cell is not None: | ||
| cell = np.array(cell) | ||
| cell = cell.reshape(nframes,3, 3) | ||
|
|
||
| atomic_numbers = np.array(atomic_numbers) | ||
| atomic_numbers = atomic_numbers.reshape(nframes, natoms) | ||
|
|
||
| data = {} | ||
| if cell is not None: | ||
| data["cell"] = cell | ||
| data["pos"] = positions | ||
| data["atomic_numbers"] = atomic_numbers | ||
| # assuming there will not be number of atoms change within the trajectory file... | ||
| # we check, because the trajectory file does support this. | ||
| natoms = np.unique([len(atoms) for atoms in traj]) | ||
| assert len(natoms) == 1, "Number of atoms in trajectory file is not consistent." | ||
| natoms = natoms[0] | ||
| # natoms = traj[0].positions.shape[0] | ||
| if natoms != info["natoms"]: | ||
| info["natoms"] = natoms | ||
| log.info(f"Number of atoms ({natoms}) in trajectory file does not match the info file.") | ||
|
|
||
| # handling the pbc flag | ||
| pbc = info.get("pbc", None) | ||
| if pbc is None: | ||
| # read from the trajectory...however, the same issue also exists here, the pbc may | ||
| # change along the trajectory, so we need to check it (only allow one pbc setting) | ||
| pbc = np.unique([atoms.pbc.tolist() for atoms in traj], axis=0) | ||
| assert len(pbc) == 1, "PBC setting in trajectory file is not consistent." | ||
| pbc = pbc[0] | ||
| assert isinstance(pbc, list) and len(pbc) == 3, \ | ||
| f"Unexpected `PBC` format: {pbc}" | ||
| info["pbc"] = pbc | ||
| # check on the value of pbc | ||
| if isinstance(pbc, bool): | ||
| pbc = [pbc] * 3 | ||
| if pbc != traj[0].pbc.tolist(): | ||
| log.warning("!! PBC setting in info file does not match the PBC setting in trajectory file, " | ||
| "we use the one in info json. BE CAREFUL!") | ||
|
|
||
| # overwrite the following three to the empty lists | ||
| atomic_numbers, positions, cell = [], [], [] | ||
| for atoms in traj: | ||
| atoms: Atoms # type annotation :) | ||
|
|
||
| atomic_numbers.append(atoms.get_atomic_numbers()) | ||
| positions.append(atoms.get_positions()) | ||
| # if there is no cell information, then set it to None. However, | ||
| # there is also the case that the invalidity of cell is reflected | ||
| # by the cell being all zeros. | ||
| cell_read = atoms.get_cell() | ||
| cell.append(None if np.allclose(cell_read, np.zeros((3, 3)), atol=1e-6) else cell_read) | ||
|
|
||
| # the trajectory reading must be successful and not empty | ||
| assert positions | ||
| assert atomic_numbers | ||
| assert cell | ||
|
|
||
| # this may raise errors about the inhomogenity of the data, or the reshape failed | ||
| data = {"pos": np.array(positions).reshape(nframes, natoms, 3), | ||
| "atomic_numbers": np.array(atomic_numbers).reshape(nframes, natoms)} | ||
| assert len(cell) == nframes | ||
| if all(c is not None for c in cell): | ||
| data["cell"] = np.array(cell).reshape(nframes, 3, 3) | ||
| else: | ||
| # otherwise, we expect that all cells are None, the hybrid case is not allowed | ||
| assert all(c is None for c in cell) | ||
|
|
||
| return cls(root=root, | ||
| data=data, | ||
|
|
@@ -318,18 +357,47 @@ def toAtomicDataList(self, idp: TypeMapper = None): | |
|
|
||
| class DefaultDataset(AtomicInMemoryDataset): | ||
|
|
||
| def __init__( | ||
| self, | ||
| root: str, | ||
| info_files: Dict[str, Dict], | ||
| url: Optional[str] = None, # seems useless but can't be remove | ||
| include_frames: Optional[List[int]] = None, # maybe support in future | ||
| type_mapper: TypeMapper = None, | ||
| get_Hamiltonian: bool = False, | ||
| get_overlap: bool = False, | ||
| get_DM: bool = False, | ||
| get_eigenvalues: bool = False, | ||
| ): | ||
| def __init__(self, | ||
| root: str, | ||
| info_files: Dict[str, Dict], | ||
| url: Optional[str] = None, # seems useless but can't be remove | ||
| include_frames: Optional[List[int]] = None, # maybe support in future | ||
| type_mapper: TypeMapper = None, | ||
| get_Hamiltonian: bool = False, | ||
| get_overlap: bool = False, | ||
| get_DM: bool = False, | ||
| get_eigenvalues: bool = False): | ||
| ''' | ||
| instantiate the default dataset. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| root : str | ||
| root directory of the dataset. | ||
| info_files : Dict[str, Dict] | ||
| the description of all the "valid" subfolders in the root directory, here the | ||
| "valid" means there are data files in the subfolder. | ||
| url : Optional[str], optional | ||
| not used in DeePTB. see its super class | ||
| include_frames : Optional[List[int]], optional | ||
| not used in DeePTB. see its super class | ||
| type_mapper: TypeMapper, optional | ||
| the mapping from orbpair index to reduced matrix element, see docstrings of class | ||
| OrbitalMapper for more information | ||
| get_Hamiltonian : bool, optional | ||
| whether to get the Hamiltonian, by default False | ||
| get_overlap : bool, optional | ||
| whether to get the overlap, by default False | ||
| get_DM : bool, optional | ||
| whether to get the density matrix, by default False | ||
| get_eigenvalues : bool, optional | ||
| whether to get the eigenvalues, by default False | ||
| ''' | ||
| def build_data(pos_typ: str, **kwargs): | ||
| builder = {'ase': _TrajData.from_ase_traj} | ||
| build_func = builder.get(pos_typ, _TrajData.from_text_data) | ||
| return build_func(**kwargs) | ||
|
|
||
| self.root = root | ||
| self.url = url | ||
| self.info_files = info_files | ||
|
|
@@ -345,24 +413,16 @@ def __init__( | |
| # get the info here | ||
| info = info_files[file] | ||
| # assert "AtomicData_options" in info | ||
| assert "r_max" in info | ||
| assert "pbc" in info | ||
| pbc = info["pbc"] | ||
| if info["pos_type"] == "ase": | ||
| subdata = _TrajData.from_ase_traj(os.path.join(self.root, file), | ||
| get_Hamiltonian, | ||
| get_overlap, | ||
| get_DM, | ||
| get_eigenvalues, | ||
| info=info) | ||
| else: | ||
| subdata = _TrajData.from_text_data(os.path.join(self.root, file), | ||
| get_Hamiltonian, | ||
| get_overlap, | ||
| get_DM, | ||
| get_eigenvalues, | ||
| info=info) | ||
| self.raw_data.append(subdata) | ||
| assert all(attr in info for attr in ["r_max", "pbc"]) | ||
| pbc = info["pbc"] # not used? | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove unused variable assignment. Static analysis correctly identifies that Apply this diff: assert all(attr in info for attr in ["r_max", "pbc"])
- pbc = info["pbc"] # not used?
self.raw_data.append(🧰 Tools🪛 Ruff (0.14.7)417-417: Local variable Remove assignment to unused variable (F841) 🤖 Prompt for AI Agents |
||
| self.raw_data.append( | ||
| build_data(pos_typ=info["pos_type"], | ||
| root=os.path.join(self.root, file), | ||
| get_Hamiltonian=get_Hamiltonian, | ||
| get_overlap=get_overlap, | ||
| get_DM=get_DM, | ||
| get_eigenvalues=get_eigenvalues, | ||
| info=info)) | ||
|
|
||
| # The AtomicData_options is never used here. | ||
| # Because we always return a list of AtomicData object in `get_data()`. | ||
|
|
@@ -381,6 +441,7 @@ def get_data(self): | |
| for subdata in tqdm(self.raw_data, desc="Loading data"): | ||
| # the type_mapper here is loaded in PyG `dataset` type as `transform` attritube | ||
| # so the OrbitalMapper can be accessed by self.transform here | ||
| subdata: _TrajData | ||
| subdata_list = subdata.toAtomicDataList(self.transform) | ||
| all_data += subdata_list | ||
| return all_data | ||
|
|
@@ -566,4 +627,4 @@ def _E3nodespecies_stat(self, typed_dataset): | |
| "scalar_std": typed_scalar_std, | ||
| } | ||
|
|
||
| return edge_stats | ||
| return edge_stats | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix PBC type mismatch after np.unique.
After
np.unique(..., axis=0)at line 229,pbcis a numpy array. Line 231'spbc[0]extracts the first row, which is still a numpy array (1D). However, line 232 assertsisinstance(pbc, list), which will fail. Additionally, line 238's comparisonpbc != traj[0].pbc.tolist()will raise an error because comparing a numpy array with!=to a list produces an array of bools, causing ambiguous truth value issues.🔎 Apply this diff to fix the type issues:
pbc = np.unique([atoms.pbc.tolist() for atoms in traj], axis=0) assert len(pbc) == 1, "PBC setting in trajectory file is not consistent." - pbc = pbc[0] + pbc = pbc[0].tolist() assert isinstance(pbc, list) and len(pbc) == 3, \ f"Unexpected `PBC` format: {pbc}" info["pbc"] = pbc # check on the value of pbc if isinstance(pbc, bool): pbc = [pbc] * 3 - if pbc != traj[0].pbc.tolist(): + if pbc != traj[0].pbc.tolist(): # Now both are lists log.warning("!! PBC setting in info file does not match the PBC setting in trajectory file, " "we use the one in info json. BE CAREFUL!")