Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,12 @@ python setup.py install
`StochMan` includes a number of modules that each defines a set of functionalities for
working with manifold data.

### `stochman.nnj`: `torch.nn` with jacobians
### `stochman.nnj`: `torch.nn` with Jacobians

Key to working with Riemannian geometry is the ability to compute jacobians. The jacobian matrix
Key to working with Riemannian geometry is the ability to compute Jacobians. The Jacobian matrix
contains the first order partial derivatives. `stochman.nnj` provides plug-in replacements for the many
used `torch.nn` layers such as `Linear`, `BatchNorm1d` etc. and commonly used activation functions such as `ReLU`,
`Sigmoid` etc. that enables fast computations of jacobians between the input to the layer and the output.
`Sigmoid` etc. that enables fast computations of Jacobians between the input to the layer and the output.

``` python
import torch
Expand All @@ -47,7 +47,7 @@ model = nnj.Sequential(nnj.Linear(10, 5),
x = torch.randn(100, 10)
y, J = model(x, jacobian=True)
print(y.shape) # output from model: torch.size([100, 5])
print(J.shape) # jacobian between input and output: torch.size([100, 5, 10])
print(J.shape) # Jacobian between input and output: torch.size([100, 5, 10])
```

### `stochman.manifold`: Interface for working with Riemannian manifolds
Expand Down
12 changes: 6 additions & 6 deletions stochman/nnj.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,13 +73,13 @@ def _jacobian_wrt_weight_sandwich(


def identity(x: Tensor) -> Tensor:
"""Function that for a given input x returns the corresponding identity jacobian matrix"""
"""Function that for a given input x returns the corresponding identity Jacobian matrix"""
m = Identity()
return m(x, jacobian=True)[1]


class Sequential(nn.Sequential):
"""Subclass of sequential that also supports calculating the jacobian through an network"""
"""Subclass of sequential that also supports calculating the Jacobian through an network"""

def forward(
self, x: Tensor, jacobian: Union[Tensor, bool] = False
Expand Down Expand Up @@ -414,7 +414,7 @@ def _jacobian_wrt_weight(self, x: Tensor, val: Tensor) -> Tensor:

reversed_inputs = torch.flip(x, [-2, -1]).movedim(0, 1)

# convolve each base element and compute the jacobian
# convolve each base element and compute the Jacobian
jacobian = (
F.conv_transpose2d(
output_identity.movedim((1, 2, 3), (-3, -2, -1)).reshape(-1, c1, kernel_h, kernel_w),
Expand Down Expand Up @@ -920,21 +920,21 @@ def _jacobian_wrt_input_mult_left_vec(self, x: Tensor, val: Tensor, jac_in: Tens


class BatchNorm1d(AbstractActivationJacobian, nn.BatchNorm1d):
# only implements jacobian during testing
# only implements Jacobian during testing
def _jacobian(self, x: Tensor, val: Tensor) -> Tensor:
jac = (self.weight / (self.running_var + self.eps).sqrt()).unsqueeze(0)
return jac


class BatchNorm2d(AbstractActivationJacobian, nn.BatchNorm2d):
# only implements jacobian during testing
# only implements Jacobian during testing
def _jacobian(self, x: Tensor, val: Tensor) -> Tensor:
jac = (self.weight / (self.running_var + self.eps).sqrt()).unsqueeze(0)
return jac


class BatchNorm3d(AbstractActivationJacobian, nn.BatchNorm3d):
# only implements jacobian during testing
# only implements Jacobian during testing
def _jacobian(self, x: Tensor, val: Tensor) -> Tensor:
jac = (self.weight / (self.running_var + self.eps).sqrt()).unsqueeze(0)
return jac
Expand Down
6 changes: 3 additions & 3 deletions tests/test_nnj.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@


def _compare_jacobian(f: Callable, x: torch.Tensor) -> torch.Tensor:
"""Use pytorch build-in jacobian function to compare for correctness of computations"""
"""Use pytorch build-in Jacobian function to compare for correctness of computations"""
out = f(x)
output = torch.autograd.functional.jacobian(f, x)
m = out.ndim
Expand Down Expand Up @@ -139,7 +139,7 @@ def _compare_jacobian(f: Callable, x: torch.Tensor) -> torch.Tensor:
class TestJacobian:
@pytest.mark.parametrize("dtype", [torch.float, torch.double])
def test_jacobians(self, model, input_shape, device, dtype):
"""Test that the analytical jacobian of the model is consistent with finite
"""Test that the analytical Jacobian of the model is consistent with finite
order approximation
"""
if "cuda" in device and not torch.cuda.is_available():
Expand All @@ -149,7 +149,7 @@ def test_jacobians(self, model, input_shape, device, dtype):
input = torch.randn(*input_shape, device=device, dtype=dtype)
_, jac = model(input, jacobian=True)
jacnum = _compare_jacobian(model, input).to(device)
assert torch.isclose(jac, jacnum, atol=1e-3).all(), "jacobians did not match"
assert torch.isclose(jac, jacnum, atol=1e-3).all(), "Jacobians did not match"

@pytest.mark.parametrize("return_jac", [True, False])
def test_jac_return(self, model, input_shape, device, return_jac):
Expand Down