-
Notifications
You must be signed in to change notification settings - Fork 28
Sparse Hr to HR forming #312
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
89d7210
a73ab33
c8c3317
4fb5728
f8325d3
1f8a30d
90aa5c6
e6552c5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -4,6 +4,7 @@ | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from dptb.plugins.monitor import TrainLossMonitor, LearningRateMonitor, Validationer, TensorBoardMonitor | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from dptb.plugins.train_logger import Logger | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from dptb.utils.argcheck import normalize, collect_cutoffs, chk_avg_per_iter | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from dptb.utils.orbital_parser import parse_orbital_file | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from dptb.plugins.saver import Saver | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from typing import Dict, List, Optional, Any | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from dptb.utils.tools import j_loader, setup_seed, j_must_have | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -90,6 +91,35 @@ def train( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| jdata = j_loader(INPUT) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| jdata = normalize(jdata) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Validate and process orbital files in basis | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if jdata.get("common_options") and jdata["common_options"].get("basis"): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| orbital_files_content = {} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for elem, value in jdata["common_options"]["basis"].items(): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if isinstance(value, str) and os.path.isfile(value): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # strict check for e3tb method | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Check if model_options exists and has prediction method | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| model_opts = jdata.get("model_options", {}) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| pred_opts = model_opts.get("prediction", {}) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # normalize might handle defaults, but safely check here | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if pred_opts.get("method", "e3tb") != "e3tb": | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| raise ValueError(f"Orbital files in 'basis' are only supported for the 'e3tb' method. Found method: {pred_opts.get('method')}") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Also checking if we are in a mixed model which might be different, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # but usually orbital files for basis imply the main basis handling. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # transform logic | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| parsed_basis = parse_orbital_file(value) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| with open(value, 'r') as f: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| orbital_files_content[elem] = f.read() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| jdata["common_options"]["basis"][elem] = parsed_basis | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| log.info(f"Parsed orbital file for {elem}: {value} -> {parsed_basis}") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| except Exception as e: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| raise ValueError(f"Failed to parse orbital file {value} for element {elem}: {e}") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if orbital_files_content: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| jdata["common_options"]["orbital_files_content"] = orbital_files_content | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
106
to
122
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Inconsistent indentation will cause Several lines have off-by-one (or more) indentation that will prevent this module from loading:
Proposed fix- if pred_opts.get("method", "e3tb") != "e3tb":
- raise ValueError(f"Orbital files in 'basis' are only supported for the 'e3tb' method. Found method: {pred_opts.get('method')}")
-
- # Also checking if we are in a mixed model which might be different,
- # but usually orbital files for basis imply the main basis handling.
- # transform logic
+ if pred_opts.get("method", "e3tb") != "e3tb":
+ raise ValueError(f"Orbital files in 'basis' are only supported for the 'e3tb' method. Found method: {pred_opts.get('method')}")
+
+ # Also checking if we are in a mixed model which might be different,
+ # but usually orbital files for basis imply the main basis handling.
+ # transform logic- jdata["common_options"]["orbital_files_content"] = orbital_files_content
+ jdata["common_options"]["orbital_files_content"] = orbital_files_content📝 Committable suggestion
Suggested change
🧰 Tools🪛 Ruff (0.15.0)[warning] 106-106: Avoid specifying long messages outside the exception class (TRY003) [warning] 118-118: Do not catch blind exception: (BLE001) [warning] 119-119: Within an (B904) [warning] 119-119: Avoid specifying long messages outside the exception class (TRY003) 🤖 Prompt for AI Agents |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # update basis if init_model or restart | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # update jdata | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # this is not necessary, because if we init model from checkpoint, the build_model will load the model_options from checkpoints if not provided | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -11,6 +11,16 @@ | |
| from dptb.data.transforms import OrbitalMapper | ||
| from dptb.data import AtomicDataDict | ||
| import logging | ||
| try: | ||
| from dptb.utils.pardiso_wrapper import PyPardisoSolver | ||
| from dptb.utils.feast_wrapper import FeastSolver | ||
| from scipy.sparse.linalg import eigsh, LinearOperator | ||
| except ImportError: | ||
| PyPardisoSolver = None | ||
| FeastSolver = None | ||
| eigsh = None | ||
| LinearOperator = None | ||
|
|
||
| log = logging.getLogger(__name__) | ||
|
|
||
| class Eigenvalues(nn.Module): | ||
|
|
@@ -113,6 +123,178 @@ def forward(self, | |
| data[AtomicDataDict.KPOINT_KEY] = kpoints | ||
|
|
||
| return data | ||
|
|
||
| class PardisoEig: | ||
| def __init__(self, sigma: float = 0.0, neig: int = 10, mode: str = 'normal'): | ||
| """ | ||
| Solver using Pardiso for shift-invert eigenvalue problems. | ||
|
|
||
| Args: | ||
| sigma: Shift value (target energy). | ||
| neig: Number of eigenvalues to solve for. | ||
| mode: Eigsh mode ('normal', 'buckling', 'cayley'). | ||
| """ | ||
| if PyPardisoSolver is None or eigsh is None: | ||
| raise ImportError("PardisoEig requires MKL (pypardiso) and scipy.sparse.linalg") | ||
|
|
||
| self.sigma = sigma | ||
| self.neig = neig | ||
| self.mode = mode | ||
|
|
||
|
|
||
| def solve(self, h_container, s_container, kpoints: Union[list, torch.Tensor, np.ndarray], return_eigenvectors: bool = False): | ||
| """ | ||
| Solve eigenvalues for given k-points. | ||
|
|
||
| Args: | ||
| h_container: vbcsr.ImageContainer for Hamiltonian. | ||
| s_container: vbcsr.ImageContainer for Overlap (can be None). | ||
| kpoints: Array of k-points (Nk, 3). | ||
| return_eigenvectors: If True, return (eigenvalues, eigenvectors). Default False. | ||
|
|
||
| Returns: | ||
| list of eigenvalues arrays (and eigenvectors arrays if return_eigenvectors=True). | ||
| """ | ||
|
|
||
| # Ensure kpoints is numpy array | ||
| if isinstance(kpoints, torch.Tensor): | ||
| kpoints = kpoints.cpu().numpy() | ||
|
|
||
| eigvals_list = [] | ||
| eigvecs_list = [] | ||
|
|
||
| for k in kpoints: | ||
| hk = h_container.sample_k(k, symm=True) | ||
|
|
||
| if s_container is not None: | ||
| sk = s_container.sample_k(k, symm=True) | ||
| hk -= self.sigma * sk | ||
| A = hk.to_scipy(format="csr") | ||
| M = sk | ||
| else: | ||
| hk.shift(-self.sigma) | ||
| A = hk.to_scipy(format="csr") | ||
| M = None | ||
|
|
||
| A.sort_indices() | ||
| A.sum_duplicates() | ||
| N = A.shape[0] | ||
|
|
||
| # Try PARDISO first, fall back to scipy SuperLU if PARDISO fails | ||
| # (MKL PARDISO has a known bug with certain block-structured patterns) | ||
| solver = PyPardisoSolver(mtype=13) | ||
| solver.factorize(A) | ||
|
|
||
| def matvec(b): | ||
| return solver.solve(A, b) | ||
|
|
||
| Op = LinearOperator((N, N), matvec=matvec, dtype=A.dtype) | ||
|
|
||
| try: | ||
| # Use larger NCV to help convergence, especially for clustered eigenvalues | ||
| ncv = max(2*self.neig + 1, 20) | ||
| vals, vecs = eigsh(A=hk, M=M, k=self.neig, sigma=self.sigma, OPinv=Op, mode=self.mode, which="LM", ncv=ncv) | ||
| except Exception: | ||
| # Retry with larger NCV if ARPACK fails (e.g. error 3: No shifts could be applied) | ||
| # This often happens when eigenvalues are clustered near the shift | ||
| ncv = max(5*self.neig, 50) | ||
| vals, vecs = eigsh(A=hk, M=M, k=self.neig, sigma=self.sigma, OPinv=Op, mode=self.mode, which="LM", ncv=ncv) | ||
|
Comment on lines
+166
to
+201
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🧩 Analysis chain🌐 Web query:
💡 Result: In What
|
||
|
|
||
| eigvals_list.append(vals) | ||
| if return_eigenvectors: | ||
| eigvecs_list.append(vecs) | ||
|
|
||
| if return_eigenvectors: | ||
| return eigvals_list, eigvecs_list | ||
| else: | ||
| return eigvals_list | ||
|
|
||
| class FEASTEig: | ||
| def __init__(self, emin: float = -1.0, emax: float = 1.0, m0: Optional[int] = None, | ||
| max_refinement: int = 3, uplo: str = 'U', extract_triangular: bool = True): | ||
| """ | ||
| Solver using FEAST algorithm for finding eigenvalues in a given interval. | ||
|
|
||
| Args: | ||
| emin, emax: Energy interval [emin, emax]. | ||
| m0: Initial subspace size estimate. | ||
| max_refinement: Number of refinements if subspace is too small. | ||
| uplo: 'U' (Upper) or 'L' (Lower) triangular part to use. | ||
| extract_triangular: Whether to extract triangular part automatically. | ||
| """ | ||
|
|
||
| if FeastSolver is None: | ||
| raise ImportError("FEAST solver not available") | ||
|
|
||
| self.emin = emin | ||
| self.emax = emax | ||
| self.m0 = m0 | ||
| self.max_refinement = max_refinement | ||
| self.uplo = uplo | ||
| self.extract_triangular = extract_triangular | ||
|
|
||
| # Initialize solver to check availability | ||
| try: | ||
| self.solver = FeastSolver() | ||
| except ImportError as e: | ||
| raise ImportError(f"FEAST solver not available: {e}") | ||
| except Exception as e: | ||
| raise RuntimeError(f"Failed to initialize FeastSolver: {e}") | ||
|
Comment on lines
+237
to
+242
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion | 🟠 Major Chain exceptions with Per B904, re-raised exceptions inside Proposed fix try:
self.solver = FeastSolver()
except ImportError as e:
- raise ImportError(f"FEAST solver not available: {e}")
+ raise ImportError(f"FEAST solver not available: {e}") from e
except Exception as e:
- raise RuntimeError(f"Failed to initialize FeastSolver: {e}")
+ raise RuntimeError(f"Failed to initialize FeastSolver: {e}") from e🧰 Tools🪛 Ruff (0.15.0)[warning] 240-240: Within an (B904) [warning] 240-240: Avoid specifying long messages outside the exception class (TRY003) [warning] 241-241: Do not catch blind exception: (BLE001) [warning] 242-242: Within an (B904) [warning] 242-242: Avoid specifying long messages outside the exception class (TRY003) 🤖 Prompt for AI Agents |
||
|
|
||
| def solve(self, h_container, s_container, kpoints: Union[list, torch.Tensor, np.ndarray], return_eigenvectors: bool = False): | ||
| """ | ||
| Solve eigenvalues for given k-points using FEAST. | ||
|
|
||
| Args: | ||
| h_container: Container for Hamiltonian (must support sample_k().to_scipy()). | ||
| s_container: Container for Overlap (can be None). | ||
| kpoints: Array of k-points. | ||
| return_eigenvectors: If True, return (eigenvalues, eigenvectors). Default False. | ||
|
|
||
| Returns: | ||
| list of eigenvalues arrays (and eigenvectors arrays if return_eigenvectors=True). | ||
| """ | ||
| if isinstance(kpoints, torch.Tensor): | ||
| kpoints = kpoints.cpu().numpy() | ||
|
|
||
| eigvals_list = [] | ||
| eigvecs_list = [] | ||
|
|
||
| for k in kpoints: | ||
| # Get Hamiltonian and Overlap at k | ||
| # Assuming h_container.sample_k returns object with .to_scipy() | ||
| hk_obj = h_container.sample_k(k, symm=True) | ||
| if hasattr(hk_obj, 'to_scipy'): | ||
| hk = hk_obj.to_scipy(format="csr") | ||
| else: | ||
| # Fallback if it checks sparse type | ||
| hk = hk_obj | ||
|
|
||
| if s_container is not None: | ||
| sk_obj = s_container.sample_k(k, symm=True) | ||
| if hasattr(sk_obj, 'to_scipy'): | ||
| sk = sk_obj.to_scipy(format="csr") | ||
| else: | ||
| sk = sk_obj | ||
| else: | ||
| sk = None | ||
|
|
||
| # Solve | ||
| evals, vecs = self.solver.solve( | ||
| hk, M=sk, emin=self.emin, emax=self.emax, | ||
| m0=self.m0, max_refinement=self.max_refinement, | ||
| uplo=self.uplo, extract_triangular=self.extract_triangular | ||
| ) | ||
|
|
||
| eigvals_list.append(evals) | ||
| if return_eigenvectors: | ||
| eigvecs_list.append(vecs) | ||
|
|
||
| if return_eigenvectors: | ||
| return eigvals_list, eigvecs_list | ||
| else: | ||
| return eigvals_list | ||
|
|
||
|
|
||
| class Eigh(nn.Module): | ||
| def __init__( | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Chain the exception to preserve traceback context.
except Exceptionis overly broad, and the re-raisedValueErrordrops the original traceback. Useraise ... from eso the cause is visible during debugging.📝 Committable suggestion
🧰 Tools
🪛 Ruff (0.15.0)
[warning] 118-118: Do not catch blind exception:
Exception(BLE001)
[warning] 119-119: Within an
exceptclause, raise exceptions withraise ... from errorraise ... from Noneto distinguish them from errors in exception handling(B904)
[warning] 119-119: Avoid specifying long messages outside the exception class
(TRY003)
🤖 Prompt for AI Agents