Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 71 additions & 23 deletions activations_plus/sparsemax/sparsemax_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,25 +58,7 @@ def forward(ctx: Any, x: torch.Tensor, dim: int = -1) -> torch.Tensor:
# Translate by max for numerical stability
x = x - x.max(dim=reduce_dim, keepdim=True).values.expand_as(x)

zs = x.sort(dim=reduce_dim, descending=True).values
d = x.size(reduce_dim)
range_th = torch.arange(1, d + 1, device=x.device, dtype=x.dtype)
shape = [1] * x.dim()
shape[reduce_dim] = d
range_th = range_th.view(*shape).expand_as(x)

# Determine sparsity of projection
bound = 1 + range_th * zs
cumsum_zs = zs.cumsum(dim=reduce_dim)
is_gt = bound.gt(cumsum_zs).type(x.dtype)
k = (is_gt * range_th).max(dim=reduce_dim, keepdim=True).values

# Compute threshold
zs_sparse = is_gt * zs
taus = (zs_sparse.sum(dim=reduce_dim, keepdim=True) - 1) / k
taus = taus.expand_as(x)

output = torch.max(torch.zeros_like(x), x - taus)
output, ctx = SparsemaxFunction._threshold_and_support(ctx, x, reduce_dim)

# Save context
ctx.save_for_backward(output)
Expand Down Expand Up @@ -116,12 +98,78 @@ def backward(ctx: Any, grad_output: torch.Tensor) -> tuple[torch.Tensor, None]:
else:
reduce_dim = ctx.dim

nonzeros = torch.ne(output, 0)
num_nonzeros = nonzeros.sum(dim=reduce_dim, keepdim=True)
sum_all = (grad_output * nonzeros).sum(dim=reduce_dim, keepdim=True) / num_nonzeros
grad_input = nonzeros * (grad_output - sum_all.expand_as(grad_output))
grad_input = SparsemaxFunction._compute_gradient(ctx, grad_output, output, reduce_dim)

if ctx.needs_reshaping:
ctx, grad_input = unflatten_all_but_nth_dim(ctx, grad_input)

return grad_input, None

@staticmethod
def _threshold_and_support(ctx: Any, x: torch.Tensor, reduce_dim: int) -> tuple[torch.Tensor, Any]:
"""Compute the threshold and support for the input tensor.

Parameters
----------
ctx : Any
Context object for autograd.
x : torch.Tensor
Input tensor.
reduce_dim : int
Dimension along which to compute threshold/support.

Returns
-------
tuple[torch.Tensor, Any]
The output tensor after applying Sparsemax and the updated context.

"""
zs = x.sort(dim=reduce_dim, descending=True).values
d = x.size(reduce_dim)
range_th = torch.arange(1, d + 1, device=x.device, dtype=x.dtype)
shape = [1] * x.dim()
shape[reduce_dim] = d
range_th = range_th.view(*shape).expand_as(x)

# Determine sparsity of projection
bound = 1 + range_th * zs
cumsum_zs = zs.cumsum(dim=reduce_dim)
is_gt = bound.gt(cumsum_zs).type(x.dtype)
k = (is_gt * range_th).max(dim=reduce_dim, keepdim=True).values

# Compute threshold
zs_sparse = is_gt * zs
taus = (zs_sparse.sum(dim=reduce_dim, keepdim=True) - 1) / k
taus = taus.expand_as(x)

output = torch.max(torch.zeros_like(x), x - taus)

return output, ctx

@staticmethod
def _compute_gradient(ctx: Any, grad_output: torch.Tensor, output: torch.Tensor, reduce_dim: int) -> torch.Tensor:
"""Compute the gradient for the backward pass.

Parameters
----------
ctx : Any
Context object for autograd.
grad_output : torch.Tensor
Gradient of the loss with respect to the output.
output : torch.Tensor
Output tensor from the forward pass.
reduce_dim : int
Dimension along which to compute the gradient.

Returns
-------
torch.Tensor
The gradient with respect to the input.

"""
nonzeros = torch.ne(output, 0)
num_nonzeros = nonzeros.sum(dim=reduce_dim, keepdim=True)
sum_all = (grad_output * nonzeros).sum(dim=reduce_dim, keepdim=True) / num_nonzeros
grad_input = nonzeros * (grad_output - sum_all.expand_as(grad_output))

return grad_input
40 changes: 40 additions & 0 deletions tests/sparsemax/test_sparsemax_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch

from activations_plus.sparsemax import SparsemaxFunction
from activations_plus.sparsemax.utils import flatten_all_but_nth_dim, unflatten_all_but_nth_dim


def test_sparsemax_forward_valid_input():
Expand Down Expand Up @@ -115,3 +116,42 @@ def test_sparsemax_backward_parametrized(x):
assert torch.allclose(grad_sum, torch.zeros_like(grad_sum), atol=1e-5), (
"Gradients should sum to zero along the projection dimension"
)


def test_flatten_all_but_nth_dim():
x = torch.randn(2, 3, 4, 5)
ctx = type('', (), {})() # Create an empty context object
ctx.dim = 1
ctx, flattened_x = flatten_all_but_nth_dim(ctx, x)
assert flattened_x.shape == (3, 40), "Flattened shape is incorrect"
assert ctx.original_size == x.size(), "Original size not saved correctly in context"


def test_unflatten_all_but_nth_dim():
x = torch.randn(3, 40)
ctx = type('', (), {})() # Create an empty context object
ctx.dim = 1
ctx.original_size = (2, 3, 4, 5)
ctx, unflattened_x = unflatten_all_but_nth_dim(ctx, x)
assert unflattened_x.shape == (2, 3, 4, 5), "Unflattened shape is incorrect"


def test_threshold_and_support():
x = torch.tensor([[1.0, 2.0, 3.0], [0.5, 0.5, 0.5]], dtype=torch.float32)
ctx = type('', (), {})() # Create an empty context object
ctx.dim = 1
output, ctx = SparsemaxFunction._threshold_and_support(ctx, x, ctx.dim)
assert output is not None, "Output should not be None"
assert output.shape == x.shape, "Output shape should match input shape"
assert torch.all(output >= 0), "Output should have non-negative values"


def test_compute_gradient():
grad_output = torch.tensor([[1.0, 2.0, 3.0], [0.5, 0.5, 0.5]], dtype=torch.float32)
output = torch.tensor([[0.2, 0.3, 0.5], [0.1, 0.1, 0.3]], dtype=torch.float32)
ctx = type('', (), {})() # Create an empty context object
ctx.dim = 1
grad_input = SparsemaxFunction._compute_gradient(ctx, grad_output, output, ctx.dim)
assert grad_input is not None, "Gradient input should not be None"
assert grad_input.shape == grad_output.shape, "Gradient input shape should match grad output shape"
assert torch.all(torch.isfinite(grad_input)), "Gradient input should not contain NaN or Inf"
2 changes: 0 additions & 2 deletions tests/sparsemax/test_sparsemax_pb.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,6 @@ def test_sparsemax_backward_pb(random_data, dim):
),
dim=st.integers(min_value=-1, max_value=0),
)
@pytest.mark.skip
def test_sparsemax_v2_threshold_and_support(random_data, dim):
x = torch.tensor(random_data, dtype=torch.double)
tau, support_size = SparsemaxFunction._threshold_and_support(x, dim=dim)
Expand All @@ -112,7 +111,6 @@ def test_sparsemax_v2_threshold_and_support(random_data, dim):
elements=st.floats(min_value=-1000, max_value=1000, allow_nan=False, allow_infinity=False),
),
)
@pytest.mark.skip
def test_compare_with_original(random_data):
x = torch.tensor(random_data, dtype=torch.double, requires_grad=True)
for dim in range(-1, x.dim()):
Expand Down
Loading