Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 97 additions & 0 deletions sparsecoding/priors/lsm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import torch
from torch.distributions.laplace import Laplace
from torch.distributions.gamma import Gamma

from sparsecoding.priors.common import Prior


class LSMPrior(Prior):
"""Prior where weights are drawn from i.i.d. from Laplacian scale mixtures.

The Laplacian scale mixture is defined in:
Garrigues & Olshausen (2010)
https://papers.nips.cc/paper/2010/hash/2d6cc4b2d139a53512fb8cbb3086ae2e-Abstract.html
.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dot


Conceptually, a Laplacian scale mixture is just a weighted sum of Laplacian distributions
with different scales.

In the paper, a Gamma distribution over:
the inverse of the scale parameter of the Laplacian
is used,
as that is the conjugate prior.

Parameters
----------
dim : int
Number of weights per sample.
alpha : float
Shape or concentration parameter of the Gamma distribution
over the Laplacian's scale.
beta : float
Rate or inverse scale parameter of the Gamma distribution
over the Laplacian's scale.
positive_only : bool
Ensure that the weights are positive by taking the absolute value
of weights sampled from the Laplacian.
"""

def __init__(
self,
dim: int,
alpha: float,
beta: float,
positive_only: bool = True,
):
if dim < 0:
raise ValueError(f"`dim` should be nonnegative, got {dim}.")
if alpha <= 0:
raise ValueError(f"Must have alpha > 0, got `alpha`={alpha}.")
if beta <= 0:
raise ValueError(f"Must have beta > 0, got `beta`={beta}.")

self.dim = dim
self.alpha = alpha
self.beta = beta
self.positive_only = positive_only

self.gamma_distr = Gamma(self.alpha, self.beta)

@property
def D(self):
return self.dim

def sample(self, num_samples: int):
N = num_samples

inverse_lambdas = self.gamma_distr.sample((N, self.D))

weights = Laplace(
loc=torch.zeros((N, self.D), dtype=torch.float32),
scale=1. / inverse_lambdas,
).sample()

if self.positive_only:
weights = torch.abs(weights)

return weights

def log_prob(
self,
sample: torch.Tensor,
):
super().check_sample_input(sample)

log_prob = (
torch.log(torch.tensor(self.alpha))
+ self.alpha * torch.log(torch.tensor(self.beta))
- (self.alpha + 1) * torch.log(self.beta + torch.abs(sample))
) # [N, D]
if self.positive_only:
log_prob[sample < 0.] = -torch.inf
else:
log_prob -= torch.log(torch.tensor(2.))

log_prob = torch.sum(log_prob, dim=1) # [N]

return log_prob
18 changes: 7 additions & 11 deletions sparsecoding/priors/spike_slab.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@


class SpikeSlabPrior(Prior):
"""Prior where weights are drawn from a "spike-and-slab" distribution.
"""Prior where weights are drawn i.i.d. from a "spike-and-slab" distribution.

The "spike" is at 0 and the "slab" is Laplacian.

Expand Down Expand Up @@ -88,19 +88,15 @@ def log_prob(
log_prob[spike_mask] = torch.log(torch.tensor(self.p_spike))

# Add log-probability for slab.
log_prob[slab_mask] = (
torch.log(torch.tensor(1. - self.p_spike))
- torch.log(torch.tensor(self.scale))
- torch.abs(sample[slab_mask]) / self.scale
)
if self.positive_only:
log_prob[slab_mask] = (
torch.log(torch.tensor(1. - self.p_spike))
- torch.log(torch.tensor(self.scale))
- sample[slab_mask] / self.scale
)
log_prob[sample < 0.] = -torch.inf
else:
log_prob[slab_mask] = (
torch.log(torch.tensor(1. - self.p_spike))
- torch.log(torch.tensor(2. * self.scale))
- torch.abs(sample[slab_mask]) / self.scale
)
log_prob[slab_mask] -= torch.log(torch.tensor(2.))

log_prob = torch.sum(log_prob, dim=1) # [N]

Expand Down
7 changes: 7 additions & 0 deletions tests/inference/common.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch

from sparsecoding.priors.l0 import L0Prior
from sparsecoding.priors.lsm import LSMPrior
from sparsecoding.priors.spike_slab import SpikeSlabPrior
from sparsecoding.data.datasets.bars import BarsDataset

Expand All @@ -24,6 +25,12 @@
).type(torch.float32)
),
),
LSMPrior(
dim=2 * PATCH_SIZE,
alpha=80.0,
beta=0.02,
positive_only=False,
),
]

DATASET = [
Expand Down
2 changes: 1 addition & 1 deletion tests/inference/test_LSM.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def test_inference(self):

a = inference_method.infer(data, DICTIONARY)

self.assertAllClose(a, dataset.weights, atol=5e-2)
self.assertAllClose(a, dataset.weights, atol=7.5e-2)


if __name__ == "__main__":
Expand Down
100 changes: 100 additions & 0 deletions tests/priors/test_lsm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import numpy as np
import torch
import unittest

from sparsecoding.priors.lsm import LSMPrior


class TestLSMPrior(unittest.TestCase):
def test_sample(self):
N = 10000
D = 4
alpha = 2
beta = 2

torch.manual_seed(1997)

for positive_only in [True, False]:
lsm_prior = LSMPrior(
D,
alpha,
beta,
positive_only,
)
weights = lsm_prior.sample(N)

assert weights.shape == (N, D)

# Check distribution.
if positive_only:
assert torch.sum(weights < 0.) == 0
else:
assert torch.allclose(
torch.sum(weights < 0.) / (N * D),
torch.sum(weights > 0.) / (N * D),
atol=2e-2,
)
weights = torch.abs(weights)

# Note:
# Antiderivative of positive-only is:
# -Beta^alpha * (Beta + x)^(-alpha),
# cdf is:
# 1. - Beta^alpha * (B + x)^(-alpha),
# quantile fn is:
# -Beta + exp((log(1-y) - alpha*log(Beta)) / -alpha)

for quantile in torch.arange(5) / 5.:
cutoff = (
-beta
+ np.exp(
(np.log(1. - quantile) - alpha * np.log(beta))
/ (-alpha)
)
)
assert torch.allclose(
torch.sum(weights < cutoff) / (N * D),
quantile,
atol=1e-2,
)

def test_log_prob(self):
D = 3
alpha = 2
beta = 2

samples = torch.Tensor([[-1., 0., 1.]])

pos_only_log_prob = (
torch.log(torch.tensor(alpha)) - torch.log(torch.tensor(beta))
+ 2 * (
torch.log(torch.tensor(alpha)) + alpha * torch.log(torch.tensor(beta))
- (alpha + 1) * torch.log(torch.tensor(1 + beta))
)
)

for positive_only in [True, False]:
lsm_prior = LSMPrior(
D,
alpha,
beta,
positive_only,
)

if positive_only:
assert lsm_prior.log_prob(samples)[0] == -torch.inf

samples = torch.abs(samples)
assert torch.allclose(
lsm_prior.log_prob(samples)[0],
pos_only_log_prob,
)
else:
assert torch.allclose(
lsm_prior.log_prob(samples)[0],
pos_only_log_prob - D * torch.log(torch.tensor(2.)),
)


if __name__ == "__main__":
unittest.main()
23 changes: 11 additions & 12 deletions tests/priors/test_spike_slab.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,15 @@ def test_log_prob(self):
p_spike = 0.5
scale = 1.

samples = torch.Tensor([[-1., 0., 1.]])

pos_only_log_prob = (
torch.log(torch.tensor(p_spike))
+ 2 * (
-1. + torch.log(torch.tensor(1. - p_spike))
)
)

for positive_only in [True, False]:
spike_slab_prior = SpikeSlabPrior(
D,
Expand All @@ -67,28 +76,18 @@ def test_log_prob(self):
positive_only,
)

samples = torch.Tensor([[-1., 0., 1.]])

if positive_only:
assert spike_slab_prior.log_prob(samples)[0] == -torch.inf

samples = torch.abs(samples)
assert torch.allclose(
spike_slab_prior.log_prob(samples)[0],
(
-1. + torch.log(torch.tensor(1. - p_spike))
+ torch.log(torch.tensor(p_spike))
- 1. + torch.log(torch.tensor(1. - p_spike))
)
pos_only_log_prob,
)
else:
assert torch.allclose(
spike_slab_prior.log_prob(samples)[0],
(
-1. + torch.log(torch.tensor(1. - p_spike)) - torch.log(torch.tensor(2.))
+ torch.log(torch.tensor(p_spike))
- 1. + torch.log(torch.tensor(1. - p_spike)) - torch.log(torch.tensor(2.))
)
pos_only_log_prob - (D - 1) * torch.log(torch.tensor(2.)),
)


Expand Down