Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions dptb/entrypoints/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from dptb.plugins.monitor import TrainLossMonitor, LearningRateMonitor, Validationer, TensorBoardMonitor
from dptb.plugins.train_logger import Logger
from dptb.utils.argcheck import normalize, collect_cutoffs, chk_avg_per_iter
from dptb.utils.orbital_parser import parse_orbital_file
from dptb.plugins.saver import Saver
from typing import Dict, List, Optional, Any
from dptb.utils.tools import j_loader, setup_seed, j_must_have
Expand Down Expand Up @@ -90,6 +91,35 @@ def train(

jdata = j_loader(INPUT)
jdata = normalize(jdata)

# Validate and process orbital files in basis
if jdata.get("common_options") and jdata["common_options"].get("basis"):
orbital_files_content = {}
for elem, value in jdata["common_options"]["basis"].items():
if isinstance(value, str) and os.path.isfile(value):
# strict check for e3tb method
# Check if model_options exists and has prediction method
model_opts = jdata.get("model_options", {})
pred_opts = model_opts.get("prediction", {})
# normalize might handle defaults, but safely check here
if pred_opts.get("method", "e3tb") != "e3tb":
raise ValueError(f"Orbital files in 'basis' are only supported for the 'e3tb' method. Found method: {pred_opts.get('method')}")

# Also checking if we are in a mixed model which might be different,
# but usually orbital files for basis imply the main basis handling.
# transform logic
try:
parsed_basis = parse_orbital_file(value)
with open(value, 'r') as f:
orbital_files_content[elem] = f.read()

jdata["common_options"]["basis"][elem] = parsed_basis
log.info(f"Parsed orbital file for {elem}: {value} -> {parsed_basis}")
except Exception as e:
raise ValueError(f"Failed to parse orbital file {value} for element {elem}: {e}")
Comment on lines +118 to +119
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Chain the exception to preserve traceback context.

except Exception is overly broad, and the re-raised ValueError drops the original traceback. Use raise ... from e so the cause is visible during debugging.

-                except Exception as e:
-                    raise ValueError(f"Failed to parse orbital file {value} for element {elem}: {e}")
+                except (ValueError, FileNotFoundError, OSError) as e:
+                    raise ValueError(f"Failed to parse orbital file {value} for element {elem}: {e}") from e
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
except Exception as e:
raise ValueError(f"Failed to parse orbital file {value} for element {elem}: {e}")
except (ValueError, FileNotFoundError, OSError) as e:
raise ValueError(f"Failed to parse orbital file {value} for element {elem}: {e}") from e
🧰 Tools
🪛 Ruff (0.15.0)

[warning] 118-118: Do not catch blind exception: Exception

(BLE001)


[warning] 119-119: Within an except clause, raise exceptions with raise ... from err or raise ... from None to distinguish them from errors in exception handling

(B904)


[warning] 119-119: Avoid specifying long messages outside the exception class

(TRY003)

🤖 Prompt for AI Agents
In `@dptb/entrypoints/train.py` around lines 118 - 119, The current except block
in dptb/entrypoints/train.py that reads "except Exception as e: raise
ValueError(f'Failed to parse orbital file {value} for element {elem}: {e}')"
drops the original traceback; change it to chain the original exception by using
"raise ValueError(f'Failed to parse orbital file {value} for element {elem}:
{e}') from e" (and, if feasible, narrow the catch to specific exceptions instead
of bare Exception) so the original error context for value and elem is
preserved.


if orbital_files_content:
jdata["common_options"]["orbital_files_content"] = orbital_files_content
Comment on lines 106 to 122
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Inconsistent indentation will cause IndentationError at runtime.

Several lines have off-by-one (or more) indentation that will prevent this module from loading:

  • Line 106: 21 spaces (expected 20 to be inside the if on Line 105).
  • Lines 108–110: 17 spaces (expected 16, same level as the if/try).
  • Line 122: 13 spaces (expected 12, inside the outer if but outside the for).
Proposed fix
-                if pred_opts.get("method", "e3tb") != "e3tb":
-                     raise ValueError(f"Orbital files in 'basis' are only supported for the 'e3tb' method. Found method: {pred_opts.get('method')}")
-                
-                 # Also checking if we are in a mixed model which might be different, 
-                 # but usually orbital files for basis imply the main basis handling.
-                 # transform logic
+                if pred_opts.get("method", "e3tb") != "e3tb":
+                    raise ValueError(f"Orbital files in 'basis' are only supported for the 'e3tb' method. Found method: {pred_opts.get('method')}")
+
+                # Also checking if we are in a mixed model which might be different,
+                # but usually orbital files for basis imply the main basis handling.
+                # transform logic
-             jdata["common_options"]["orbital_files_content"] = orbital_files_content
+            jdata["common_options"]["orbital_files_content"] = orbital_files_content
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
raise ValueError(f"Orbital files in 'basis' are only supported for the 'e3tb' method. Found method: {pred_opts.get('method')}")
# Also checking if we are in a mixed model which might be different,
# but usually orbital files for basis imply the main basis handling.
# transform logic
try:
parsed_basis = parse_orbital_file(value)
with open(value, 'r') as f:
orbital_files_content[value] = f.read()
jdata["common_options"]["basis"][elem] = parsed_basis
log.info(f"Parsed orbital file for {elem}: {value} -> {parsed_basis}")
except Exception as e:
raise ValueError(f"Failed to parse orbital file {value} for element {elem}: {e}")
if orbital_files_content:
jdata["common_options"]["orbital_files_content"] = orbital_files_content
raise ValueError(f"Orbital files in 'basis' are only supported for the 'e3tb' method. Found method: {pred_opts.get('method')}")
# Also checking if we are in a mixed model which might be different,
# but usually orbital files for basis imply the main basis handling.
# transform logic
try:
parsed_basis = parse_orbital_file(value)
with open(value, 'r') as f:
orbital_files_content[value] = f.read()
jdata["common_options"]["basis"][elem] = parsed_basis
log.info(f"Parsed orbital file for {elem}: {value} -> {parsed_basis}")
except Exception as e:
raise ValueError(f"Failed to parse orbital file {value} for element {elem}: {e}")
jdata["common_options"]["orbital_files_content"] = orbital_files_content
🧰 Tools
🪛 Ruff (0.15.0)

[warning] 106-106: Avoid specifying long messages outside the exception class

(TRY003)


[warning] 118-118: Do not catch blind exception: Exception

(BLE001)


[warning] 119-119: Within an except clause, raise exceptions with raise ... from err or raise ... from None to distinguish them from errors in exception handling

(B904)


[warning] 119-119: Avoid specifying long messages outside the exception class

(TRY003)

🤖 Prompt for AI Agents
In `@dptb/entrypoints/train.py` around lines 106 - 122, The snippet has
inconsistent indentation causing IndentationError; align the try/except and
inner statements to the surrounding if/for blocks: ensure the "try:" block, the
calls to parse_orbital_file(value), the with open(...) block populating
orbital_files_content[value], the jdata["common_options"]["basis"][elem] =
parsed_basis assignment, and the log.info(...) line are indented uniformly at
the same level inside the for/if scope, and ensure the except: and its raise
ValueError(...) match the try's indentation; finally make sure the final "if
orbital_files_content:" and its body
jdata["common_options"]["orbital_files_content"] = orbital_files_content are
indented one level inside the outer function/block. This will fix the off-by-one
indentation issues around parse_orbital_file, orbital_files_content,
jdata["common_options"]["basis"], and log.info.

# update basis if init_model or restart
# update jdata
# this is not necessary, because if we init model from checkpoint, the build_model will load the model_options from checkpoints if not provided
Expand Down
2 changes: 2 additions & 0 deletions dptb/nn/deeptb.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,8 @@ def __init__(
device=self.device,
)

if kwargs.get("orbital_files_content"):
self.orbital_files_content = kwargs["orbital_files_content"]

def forward(self, data: AtomicDataDict.Type):
if data.get(AtomicDataDict.EDGE_TYPE_KEY, None) is None:
Expand Down
182 changes: 182 additions & 0 deletions dptb/nn/energy.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,16 @@
from dptb.data.transforms import OrbitalMapper
from dptb.data import AtomicDataDict
import logging
try:
from dptb.utils.pardiso_wrapper import PyPardisoSolver
from dptb.utils.feast_wrapper import FeastSolver
from scipy.sparse.linalg import eigsh, LinearOperator
except ImportError:
PyPardisoSolver = None
FeastSolver = None
eigsh = None
LinearOperator = None

log = logging.getLogger(__name__)

class Eigenvalues(nn.Module):
Expand Down Expand Up @@ -113,6 +123,178 @@ def forward(self,
data[AtomicDataDict.KPOINT_KEY] = kpoints

return data

class PardisoEig:
def __init__(self, sigma: float = 0.0, neig: int = 10, mode: str = 'normal'):
"""
Solver using Pardiso for shift-invert eigenvalue problems.

Args:
sigma: Shift value (target energy).
neig: Number of eigenvalues to solve for.
mode: Eigsh mode ('normal', 'buckling', 'cayley').
"""
if PyPardisoSolver is None or eigsh is None:
raise ImportError("PardisoEig requires MKL (pypardiso) and scipy.sparse.linalg")

self.sigma = sigma
self.neig = neig
self.mode = mode


def solve(self, h_container, s_container, kpoints: Union[list, torch.Tensor, np.ndarray], return_eigenvectors: bool = False):
"""
Solve eigenvalues for given k-points.

Args:
h_container: vbcsr.ImageContainer for Hamiltonian.
s_container: vbcsr.ImageContainer for Overlap (can be None).
kpoints: Array of k-points (Nk, 3).
return_eigenvectors: If True, return (eigenvalues, eigenvectors). Default False.

Returns:
list of eigenvalues arrays (and eigenvectors arrays if return_eigenvectors=True).
"""

# Ensure kpoints is numpy array
if isinstance(kpoints, torch.Tensor):
kpoints = kpoints.cpu().numpy()

eigvals_list = []
eigvecs_list = []

for k in kpoints:
hk = h_container.sample_k(k, symm=True)

if s_container is not None:
sk = s_container.sample_k(k, symm=True)
hk -= self.sigma * sk
A = hk.to_scipy(format="csr")
M = sk
else:
hk.shift(-self.sigma)
A = hk.to_scipy(format="csr")
M = None

A.sort_indices()
A.sum_duplicates()
N = A.shape[0]

# Try PARDISO first, fall back to scipy SuperLU if PARDISO fails
# (MKL PARDISO has a known bug with certain block-structured patterns)
solver = PyPardisoSolver(mtype=13)
solver.factorize(A)

def matvec(b):
return solver.solve(A, b)

Op = LinearOperator((N, N), matvec=matvec, dtype=A.dtype)

try:
# Use larger NCV to help convergence, especially for clustered eigenvalues
ncv = max(2*self.neig + 1, 20)
vals, vecs = eigsh(A=hk, M=M, k=self.neig, sigma=self.sigma, OPinv=Op, mode=self.mode, which="LM", ncv=ncv)
except Exception:
# Retry with larger NCV if ARPACK fails (e.g. error 3: No shifts could be applied)
# This often happens when eigenvalues are clustered near the shift
ncv = max(5*self.neig, 50)
vals, vecs = eigsh(A=hk, M=M, k=self.neig, sigma=self.sigma, OPinv=Op, mode=self.mode, which="LM", ncv=ncv)
Comment on lines +166 to +201
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🌐 Web query:

scipy sparse linalg eigsh sigma OPinv shift-invert mode documentation

💡 Result:

In scipy.sparse.linalg.eigsh, shift-invert is activated by passing a real sigma (the target shift). ARPACK then works with a transformed eigenproblem so it can efficiently return eigenvalues of A x = w M x near sigma. (docs.scipy.org)

What sigma means (shift-invert)

With sigma set, eigsh needs repeated solves of the linear system

[
(A - \sigma M),x = b \quad (\text{with } M = I \text{ if omitted})
]

SciPy will form/solve this internally (e.g., sparse LU for explicit sparse matrices), or you can provide that inverse action yourself via OPinv. (docs.scipy.org)

What OPinv must be

If you pass OPinv, it must represent the operator

[
OPinv \equiv (A - \sigma M)^{-1}
]

i.e., it must apply b ↦ (A - sigma*M)^{-1} b. SciPy notes this is the required form regardless of mode. (docs.scipy.org)

How mode changes what ARPACK selects (which)

When sigma is specified, SciPy emphasizes that which no longer refers directly to w, but to shifted Ritz values w':

  • mode='normal': (w' = 1/(w-\sigma))
  • mode='cayley': (w' = (w+\sigma)/(w-\sigma))
  • mode='buckling': (w' = w/(w-\sigma))

So (commonly) to get eigenvalues closest to sigma, you typically use sigma=... with which='LM' (largest magnitude in the transformed problem). (docs.scipy.org)

Minimal example (providing OPinv)

import numpy as np
import scipy.sparse as sp
import scipy.sparse.linalg as sla

A = sp.random(10000, 10000, density=1e-4, format="csc")
A = (A + A.T) * 0.5  # make symmetric
sigma = 0.1

# Factor (A - sigma*I) once, then expose inverse-apply as OPinv
lu = sla.splu(A - sigma * sp.eye(A.shape[0], format="csc"))
OPinv = sla.LinearOperator(A.shape, matvec=lu.solve)

w, v = sla.eigsh(A, k=6, sigma=sigma, which="LM", OPinv=OPinv, mode="normal")

(If you don’t supply OPinv, SciPy will choose an internal approach to apply ((A-\sigma M)^{-1}) for you.) (docs.scipy.org)

Citations:


🏁 Script executed:

# Examine the actual code file to understand the implementation
cat -n dptb/nn/energy.py | sed -n '150,210p'

Repository: deepmodeling/DeePTB

Length of output: 2996


🏁 Script executed:

# Check the vbcsr type and its methods
fd -e py -x grep -l "class vbcsr\|def shift\|def sample_k" dptb/

Repository: deepmodeling/DeePTB

Length of output: 7134


🏁 Script executed:

# Find vbcsr definition to understand type compatibility
fd -e py -x grep -l "class vbcsr" {} \;

Repository: deepmodeling/DeePTB

Length of output: 45


🏁 Script executed:

# Check imports in energy.py to see what vbcsr is
head -50 dptb/nn/energy.py | grep -E "import|from"

Repository: deepmodeling/DeePTB

Length of output: 464


🏁 Script executed:

# Check if vbcsr has scipy-compatible methods
rg "def to_scipy|def __matmul__|def __mul__" -A 2 dptb/

Repository: deepmodeling/DeePTB

Length of output: 221


🏁 Script executed:

# Find where vbcsr comes from - check dptb imports and structure
rg "vbcsr" dptb/nn/energy.py -B 3 -A 1

Repository: deepmodeling/DeePTB

Length of output: 315


🏁 Script executed:

# Search more broadly for vbcsr definition
rg "class vbcsr" dptb/

Repository: deepmodeling/DeePTB

Length of output: 45


🏁 Script executed:

# Check the actual type used in h_container.sample_k
rg "def sample_k" dptb/ -A 3

Repository: deepmodeling/DeePTB

Length of output: 322


🏁 Script executed:

# Search for vbcsr imports
rg "from.*vbcsr|import.*vbcsr" dptb/

Repository: deepmodeling/DeePTB

Length of output: 448


🏁 Script executed:

# Check if vbcsr is external package
fd -e py -x grep -l "ImageContainer" {} \;

Repository: deepmodeling/DeePTB

Length of output: 243


🏁 Script executed:

# Look for what happens with hk when passed to eigsh - check if it's converted
rg "eigsh.*A=" dptb/ -B 3 -A 3

Repository: deepmodeling/DeePTB

Length of output: 1025


🌐 Web query:

vbcsr Python library scipy compatible sparse matrix

💡 Result:

vbcsr is SciPy-compatible in the practical sense (works with SciPy sparse workflows), but it is not a scipy.sparse.spmatrix subclass.

  • Convert from SciPy sparse → VBCSR: VBCSR.from_scipy(A_scipy) (the docs mention SciPy BSR or CSR inputs). [1]
  • Use in SciPy iterative solvers: vbcsr.VBCSR implements the scipy.sparse.linalg.LinearOperator interface, so you can pass a VBCSR matrix directly to solvers like cg, gmres, etc. [1]
  • Install: pip install vbcsr (latest release shown on PyPI is 0.2.0, Feb 10, 2026). [1]

Sources: [1]


🏁 Script executed:

# Check the vbcsr type more specifically - see if it has scipy interoperability
rg "\.shift\(" dptb/ -B 2 -A 2

Repository: deepmodeling/DeePTB

Length of output: 302


🏁 Script executed:

# Check what methods vbcsr objects have
rg "to_scipy|__matmul__" dptb/ -B 1 -A 3

Repository: deepmodeling/DeePTB

Length of output: 2583


eigsh receives already-shifted hk together with sigma — this causes a double-shift bug.

When eigsh is called with sigma=self.sigma and OPinv, it internally computes (A - sigma*M)^{-1}. However, on line 171 (branch with overlap), hk is already shifted to (H − σS), and on line 175 (branch without overlap), hk is shifted by −σ.

By line 196, when eigsh(A=hk, M=M, sigma=self.sigma, OPinv=Op, ...) is called:

  • eigsh sees A = (H − σS) and sigma = σ
  • It attempts to compute ((H − σS) − σ*S)^{-1} = (H − 2σS)^{-1}
  • But OPinv was factorized as only (H − σS)^{-1}
  • Result: wrong operator, wrong eigenvalues

Fix: Keep the original Hamiltonian for eigsh's A argument; use the shifted matrix only for the Pardiso factorization.

Additionally, on line 173, M = sk is a vbcsr object. While vbcsr implements the LinearOperator interface, convert it to scipy format with sk.to_scipy(format="csr") for consistency and to ensure full compatibility with eigsh.

Proposed fix (sketch)
         for k in kpoints:
             hk = h_container.sample_k(k, symm=True)
+            hk_scipy = hk.to_scipy(format="csr")   # original H(k) for eigsh
             
             if s_container is not None:
                 sk = s_container.sample_k(k, symm=True)
-                hk -= self.sigma * sk
-                A = hk.to_scipy(format="csr")
-                M = sk
+                sk_scipy = sk.to_scipy(format="csr")
+                A_shifted = hk_scipy - self.sigma * sk_scipy
+                M = sk_scipy
             else:
-                hk.shift(-self.sigma)
-                A = hk.to_scipy(format="csr")
+                from scipy.sparse import eye
+                A_shifted = hk_scipy - self.sigma * eye(hk_scipy.shape[0], format="csr")
                 M = None
             
-            A.sort_indices()
-            A.sum_duplicates()
-            N = A.shape[0]
+            A_shifted.sort_indices()
+            A_shifted.sum_duplicates()
+            N = A_shifted.shape[0]
             
             solver = PyPardisoSolver(mtype=13)
-            solver.factorize(A)
+            solver.factorize(A_shifted)
             
             def matvec(b):
-                return solver.solve(A, b)
+                return solver.solve(A_shifted, b)
                 
-            Op = LinearOperator((N, N), matvec=matvec, dtype=A.dtype)
+            Op = LinearOperator((N, N), matvec=matvec, dtype=A_shifted.dtype)
             
             try:
                 ncv =  max(2*self.neig + 1, 20)
-                vals, vecs = eigsh(A=hk, M=M, k=self.neig, sigma=self.sigma, OPinv=Op, mode=self.mode, which="LM", ncv=ncv)
+                vals, vecs = eigsh(A=hk_scipy, M=M, k=self.neig, sigma=self.sigma, OPinv=Op, mode=self.mode, which="LM", ncv=ncv)
             except Exception:
                 ncv =  max(5*self.neig, 50)
-                vals, vecs = eigsh(A=hk, M=M, k=self.neig, sigma=self.sigma, OPinv=Op, mode=self.mode, which="LM", ncv=ncv)
+                vals, vecs = eigsh(A=hk_scipy, M=M, k=self.neig, sigma=self.sigma, OPinv=Op, mode=self.mode, which="LM", ncv=ncv)
🧰 Tools
🪛 Ruff (0.15.0)

[warning] 189-189: Function definition does not bind loop variable solver

(B023)


[warning] 189-189: Function definition does not bind loop variable A

(B023)


[warning] 197-197: Do not catch blind exception: Exception

(BLE001)

🤖 Prompt for AI Agents
In `@dptb/nn/energy.py` around lines 166 - 201, The code is double-shifting the
matrix: hk is mutated to (H − σS) or (H − σ) before being passed as A to eigsh
while eigsh is also given sigma=self.sigma and OPinv built from the shifted
factorization, causing (H − 2σS) mismatch. Fix by preserving an unshifted copy
(e.g., hk_orig = hk) to pass as A to eigsh, then create a separate shifted
matrix only for Pardiso factorization (use solver.factorize on the shifted
matrix), build Op from that factorization, and pass M as a scipy CSR (convert sk
via sk.to_scipy(format="csr") when s_container is present) so eigsh receives the
original H (hk_orig), M as CSR, OPinv from the factorized shifted matrix, and
sigma=self.sigma.

Comment on lines +193 to +201
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Broad except Exception silently retries — real errors will be masked.

If eigsh raises for a reason other than ARPACK convergence (e.g., incompatible matrix types, memory error), the retry will likely also fail, but the original traceback is lost. Catch scipy.sparse.linalg.ArpackNoConvergence (or at least ArpackError) specifically, or at minimum log the first exception.

🧰 Tools
🪛 Ruff (0.15.0)

[warning] 197-197: Do not catch blind exception: Exception

(BLE001)

🤖 Prompt for AI Agents
In `@dptb/nn/energy.py` around lines 193 - 201, The broad except around the eigsh
call masks real errors; change it to catch
scipy.sparse.linalg.ArpackNoConvergence (or ArpackError) specifically when
retrying, and ensure the original exception is logged before retrying so
non-ARPACK errors are not swallowed. Locate the eigsh invocation (vars: hk, M,
Op, self.neig, self.sigma, mode, ncv) and replace the bare "except Exception:"
with a targeted "except ArpackNoConvergence/ArpackError as err:" (importing the
symbol), call the logger to record err, then perform the larger-ncv retry; leave
other exceptions to propagate.


eigvals_list.append(vals)
if return_eigenvectors:
eigvecs_list.append(vecs)

if return_eigenvectors:
return eigvals_list, eigvecs_list
else:
return eigvals_list

class FEASTEig:
def __init__(self, emin: float = -1.0, emax: float = 1.0, m0: Optional[int] = None,
max_refinement: int = 3, uplo: str = 'U', extract_triangular: bool = True):
"""
Solver using FEAST algorithm for finding eigenvalues in a given interval.

Args:
emin, emax: Energy interval [emin, emax].
m0: Initial subspace size estimate.
max_refinement: Number of refinements if subspace is too small.
uplo: 'U' (Upper) or 'L' (Lower) triangular part to use.
extract_triangular: Whether to extract triangular part automatically.
"""

if FeastSolver is None:
raise ImportError("FEAST solver not available")

self.emin = emin
self.emax = emax
self.m0 = m0
self.max_refinement = max_refinement
self.uplo = uplo
self.extract_triangular = extract_triangular

# Initialize solver to check availability
try:
self.solver = FeastSolver()
except ImportError as e:
raise ImportError(f"FEAST solver not available: {e}")
except Exception as e:
raise RuntimeError(f"Failed to initialize FeastSolver: {e}")
Comment on lines +237 to +242
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

Chain exceptions with from e to preserve the original traceback.

Per B904, re-raised exceptions inside except blocks should use raise ... from e so the caller sees the full chain.

Proposed fix
         try:
             self.solver = FeastSolver()
         except ImportError as e:
-            raise ImportError(f"FEAST solver not available: {e}")
+            raise ImportError(f"FEAST solver not available: {e}") from e
         except Exception as e:
-             raise RuntimeError(f"Failed to initialize FeastSolver: {e}")
+             raise RuntimeError(f"Failed to initialize FeastSolver: {e}") from e
🧰 Tools
🪛 Ruff (0.15.0)

[warning] 240-240: Within an except clause, raise exceptions with raise ... from err or raise ... from None to distinguish them from errors in exception handling

(B904)


[warning] 240-240: Avoid specifying long messages outside the exception class

(TRY003)


[warning] 241-241: Do not catch blind exception: Exception

(BLE001)


[warning] 242-242: Within an except clause, raise exceptions with raise ... from err or raise ... from None to distinguish them from errors in exception handling

(B904)


[warning] 242-242: Avoid specifying long messages outside the exception class

(TRY003)

🤖 Prompt for AI Agents
In `@dptb/nn/energy.py` around lines 237 - 242, The except blocks around
FeastSolver initialization should chain the original exceptions so tracebacks
are preserved; modify the two re-raises in the try/except that set self.solver =
FeastSolver() to use "raise ImportError(... ) from e" for the ImportError branch
and "raise RuntimeError(... ) from e" for the generic Exception branch, keeping
the same messages but adding "from e" to both raise statements referencing
FeastSolver and self.solver initialization.


def solve(self, h_container, s_container, kpoints: Union[list, torch.Tensor, np.ndarray], return_eigenvectors: bool = False):
"""
Solve eigenvalues for given k-points using FEAST.

Args:
h_container: Container for Hamiltonian (must support sample_k().to_scipy()).
s_container: Container for Overlap (can be None).
kpoints: Array of k-points.
return_eigenvectors: If True, return (eigenvalues, eigenvectors). Default False.

Returns:
list of eigenvalues arrays (and eigenvectors arrays if return_eigenvectors=True).
"""
if isinstance(kpoints, torch.Tensor):
kpoints = kpoints.cpu().numpy()

eigvals_list = []
eigvecs_list = []

for k in kpoints:
# Get Hamiltonian and Overlap at k
# Assuming h_container.sample_k returns object with .to_scipy()
hk_obj = h_container.sample_k(k, symm=True)
if hasattr(hk_obj, 'to_scipy'):
hk = hk_obj.to_scipy(format="csr")
else:
# Fallback if it checks sparse type
hk = hk_obj

if s_container is not None:
sk_obj = s_container.sample_k(k, symm=True)
if hasattr(sk_obj, 'to_scipy'):
sk = sk_obj.to_scipy(format="csr")
else:
sk = sk_obj
else:
sk = None

# Solve
evals, vecs = self.solver.solve(
hk, M=sk, emin=self.emin, emax=self.emax,
m0=self.m0, max_refinement=self.max_refinement,
uplo=self.uplo, extract_triangular=self.extract_triangular
)

eigvals_list.append(evals)
if return_eigenvectors:
eigvecs_list.append(vecs)

if return_eigenvectors:
return eigvals_list, eigvecs_list
else:
return eigvals_list


class Eigh(nn.Module):
def __init__(
Expand Down
Loading
Loading