Skip to content

Commit

Permalink
- add warnings if neg eigvals
Browse files Browse the repository at this point in the history
- modify test case
  • Loading branch information
tingwl0122 committed Sep 12, 2024
1 parent 03a9dfd commit d1d3ab2
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 5 deletions.
2 changes: 1 addition & 1 deletion dattri/algorithm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ def attribute(

tda_output[row_st:row_ed, col_st:col_ed] += (
train_batch_rep @ test_batch_rep.T
).to(torch.float)
)

tda_output /= checkpoint_idx + 1

Expand Down
15 changes: 14 additions & 1 deletion dattri/func/projection.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
from collections.abc import Callable
from typing import Dict, List, Union

import warnings

import numpy as np
import torch
from torch import Tensor
Expand Down Expand Up @@ -753,6 +755,18 @@ def get_eigenspace(self) -> None:
# proj dim is the value of k
self.eigvals, self.eigvecs = self._distill(appr_mat, proj, self.proj_dim)

# prevent from negative eigvals
if self.proj_dim > torch.sum(self.eigvals > 0):
# adjust proj_dim
self.proj_dim = torch.sum(self.eigvals > 0).item()
warnings.warn(
"Encountered many negative eigenvalues and `proj_dim` is greater"
" than the number of positive eigenvalues. Automatically adjusting"
" `proj_dim` to the number of positive eigenvalues. Please consider"
" increasing `regularization` to reduce negative eigenvalues.",
stacklevel=1,
)

def project(self, features: Union[dict, Tensor]) -> Tensor:
"""Performs the random projection on the feature matrix.
Expand All @@ -770,7 +784,6 @@ def project(self, features: Union[dict, Tensor]) -> Tensor:
if self.eigvals is None or self.eigvecs is None:
self.get_eigenspace()

self.eigvals = (self.eigvals).to(torch.complex64)
return features @ self.eigvecs.T * (1.0 / torch.sqrt(self.eigvals.unsqueeze(0)))

def free_memory(self) -> None:
Expand Down
8 changes: 5 additions & 3 deletions test/dattri/func/test_proj.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,12 +152,14 @@ def target(x):
return torch.sin(x).sum()

x = torch.randn(self.feature_dim)

# set reg large enough to have some positive eigvals
reg = 1.0
self.projector = ArnoldiProjector(
self.feature_dim,
self.proj_dim,
target,
x,
regularization=reg,
)

vec1 = torch.randn(self.vec_dim, self.feature_dim)
Expand All @@ -167,8 +169,8 @@ def target(x):

# test the closeness of inner product only
assert torch.allclose(
(projected_grads1 @ projected_grads2.T).to(torch.float),
(vec1 @ torch.diag(-1 / x.sin())) @ vec2.T,
(projected_grads1 @ projected_grads2.T),
(vec1 @ (torch.diag(1 / (reg - x.sin())))) @ vec2.T,
rtol=1e-01,
atol=1e-04,
)
Expand Down

0 comments on commit d1d3ab2

Please sign in to comment.