Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions sparsecoding/priors/laplace.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import torch
from torch.distributions.laplace import Laplace

from sparsecoding.priors.common import Prior


class LaplacePrior(Prior):
"""Prior corresponding to a Laplacian distribution.

Parameters
----------
dim : int
Number of weights per sample.
scale : float
The "scale" of the Laplacian distribution (larger is wider).
positive_only : bool
Ensure that the weights are positive by taking the absolute value
of weights sampled from the Laplacian.
"""

def __init__(
self,
dim: int,
scale: float,
positive_only: bool = True,
):
if dim < 0:
raise ValueError(f"`dim` should be nonnegative, got {dim}.")
if scale <= 0:
raise ValueError(f"`scale` must be positive, got {scale}.")

self.dim = dim
self.scale = scale
self.positive_only = positive_only

self.distr = Laplace(loc=torch.tensor(0.), scale=torch.tensor(self.scale))

@property
def D(self):
return self.dim

def sample(self, num_samples: int):
weights = self.distr.rsample((num_samples, self.D))
if self.positive_only:
weights = torch.abs(weights)
return weights

def log_prob(
self,
sample: torch.Tensor,
):
super().check_sample_input(sample)

log_prob = self.distr.log_prob(sample)
if self.positive_only:
log_prob += torch.log(torch.tensor(2.))
log_prob[sample < 0.] = -torch.inf
log_prob = torch.sum(log_prob, dim=1) # [N]

return log_prob
43 changes: 15 additions & 28 deletions sparsecoding/priors/spike_slab.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
import torch
from torch.distributions.laplace import Laplace

from sparsecoding.priors.common import Prior


class SpikeSlabPrior(Prior):
"""Prior where weights are drawn i.i.d. from a "spike-and-slab" distribution.

The "spike" is at 0 and the "slab" is Laplacian.

See:
https://wesselb.github.io/assets/write-ups/Bruinsma,%20Spike%20and%20Slab%20Priors.pdf
for a good review of the spike-and-slab model.
Expand All @@ -19,31 +16,31 @@ class SpikeSlabPrior(Prior):
Number of weights per sample.
p_spike : float
The probability of the weight being 0.
scale : float
The "scale" of the Laplacian distribution (larger is wider).
positive_only : bool
Ensure that the weights are positive by taking the absolute value
of weights sampled from the Laplacian.
slab : Prior
The distribution of the "slab".
Since weights drawn from this distribution must be i.i.d.,
we enforce `slab.D` to be 1.
"""

def __init__(
self,
dim: int,
p_spike: float,
scale: float,
positive_only: bool = True,
slab: Prior,
):
if dim < 0:
raise ValueError(f"`dim` should be nonnegative, got {dim}.")
if p_spike < 0 or p_spike > 1:
raise ValueError(f"Must have 0 <= `p_spike` <= 1, got `p_spike`={p_spike}.")
if scale <= 0:
raise ValueError(f"`scale` must be positive, got {scale}.")
if slab.D != 1:
raise ValueError(
f"`slab.D` must be 1 (got {slab.D}). "
f"This enforces that can sample i.i.d. weights."
)

self.dim = dim
self.p_spike = p_spike
self.scale = scale
self.positive_only = positive_only
self.slab = slab

@property
def D(self):
Expand All @@ -53,13 +50,8 @@ def sample(self, num_samples: int):
N = num_samples

zero_weights = torch.zeros((N, self.D), dtype=torch.float32)
slab_weights = Laplace(
loc=zero_weights,
scale=torch.full((N, self.D), self.scale, dtype=torch.float32),
).sample() # [N, D]

if self.positive_only:
slab_weights = torch.abs(slab_weights)
slab_weights = self.slab.sample(num_samples * self.D)
slab_weights = slab_weights.reshape((num_samples, self.D))

spike_over_slab = torch.rand(N, self.D, dtype=torch.float32) < self.p_spike

Expand Down Expand Up @@ -89,14 +81,9 @@ def log_prob(

# Add log-probability for slab.
log_prob[slab_mask] = (
torch.log(torch.tensor(1. - self.p_spike))
- torch.log(torch.tensor(self.scale))
- torch.abs(sample[slab_mask]) / self.scale
self.slab.log_prob(sample[slab_mask].reshape(-1, 1)).reshape(-1)
+ torch.log(torch.tensor(1. - self.p_spike))
)
if self.positive_only:
log_prob[sample < 0.] = -torch.inf
else:
log_prob[slab_mask] -= torch.log(torch.tensor(2.))

log_prob = torch.sum(log_prob, dim=1) # [N]

Expand Down
8 changes: 6 additions & 2 deletions tests/inference/common.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch

from sparsecoding.priors.l0 import L0Prior
from sparsecoding.priors.laplace import LaplacePrior
from sparsecoding.priors.lsm import LSMPrior
from sparsecoding.priors.spike_slab import SpikeSlabPrior
from sparsecoding.data.datasets.bars import BarsDataset
Expand All @@ -14,8 +15,11 @@
SpikeSlabPrior(
dim=2 * PATCH_SIZE,
p_spike=0.8,
scale=1.0,
positive_only=True,
slab=LaplacePrior(
dim=1,
scale=1.0,
positive_only=True,
),
),
L0Prior(
prob_distr=(
Expand Down
76 changes: 76 additions & 0 deletions tests/priors/test_laplace.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import torch
import unittest

from sparsecoding.priors.laplace import LaplacePrior


class TestLaplacePrior(unittest.TestCase):
def test_sample(self):
N = 10000
D = 4
scale = 1.

torch.manual_seed(1997)

for positive_only in [True, False]:
laplace_prior = LaplacePrior(
D,
scale,
positive_only,
)
weights = laplace_prior.sample(N)

assert weights.shape == (N, D)

# Check Laplacian distribution.
if positive_only:
assert torch.sum(weights < 0.) == 0
else:
assert torch.allclose(
torch.sum(weights < 0.) / (N * D),
torch.sum(weights > 0.) / (N * D),
atol=2e-2,
)
weights = torch.abs(weights)

laplace_weights = weights[weights > 0.]
for quantile in torch.arange(5) / 5.:
cutoff = -torch.log(1. - quantile)
assert torch.allclose(
torch.sum(laplace_weights < cutoff) / (N * D),
quantile,
atol=1e-2,
)

def test_log_prob(self):
D = 3
scale = 1.

samples = torch.Tensor([[-1., 0., 1.]])

pos_only_log_prob = torch.tensor(-2.)

for positive_only in [True, False]:
laplace_prior = LaplacePrior(
D,
scale,
positive_only,
)

if positive_only:
assert laplace_prior.log_prob(samples)[0] == -torch.inf

samples = torch.abs(samples)
assert torch.allclose(
laplace_prior.log_prob(samples)[0],
pos_only_log_prob,
)
else:
assert torch.allclose(
laplace_prior.log_prob(samples)[0],
pos_only_log_prob - D * torch.log(torch.tensor(2.)),
)


if __name__ == "__main__":
unittest.main()
107 changes: 48 additions & 59 deletions tests/priors/test_spike_slab.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch
import unittest

from sparsecoding.priors.laplace import LaplacePrior
from sparsecoding.priors.spike_slab import SpikeSlabPrior


Expand All @@ -9,86 +10,74 @@ def test_sample(self):
N = 10000
D = 4
p_spike = 0.5
scale = 1.
slab = LaplacePrior(
dim=1,
scale=1.0,
positive_only=True,
)

torch.manual_seed(1997)

p_slab = 1. - p_spike

for positive_only in [True, False]:
spike_slab_prior = SpikeSlabPrior(
D,
p_spike,
scale,
positive_only,
)
weights = spike_slab_prior.sample(N)
spike_slab_prior = SpikeSlabPrior(
D,
p_spike,
slab,
)
weights = spike_slab_prior.sample(N)

assert weights.shape == (N, D)

assert weights.shape == (N, D)
# Check spike probability.
assert torch.allclose(
torch.sum(weights == 0.) / (N * D),
torch.tensor(p_spike),
atol=1e-2,
)

# Check Laplacian distribution.
N_slab = p_slab * N * D
assert torch.sum(weights < 0.) == 0

# Check spike probability.
laplace_weights = weights[weights > 0.]
for quantile in torch.arange(5) / 5.:
cutoff = -torch.log(1. - quantile)
assert torch.allclose(
torch.sum(weights == 0.) / (N * D),
torch.tensor(p_spike),
torch.sum(laplace_weights < cutoff) / N_slab,
quantile,
atol=1e-2,
)

# Check Laplacian distribution.
N_slab = p_slab * N * D
if positive_only:
assert torch.sum(weights < 0.) == 0
else:
assert torch.allclose(
torch.sum(weights < 0.) / N_slab,
torch.sum(weights > 0.) / N_slab,
atol=2e-2,
)
weights = torch.abs(weights)

laplace_weights = weights[weights > 0.]
for quantile in torch.arange(5) / 5.:
cutoff = -torch.log(1. - quantile)
assert torch.allclose(
torch.sum(laplace_weights < cutoff) / N_slab,
quantile,
atol=1e-2,
)

def test_log_prob(self):
D = 3
p_spike = 0.5
scale = 1.
slab = LaplacePrior(
dim=1,
scale=1.0,
positive_only=True,
)

samples = torch.Tensor([[-1., 0., 1.]])

pos_only_log_prob = (
torch.log(torch.tensor(p_spike))
+ 2 * (
-1. + torch.log(torch.tensor(1. - p_spike))
)
spike_slab_prior = SpikeSlabPrior(
D,
p_spike,
slab,
)

for positive_only in [True, False]:
spike_slab_prior = SpikeSlabPrior(
D,
p_spike,
scale,
positive_only,
)
assert spike_slab_prior.log_prob(samples)[0] == -torch.inf

if positive_only:
assert spike_slab_prior.log_prob(samples)[0] == -torch.inf

samples = torch.abs(samples)
assert torch.allclose(
spike_slab_prior.log_prob(samples)[0],
pos_only_log_prob,
)
else:
assert torch.allclose(
spike_slab_prior.log_prob(samples)[0],
pos_only_log_prob - (D - 1) * torch.log(torch.tensor(2.)),
samples = torch.abs(samples)
assert torch.allclose(
spike_slab_prior.log_prob(samples)[0],
(
torch.log(torch.tensor(p_spike))
+ 2 * (
-1. + torch.log(torch.tensor(1. - p_spike))
)
),
)


if __name__ == "__main__":
Expand Down