Skip to content

Commit 9f93432

Browse files
authored
support KL and cross-entropy semiring (#79)
1 parent d9157fc commit 9f93432

File tree

10 files changed

+206
-13
lines changed

10 files changed

+206
-13
lines changed

torch_struct/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,3 +72,4 @@
7272
CheckpointShardSemiring,
7373
TempMax,
7474
]
75+

torch_struct/cky_crf.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,7 @@
66

77
class CKY_CRF(_Struct):
88
def _check_potentials(self, edge, lengths=None):
9-
batch, N, _, NT = edge.shape
10-
edge.requires_grad_(True)
9+
batch, N, _, NT = self._get_dimension(edge)
1110
edge = self.semiring.convert(edge)
1211
if lengths is None:
1312
lengths = torch.LongTensor([N] * batch).to(edge.device)

torch_struct/deptree.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ class DepTree(_Struct):
4343
Parameters:
4444
arc_scores_in: Arc scores of shape (B, N, N) or (B, N, N, L) with root scores on
4545
diagonal.
46+
47+
Note: For single-root case, do not set cache=True for now.
4648
"""
4749

4850
def _dp(self, arc_scores_in, lengths=None, force_grad=False, cache=True):
@@ -61,7 +63,7 @@ def _dp(self, arc_scores_in, lengths=None, force_grad=False, cache=True):
6163
alpha = [
6264
[
6365
[
64-
Chart((batch, N, N), arc_scores, semiring, cache=cache)
66+
Chart((batch, N, N), arc_scores, semiring, cache=multiroot)
6567
for _ in range(2)
6668
]
6769
for _ in range(2)
@@ -113,7 +115,7 @@ def _dp(self, arc_scores_in, lengths=None, force_grad=False, cache=True):
113115

114116
def _check_potentials(self, arc_scores, lengths=None):
115117
semiring = self.semiring
116-
batch, N, N2 = arc_scores.shape[:3]
118+
batch, N, N2, *_ = self._get_dimension(arc_scores)
117119
assert N == N2, "Non-square potentials"
118120
if lengths is None:
119121
lengths = torch.LongTensor([N - 1] * batch).to(arc_scores.device)

torch_struct/distributions.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,15 @@
1111
LogSemiring,
1212
MaxSemiring,
1313
EntropySemiring,
14+
CrossEntropySemiring,
15+
KLDivergenceSemiring,
1416
MultiSampledSemiring,
1517
KMaxSemiring,
1618
StdSemiring,
1719
)
1820

1921

22+
2023
class StructDistribution(Distribution):
2124
r"""
2225
Base structured distribution class.
@@ -65,6 +68,8 @@ def log_prob(self, value):
6568
value.type_as(self.log_potentials),
6669
batch_dims=batch_dims,
6770
)
71+
72+
6873
return v - self.partition
6974

7075
@lazy_property
@@ -75,13 +80,32 @@ def entropy(self):
7580
Returns:
7681
entropy (*batch_shape*)
7782
"""
83+
7884
return self._struct(EntropySemiring).sum(self.log_potentials, self.lengths)
7985

86+
def cross_entropy(self, other):
87+
"""
88+
Compute cross-entropy for distribution p(self) and q(other) :math:`H[p, q]`.
89+
90+
Returns:
91+
cross entropy (*batch_shape*)
92+
"""
93+
94+
return self._struct(CrossEntropySemiring).sum([self.log_potentials, other.log_potentials], self.lengths)
95+
96+
def kl(self, other):
97+
"""
98+
Compute KL-divergence for distribution p(self) and q(other) :math:`KL[p || q] = H[p, q] - H[p]`.
99+
100+
Returns:
101+
cross entropy (*batch_shape*)
102+
"""
103+
return self._struct(KLDivergenceSemiring).sum([self.log_potentials, other.log_potentials], self.lengths)
104+
80105
@lazy_property
81106
def max(self):
82107
r"""
83108
Compute an max for distribution :math:`\max p(z)`.
84-
85109
Returns:
86110
max (*batch_shape*)
87111
"""
@@ -355,6 +379,7 @@ def __init__(self, log_potentials, lengths=None, args={}, multiroot=True):
355379
setattr(self.struct, "multiroot", multiroot)
356380

357381

382+
358383
class TreeCRF(StructDistribution):
359384
r"""
360385
Represents a 0th-order span parser with NT nonterminals. Implemented using a

torch_struct/helpers.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,15 @@ def _bin_length(self, length):
7979
bin_N = int(math.pow(2, log_N))
8080
return log_N, bin_N
8181

82+
def _get_dimension(self, edge):
83+
if isinstance(edge, list):
84+
for t in edge:
85+
t.requires_grad_(True)
86+
return edge[0].shape
87+
else:
88+
edge.requires_grad_(True)
89+
return edge.shape
90+
8291
def _chart(self, size, potentials, force_grad):
8392
return self._make_chart(1, size, potentials, force_grad)[0]
8493

torch_struct/linearchain.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,8 @@ class LinearChain(_Struct):
2828
"""
2929

3030
def _check_potentials(self, edge, lengths=None):
31-
batch, N_1, C, C2 = edge.shape
32-
edge.requires_grad_(True)
31+
batch, N_1, C, C2 = self._get_dimension(edge)
3332
edge = self.semiring.convert(edge)
34-
3533
N = N_1 + 1
3634

3735
if lengths is None:

torch_struct/semimarkov.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ class SemiMarkov(_Struct):
88
"""
99

1010
def _check_potentials(self, edge, lengths=None):
11-
batch, N_1, K, C, C2 = edge.shape
11+
batch, N_1, K, C, C2 = self._get_dimension(edge)
1212
edge = self.semiring.convert(edge)
1313
N = N_1 + 1
1414
if lengths is None:

torch_struct/semirings/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
KMaxSemiring,
55
MaxSemiring,
66
EntropySemiring,
7+
CrossEntropySemiring,
8+
KLDivergenceSemiring,
79
TempMax,
810
)
911

@@ -29,6 +31,8 @@
2931
SparseMaxSemiring,
3032
KMaxSemiring,
3133
EntropySemiring,
34+
CrossEntropySemiring,
35+
KLDivergenceSemiring,
3236
MultiSampledSemiring,
3337
CheckpointSemiring,
3438
CheckpointShardSemiring,

torch_struct/semirings/semirings.py

Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,150 @@ def mul(a, b):
269269
return KMaxSemiring
270270

271271

272+
class KLDivergenceSemiring(Semiring):
273+
"""
274+
Implements an KL-divergence semiring.
275+
276+
Computes both the log-values of two distributions and the running KL divergence between two distributions.
277+
278+
Based on descriptions in:
279+
280+
* Parameter estimation for probabilistic finite-state transducers :cite:`eisner2002parameter`
281+
* First-and second-order expectation semirings with applications to minimum-risk training on translation forests :cite:`li2009first`
282+
* Sample Selection for Statistical Grammar Induction :cite:`hwa2000samplesf`
283+
"""
284+
zero = 0
285+
@staticmethod
286+
def size():
287+
return 3
288+
289+
@staticmethod
290+
def convert(xs):
291+
values = torch.zeros((3,) + xs[0].shape).type_as(xs[0])
292+
values[0] = xs[0]
293+
values[1] = xs[1]
294+
values[2] = 0
295+
return values
296+
297+
@staticmethod
298+
def unconvert(xs):
299+
return xs[-1]
300+
301+
@staticmethod
302+
def sum(xs, dim=-1):
303+
assert dim != 0
304+
d = dim - 1 if dim > 0 else dim
305+
part_p = torch.logsumexp(xs[0], dim=d)
306+
part_q = torch.logsumexp(xs[1], dim=d)
307+
log_sm_p = xs[0] - part_p.unsqueeze(d)
308+
log_sm_q = xs[1] - part_q.unsqueeze(d)
309+
sm_p = log_sm_p.exp()
310+
return torch.stack((part_p, part_q, torch.sum(xs[2].mul(sm_p) - log_sm_q.mul(sm_p) + log_sm_p.mul(sm_p), dim=d)))
311+
312+
@staticmethod
313+
def mul(a, b):
314+
return torch.stack((a[0] + b[0], a[1] + b[1], a[2] + b[2]))
315+
316+
@classmethod
317+
def prod(cls, xs, dim=-1):
318+
return xs.sum(dim)
319+
320+
@classmethod
321+
def zero_mask_(cls, xs, mask):
322+
"Fill *ssize x ...* tensor with additive identity."
323+
xs[0].masked_fill_(mask, -1e5)
324+
xs[1].masked_fill_(mask, -1e5)
325+
xs[2].masked_fill_(mask, 0)
326+
327+
@staticmethod
328+
def zero_(xs):
329+
xs[0].fill_(-1e5)
330+
xs[1].fill_(-1e5)
331+
xs[2].fill_(0)
332+
return xs
333+
334+
@staticmethod
335+
def one_(xs):
336+
xs[0].fill_(0)
337+
xs[1].fill_(0)
338+
xs[2].fill_(0)
339+
return xs
340+
341+
class CrossEntropySemiring(Semiring):
342+
"""
343+
Implements an cross-entropy expectation semiring.
344+
345+
Computes both the log-values of two distributions and the running cross entropy between two distributions.
346+
347+
Based on descriptions in:
348+
349+
* Parameter estimation for probabilistic finite-state transducers :cite:`eisner2002parameter`
350+
* First-and second-order expectation semirings with applications to minimum-risk training on translation forests :cite:`li2009first`
351+
* Sample Selection for Statistical Grammar Induction :cite:`hwa2000samplesf`
352+
"""
353+
354+
zero = 0
355+
356+
@staticmethod
357+
def size():
358+
return 3
359+
360+
@staticmethod
361+
def convert(xs):
362+
values = torch.zeros((3,) + xs[0].shape).type_as(xs[0])
363+
values[0] = xs[0]
364+
values[1] = xs[1]
365+
values[2] = 0
366+
return values
367+
368+
@staticmethod
369+
def unconvert(xs):
370+
return xs[-1]
371+
372+
@staticmethod
373+
def sum(xs, dim=-1):
374+
assert dim != 0
375+
d = dim - 1 if dim > 0 else dim
376+
part_p = torch.logsumexp(xs[0], dim=d)
377+
part_q = torch.logsumexp(xs[1], dim=d)
378+
log_sm_p = xs[0] - part_p.unsqueeze(d)
379+
log_sm_q = xs[1] - part_q.unsqueeze(d)
380+
sm_p = log_sm_p.exp()
381+
return torch.stack((part_p, part_q, torch.sum(xs[2].mul(sm_p) - log_sm_q.mul(sm_p), dim=d)))
382+
383+
@staticmethod
384+
def mul(a, b):
385+
return torch.stack((a[0] + b[0], a[1] + b[1], a[2] + b[2]))
386+
387+
@classmethod
388+
def prod(cls, xs, dim=-1):
389+
return xs.sum(dim)
390+
391+
@classmethod
392+
def zero_mask_(cls, xs, mask):
393+
"Fill *ssize x ...* tensor with additive identity."
394+
xs[0].masked_fill_(mask, -1e5)
395+
xs[1].masked_fill_(mask, -1e5)
396+
xs[2].masked_fill_(mask, 0)
397+
398+
@staticmethod
399+
def zero_(xs):
400+
xs[0].fill_(-1e5)
401+
xs[1].fill_(-1e5)
402+
xs[2].fill_(0)
403+
return xs
404+
405+
@staticmethod
406+
def one_(xs):
407+
xs[0].fill_(0)
408+
xs[1].fill_(0)
409+
xs[2].fill_(0)
410+
return xs
411+
412+
413+
414+
415+
272416
class EntropySemiring(Semiring):
273417
"""
274418
Implements an entropy expectation semiring.
@@ -279,6 +423,7 @@ class EntropySemiring(Semiring):
279423
280424
* Parameter estimation for probabilistic finite-state transducers :cite:`eisner2002parameter`
281425
* First-and second-order expectation semirings with applications to minimum-risk training on translation forests :cite:`li2009first`
426+
* Sample Selection for Statistical Grammar Induction :cite:`hwa2000samplesf`
282427
"""
283428

284429
zero = 0

torch_struct/test_distributions.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,19 +21,29 @@ def test_simple(data, seed):
2121
lengths = torch.tensor(
2222
[data.draw(integers(min_value=2, max_value=N)) for b in range(batch - 1)] + [N]
2323
)
24-
2524
dist = model(vals, lengths)
2625
edges, enum_lengths = dist.enumerate_support()
27-
print(edges.shape)
2826
log_probs = dist.log_prob(edges)
2927
for b in range(lengths.shape[0]):
3028
log_probs[enum_lengths[b] :, b] = -1e9
31-
3229
assert torch.isclose(log_probs.exp().sum(0), torch.tensor(1.0)).all()
33-
3430
entropy = dist.entropy
3531
assert torch.isclose(entropy, -log_probs.exp().mul(log_probs).sum(0)).all()
3632

33+
vals2 = torch.rand(*vals.shape)
34+
dist2 = model(vals2, lengths)
35+
36+
cross_entropy = dist.cross_entropy(other=dist2)
37+
kl = dist.kl(other=dist2)
38+
39+
edges2, enum_lengths2 = dist2.enumerate_support()
40+
log_probs2 = dist2.log_prob(edges2)
41+
for b in range(lengths.shape[0]):
42+
log_probs2[enum_lengths2[b] :, b] = -1e9
43+
44+
assert torch.isclose(cross_entropy, -log_probs.exp().mul(log_probs2).sum(0)).all()
45+
assert torch.isclose(kl, -log_probs.exp().mul(log_probs2-log_probs).sum(0)).all()
46+
3747
argmax = dist.argmax
3848
_, max_indices = log_probs.max(0)
3949

0 commit comments

Comments
 (0)