Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions torch_struct/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,4 @@
CheckpointShardSemiring,
TempMax,
]

3 changes: 1 addition & 2 deletions torch_struct/cky_crf.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@

class CKY_CRF(_Struct):
def _check_potentials(self, edge, lengths=None):
batch, N, _, NT = edge.shape
edge.requires_grad_(True)
batch, N, _, NT = self._get_dimension(edge)
edge = self.semiring.convert(edge)
if lengths is None:
lengths = torch.LongTensor([N] * batch).to(edge.device)
Expand Down
6 changes: 4 additions & 2 deletions torch_struct/deptree.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ class DepTree(_Struct):
Parameters:
arc_scores_in: Arc scores of shape (B, N, N) or (B, N, N, L) with root scores on
diagonal.

Note: For single-root case, do not set cache=True for now.
"""

def _dp(self, arc_scores_in, lengths=None, force_grad=False, cache=True):
Expand All @@ -61,7 +63,7 @@ def _dp(self, arc_scores_in, lengths=None, force_grad=False, cache=True):
alpha = [
[
[
Chart((batch, N, N), arc_scores, semiring, cache=cache)
Chart((batch, N, N), arc_scores, semiring, cache=multiroot)
for _ in range(2)
]
for _ in range(2)
Expand Down Expand Up @@ -113,7 +115,7 @@ def _dp(self, arc_scores_in, lengths=None, force_grad=False, cache=True):

def _check_potentials(self, arc_scores, lengths=None):
semiring = self.semiring
batch, N, N2 = arc_scores.shape[:3]
batch, N, N2, *_ = self._get_dimension(arc_scores)
assert N == N2, "Non-square potentials"
if lengths is None:
lengths = torch.LongTensor([N - 1] * batch).to(arc_scores.device)
Expand Down
27 changes: 26 additions & 1 deletion torch_struct/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,15 @@
LogSemiring,
MaxSemiring,
EntropySemiring,
CrossEntropySemiring,
KLDivergenceSemiring,
MultiSampledSemiring,
KMaxSemiring,
StdSemiring,
)



class StructDistribution(Distribution):
r"""
Base structured distribution class.
Expand Down Expand Up @@ -65,6 +68,8 @@ def log_prob(self, value):
value.type_as(self.log_potentials),
batch_dims=batch_dims,
)


return v - self.partition

@lazy_property
Expand All @@ -75,13 +80,32 @@ def entropy(self):
Returns:
entropy (*batch_shape*)
"""

return self._struct(EntropySemiring).sum(self.log_potentials, self.lengths)

def cross_entropy(self, other):
"""
Compute cross-entropy for distribution p(self) and q(other) :math:`H[p, q]`.

Returns:
cross entropy (*batch_shape*)
"""

return self._struct(CrossEntropySemiring).sum([self.log_potentials, other.log_potentials], self.lengths)

def kl(self, other):
"""
Compute KL-divergence for distribution p(self) and q(other) :math:`KL[p || q] = H[p, q] - H[p]`.

Returns:
cross entropy (*batch_shape*)
"""
return self._struct(KLDivergenceSemiring).sum([self.log_potentials, other.log_potentials], self.lengths)

@lazy_property
def max(self):
r"""
Compute an max for distribution :math:`\max p(z)`.

Returns:
max (*batch_shape*)
"""
Expand Down Expand Up @@ -355,6 +379,7 @@ def __init__(self, log_potentials, lengths=None, args={}, multiroot=True):
setattr(self.struct, "multiroot", multiroot)



class TreeCRF(StructDistribution):
r"""
Represents a 0th-order span parser with NT nonterminals. Implemented using a
Expand Down
9 changes: 9 additions & 0 deletions torch_struct/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,15 @@ def _bin_length(self, length):
bin_N = int(math.pow(2, log_N))
return log_N, bin_N

def _get_dimension(self, edge):
if isinstance(edge, list):
for t in edge:
t.requires_grad_(True)
return edge[0].shape
else:
edge.requires_grad_(True)
return edge.shape

def _chart(self, size, potentials, force_grad):
return self._make_chart(1, size, potentials, force_grad)[0]

Expand Down
4 changes: 1 addition & 3 deletions torch_struct/linearchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,8 @@ class LinearChain(_Struct):
"""

def _check_potentials(self, edge, lengths=None):
batch, N_1, C, C2 = edge.shape
edge.requires_grad_(True)
batch, N_1, C, C2 = self._get_dimension(edge)
edge = self.semiring.convert(edge)

N = N_1 + 1

if lengths is None:
Expand Down
2 changes: 1 addition & 1 deletion torch_struct/semimarkov.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ class SemiMarkov(_Struct):
"""

def _check_potentials(self, edge, lengths=None):
batch, N_1, K, C, C2 = edge.shape
batch, N_1, K, C, C2 = self._get_dimension(edge)
edge = self.semiring.convert(edge)
N = N_1 + 1
if lengths is None:
Expand Down
4 changes: 4 additions & 0 deletions torch_struct/semirings/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
KMaxSemiring,
MaxSemiring,
EntropySemiring,
CrossEntropySemiring,
KLDivergenceSemiring,
TempMax,
)

Expand All @@ -29,6 +31,8 @@
SparseMaxSemiring,
KMaxSemiring,
EntropySemiring,
CrossEntropySemiring,
KLDivergenceSemiring,
MultiSampledSemiring,
CheckpointSemiring,
CheckpointShardSemiring,
Expand Down
145 changes: 145 additions & 0 deletions torch_struct/semirings/semirings.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,150 @@ def mul(a, b):
return KMaxSemiring


class KLDivergenceSemiring(Semiring):
"""
Implements an KL-divergence semiring.

Computes both the log-values of two distributions and the running KL divergence between two distributions.

Based on descriptions in:

* Parameter estimation for probabilistic finite-state transducers :cite:`eisner2002parameter`
* First-and second-order expectation semirings with applications to minimum-risk training on translation forests :cite:`li2009first`
* Sample Selection for Statistical Grammar Induction :cite:`hwa2000samplesf`
"""
zero = 0
@staticmethod
def size():
return 3

@staticmethod
def convert(xs):
values = torch.zeros((3,) + xs[0].shape).type_as(xs[0])
values[0] = xs[0]
values[1] = xs[1]
values[2] = 0
return values

@staticmethod
def unconvert(xs):
return xs[-1]

@staticmethod
def sum(xs, dim=-1):
assert dim != 0
d = dim - 1 if dim > 0 else dim
part_p = torch.logsumexp(xs[0], dim=d)
part_q = torch.logsumexp(xs[1], dim=d)
log_sm_p = xs[0] - part_p.unsqueeze(d)
log_sm_q = xs[1] - part_q.unsqueeze(d)
sm_p = log_sm_p.exp()
return torch.stack((part_p, part_q, torch.sum(xs[2].mul(sm_p) - log_sm_q.mul(sm_p) + log_sm_p.mul(sm_p), dim=d)))

@staticmethod
def mul(a, b):
return torch.stack((a[0] + b[0], a[1] + b[1], a[2] + b[2]))

@classmethod
def prod(cls, xs, dim=-1):
return xs.sum(dim)

@classmethod
def zero_mask_(cls, xs, mask):
"Fill *ssize x ...* tensor with additive identity."
xs[0].masked_fill_(mask, -1e5)
xs[1].masked_fill_(mask, -1e5)
xs[2].masked_fill_(mask, 0)

@staticmethod
def zero_(xs):
xs[0].fill_(-1e5)
xs[1].fill_(-1e5)
xs[2].fill_(0)
return xs

@staticmethod
def one_(xs):
xs[0].fill_(0)
xs[1].fill_(0)
xs[2].fill_(0)
return xs

class CrossEntropySemiring(Semiring):
"""
Implements an cross-entropy expectation semiring.

Computes both the log-values of two distributions and the running cross entropy between two distributions.

Based on descriptions in:

* Parameter estimation for probabilistic finite-state transducers :cite:`eisner2002parameter`
* First-and second-order expectation semirings with applications to minimum-risk training on translation forests :cite:`li2009first`
* Sample Selection for Statistical Grammar Induction :cite:`hwa2000samplesf`
"""

zero = 0

@staticmethod
def size():
return 3

@staticmethod
def convert(xs):
values = torch.zeros((3,) + xs[0].shape).type_as(xs[0])
values[0] = xs[0]
values[1] = xs[1]
values[2] = 0
return values

@staticmethod
def unconvert(xs):
return xs[-1]

@staticmethod
def sum(xs, dim=-1):
assert dim != 0
d = dim - 1 if dim > 0 else dim
part_p = torch.logsumexp(xs[0], dim=d)
part_q = torch.logsumexp(xs[1], dim=d)
log_sm_p = xs[0] - part_p.unsqueeze(d)
log_sm_q = xs[1] - part_q.unsqueeze(d)
sm_p = log_sm_p.exp()
return torch.stack((part_p, part_q, torch.sum(xs[2].mul(sm_p) - log_sm_q.mul(sm_p), dim=d)))

@staticmethod
def mul(a, b):
return torch.stack((a[0] + b[0], a[1] + b[1], a[2] + b[2]))

@classmethod
def prod(cls, xs, dim=-1):
return xs.sum(dim)

@classmethod
def zero_mask_(cls, xs, mask):
"Fill *ssize x ...* tensor with additive identity."
xs[0].masked_fill_(mask, -1e5)
xs[1].masked_fill_(mask, -1e5)
xs[2].masked_fill_(mask, 0)

@staticmethod
def zero_(xs):
xs[0].fill_(-1e5)
xs[1].fill_(-1e5)
xs[2].fill_(0)
return xs

@staticmethod
def one_(xs):
xs[0].fill_(0)
xs[1].fill_(0)
xs[2].fill_(0)
return xs





class EntropySemiring(Semiring):
"""
Implements an entropy expectation semiring.
Expand All @@ -279,6 +423,7 @@ class EntropySemiring(Semiring):

* Parameter estimation for probabilistic finite-state transducers :cite:`eisner2002parameter`
* First-and second-order expectation semirings with applications to minimum-risk training on translation forests :cite:`li2009first`
* Sample Selection for Statistical Grammar Induction :cite:`hwa2000samplesf`
"""

zero = 0
Expand Down
18 changes: 14 additions & 4 deletions torch_struct/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,29 @@ def test_simple(data, seed):
lengths = torch.tensor(
[data.draw(integers(min_value=2, max_value=N)) for b in range(batch - 1)] + [N]
)

dist = model(vals, lengths)
edges, enum_lengths = dist.enumerate_support()
print(edges.shape)
log_probs = dist.log_prob(edges)
for b in range(lengths.shape[0]):
log_probs[enum_lengths[b] :, b] = -1e9

assert torch.isclose(log_probs.exp().sum(0), torch.tensor(1.0)).all()

entropy = dist.entropy
assert torch.isclose(entropy, -log_probs.exp().mul(log_probs).sum(0)).all()

vals2 = torch.rand(*vals.shape)
dist2 = model(vals2, lengths)

cross_entropy = dist.cross_entropy(other=dist2)
kl = dist.kl(other=dist2)

edges2, enum_lengths2 = dist2.enumerate_support()
log_probs2 = dist2.log_prob(edges2)
for b in range(lengths.shape[0]):
log_probs2[enum_lengths2[b] :, b] = -1e9

assert torch.isclose(cross_entropy, -log_probs.exp().mul(log_probs2).sum(0)).all()
assert torch.isclose(kl, -log_probs.exp().mul(log_probs2-log_probs).sum(0)).all()

argmax = dist.argmax
_, max_indices = log_probs.max(0)

Expand Down