Skip to content

Commit

Permalink
Merge pull request #578 from aai-institute/fix/552-ekfac-bug-mps-osx
Browse files Browse the repository at this point in the history
Fix/552 ekfac bug mps osx
  • Loading branch information
schroedk authored May 7, 2024
2 parents 7fa1ab2 + 97b09df commit 8e211dc
Show file tree
Hide file tree
Showing 6 changed files with 139 additions and 3 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@
implementation [PR #570](https://github.com/aai-institute/pyDVL/pull/570)
- Missing move to device of `preconditioner` in `CgInfluence` implementation
[PR #572](https://github.com/aai-institute/pyDVL/pull/572)
- Raise a more specific error message, when a `RunTimeError` occurs in
`torch.linalg.eigh`, so the user can check if it is related to a known
issue
[PR #578](https://github.com/aai-institute/pyDVL/pull/578)

### Changed

Expand Down
5 changes: 3 additions & 2 deletions src/pydvl/influence/torch/influence_function_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
EkfacRepresentation,
empirical_cross_entropy_loss_fn,
flatten_dimensions,
safe_torch_linalg_eigh,
)

__all__ = [
Expand Down Expand Up @@ -1284,8 +1285,8 @@ def fit(self, data: DataLoader) -> EkfacInfluence:
layers_evect_g = {}
layers_diags = {}
for key in self.active_layers.keys():
evals_a, evecs_a = torch.linalg.eigh(forward_x[key])
evals_g, evecs_g = torch.linalg.eigh(grad_y[key])
evals_a, evecs_a = safe_torch_linalg_eigh(forward_x[key])
evals_g, evecs_g = safe_torch_linalg_eigh(grad_y[key])
layers_evecs_a[key] = evecs_a
layers_evect_g[key] = evecs_g
layers_diags[key] = torch.kron(evals_g.view(-1, 1), evals_a.view(-1, 1))
Expand Down
48 changes: 47 additions & 1 deletion src/pydvl/influence/torch/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from typing import (
Collection,
Dict,
Generator,
Iterable,
Iterator,
List,
Expand All @@ -25,6 +24,7 @@
from torch.utils.data import Dataset
from tqdm import tqdm

from ...utils.exceptions import catch_and_raise_exception
from ..array import (
LazyChunkSequence,
NestedLazyChunkSequence,
Expand Down Expand Up @@ -552,3 +552,49 @@ def empirical_cross_entropy_loss_fn(
torch.isfinite(log_probs_), log_probs_, torch.zeros_like(log_probs_)
)
return torch.sum(log_probs_ * probs_.detach() ** 0.5)


@catch_and_raise_exception(RuntimeError, lambda e: TorchLinalgEighException(e))
def safe_torch_linalg_eigh(*args, **kwargs):
"""
A wrapper around `torch.linalg.eigh` that safely handles potential runtime errors
by raising a custom `TorchLinalgEighException` with more context,
especially related to the issues reported in
[https://github.com/pytorch/pytorch/issues/92141](
https://github.com/pytorch/pytorch/issues/92141).
Args:
*args: Positional arguments passed to `torch.linalg.eigh`.
**kwargs: Keyword arguments passed to `torch.linalg.eigh`.
Returns:
The result of calling `torch.linalg.eigh` with the provided arguments.
Raises:
TorchLinalgEighException: If a `RuntimeError` occurs during the execution of
`torch.linalg.eigh`.
"""
return torch.linalg.eigh(*args, **kwargs)


class TorchLinalgEighException(Exception):
"""
Exception to wrap a RunTimeError raised by torch.linalg.eigh, when used
with large matrices,
see [https://github.com/pytorch/pytorch/issues/92141](
https://github.com/pytorch/pytorch/issues/92141)
"""

def __init__(self, original_exception: RuntimeError):
func = torch.linalg.eigh
err_msg = (
f"A RunTimeError occurred in '{func.__module__}.{func.__qualname__}'. "
"This might be related to known issues with "
"[torch.linalg.eigh][torch.linalg.eigh] on certain matrix sizes.\n "
"For more details, refer to "
"https://github.com/pytorch/pytorch/issues/92141. \n"
"In this case, consider to use a different implementation, which does not "
"depend on the usage of [torch.linalg.eigh][torch.linalg.eigh].\n"
f" Inspect the original exception message: \n{str(original_exception)}"
)
super().__init__(err_msg)
59 changes: 59 additions & 0 deletions src/pydvl/utils/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from functools import wraps
from typing import Callable, Type, TypeVar

CatchExceptionType = TypeVar("CatchExceptionType", bound=BaseException)
RaiseExceptionType = TypeVar("RaiseExceptionType", bound=BaseException)


def catch_and_raise_exception(
catch_exception_type: Type[CatchExceptionType],
raise_exception_factory: Callable[[CatchExceptionType], RaiseExceptionType],
) -> Callable:
"""
A decorator that catches exceptions of a specified exception type and raises
another specified exception.
Args:
catch_exception_type: The type of the exception to catch.
raise_exception_factory: A factory function that creates a new exception.
Returns:
A decorator function that wraps the target function.
??? Example
```python
@catch_and_raise_exception(RuntimeError, lambda e: TorchLinalgEighException(e))
def safe_torch_linalg_eigh(*args, **kwargs):
'''
A wrapper around `torch.linalg.eigh` that safely handles potential runtime errors
by raising a custom `TorchLinalgEighException` with more context,
especially related to the issues reported in
https://github.com/pytorch/pytorch/issues/92141.
Args:
*args: Positional arguments passed to `torch.linalg.eigh`.
**kwargs: Keyword arguments passed to `torch.linalg.eigh`.
Returns:
The result of calling `torch.linalg.eigh` with the provided arguments.
Raises:
TorchLinalgEighException: If a `RuntimeError` occurs during the execution of
`torch.linalg.eigh`.
'''
return torch.linalg.eigh(*args, **kwargs)
```
"""

def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except catch_exception_type as e:
raise raise_exception_factory(e) from e

return wrapper

return decorator
5 changes: 5 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import os
import platform
from dataclasses import asdict
from typing import TYPE_CHECKING, Optional, Tuple

Expand Down Expand Up @@ -264,3 +265,7 @@ def pytest_terminal_summary(
):
tolerate_session = terminalreporter.config._tolerate_session
tolerate_session.display(terminalreporter)


def is_osx_arm64():
return platform.system() == "Darwin" and platform.machine() == "arm64"
21 changes: 21 additions & 0 deletions tests/influence/torch/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,14 @@
lanzcos_low_rank_hessian_approx,
)
from pydvl.influence.torch.util import (
TorchLinalgEighException,
TorchTensorContainerType,
align_structure,
flatten_dimensions,
safe_torch_linalg_eigh,
torch_dataset_to_dask_array,
)
from tests.conftest import is_osx_arm64
from tests.influence.conftest import linear_hessian_analytical, linear_model


Expand Down Expand Up @@ -297,3 +300,21 @@ def are_active_layers_linear(model):
if any(param_requires_grad):
return False
return True


@pytest.mark.torch
def test_safe_torch_linalg_eigh():
t = torch.randn([10, 10])
t = t @ t.t()
safe_eigs, safe_eigvec = safe_torch_linalg_eigh(t)
eigs, eigvec = torch.linalg.eigh(t)
assert torch.allclose(safe_eigs, eigs)
assert torch.allclose(safe_eigvec, eigvec)


@pytest.mark.torch
@pytest.mark.slow
@pytest.mark.skipif(not is_osx_arm64(), reason="Requires macOS ARM64.")
def test_safe_torch_linalg_eigh_exception():
with pytest.raises(TorchLinalgEighException):
safe_torch_linalg_eigh(torch.randn([53000, 53000]))

0 comments on commit 8e211dc

Please sign in to comment.