@@ -269,6 +269,150 @@ def mul(a, b):
269269 return KMaxSemiring
270270
271271
272+ class KLDivergenceSemiring (Semiring ):
273+ """
274+ Implements an KL-divergence semiring.
275+
276+ Computes both the log-values of two distributions and the running KL divergence between two distributions.
277+
278+ Based on descriptions in:
279+
280+ * Parameter estimation for probabilistic finite-state transducers :cite:`eisner2002parameter`
281+ * First-and second-order expectation semirings with applications to minimum-risk training on translation forests :cite:`li2009first`
282+ * Sample Selection for Statistical Grammar Induction :cite:`hwa2000samplesf`
283+ """
284+ zero = 0
285+ @staticmethod
286+ def size ():
287+ return 3
288+
289+ @staticmethod
290+ def convert (xs ):
291+ values = torch .zeros ((3 ,) + xs [0 ].shape ).type_as (xs [0 ])
292+ values [0 ] = xs [0 ]
293+ values [1 ] = xs [1 ]
294+ values [2 ] = 0
295+ return values
296+
297+ @staticmethod
298+ def unconvert (xs ):
299+ return xs [- 1 ]
300+
301+ @staticmethod
302+ def sum (xs , dim = - 1 ):
303+ assert dim != 0
304+ d = dim - 1 if dim > 0 else dim
305+ part_p = torch .logsumexp (xs [0 ], dim = d )
306+ part_q = torch .logsumexp (xs [1 ], dim = d )
307+ log_sm_p = xs [0 ] - part_p .unsqueeze (d )
308+ log_sm_q = xs [1 ] - part_q .unsqueeze (d )
309+ sm_p = log_sm_p .exp ()
310+ return torch .stack ((part_p , part_q , torch .sum (xs [2 ].mul (sm_p ) - log_sm_q .mul (sm_p ) + log_sm_p .mul (sm_p ), dim = d )))
311+
312+ @staticmethod
313+ def mul (a , b ):
314+ return torch .stack ((a [0 ] + b [0 ], a [1 ] + b [1 ], a [2 ] + b [2 ]))
315+
316+ @classmethod
317+ def prod (cls , xs , dim = - 1 ):
318+ return xs .sum (dim )
319+
320+ @classmethod
321+ def zero_mask_ (cls , xs , mask ):
322+ "Fill *ssize x ...* tensor with additive identity."
323+ xs [0 ].masked_fill_ (mask , - 1e5 )
324+ xs [1 ].masked_fill_ (mask , - 1e5 )
325+ xs [2 ].masked_fill_ (mask , 0 )
326+
327+ @staticmethod
328+ def zero_ (xs ):
329+ xs [0 ].fill_ (- 1e5 )
330+ xs [1 ].fill_ (- 1e5 )
331+ xs [2 ].fill_ (0 )
332+ return xs
333+
334+ @staticmethod
335+ def one_ (xs ):
336+ xs [0 ].fill_ (0 )
337+ xs [1 ].fill_ (0 )
338+ xs [2 ].fill_ (0 )
339+ return xs
340+
341+ class CrossEntropySemiring (Semiring ):
342+ """
343+ Implements an cross-entropy expectation semiring.
344+
345+ Computes both the log-values of two distributions and the running cross entropy between two distributions.
346+
347+ Based on descriptions in:
348+
349+ * Parameter estimation for probabilistic finite-state transducers :cite:`eisner2002parameter`
350+ * First-and second-order expectation semirings with applications to minimum-risk training on translation forests :cite:`li2009first`
351+ * Sample Selection for Statistical Grammar Induction :cite:`hwa2000samplesf`
352+ """
353+
354+ zero = 0
355+
356+ @staticmethod
357+ def size ():
358+ return 3
359+
360+ @staticmethod
361+ def convert (xs ):
362+ values = torch .zeros ((3 ,) + xs [0 ].shape ).type_as (xs [0 ])
363+ values [0 ] = xs [0 ]
364+ values [1 ] = xs [1 ]
365+ values [2 ] = 0
366+ return values
367+
368+ @staticmethod
369+ def unconvert (xs ):
370+ return xs [- 1 ]
371+
372+ @staticmethod
373+ def sum (xs , dim = - 1 ):
374+ assert dim != 0
375+ d = dim - 1 if dim > 0 else dim
376+ part_p = torch .logsumexp (xs [0 ], dim = d )
377+ part_q = torch .logsumexp (xs [1 ], dim = d )
378+ log_sm_p = xs [0 ] - part_p .unsqueeze (d )
379+ log_sm_q = xs [1 ] - part_q .unsqueeze (d )
380+ sm_p = log_sm_p .exp ()
381+ return torch .stack ((part_p , part_q , torch .sum (xs [2 ].mul (sm_p ) - log_sm_q .mul (sm_p ), dim = d )))
382+
383+ @staticmethod
384+ def mul (a , b ):
385+ return torch .stack ((a [0 ] + b [0 ], a [1 ] + b [1 ], a [2 ] + b [2 ]))
386+
387+ @classmethod
388+ def prod (cls , xs , dim = - 1 ):
389+ return xs .sum (dim )
390+
391+ @classmethod
392+ def zero_mask_ (cls , xs , mask ):
393+ "Fill *ssize x ...* tensor with additive identity."
394+ xs [0 ].masked_fill_ (mask , - 1e5 )
395+ xs [1 ].masked_fill_ (mask , - 1e5 )
396+ xs [2 ].masked_fill_ (mask , 0 )
397+
398+ @staticmethod
399+ def zero_ (xs ):
400+ xs [0 ].fill_ (- 1e5 )
401+ xs [1 ].fill_ (- 1e5 )
402+ xs [2 ].fill_ (0 )
403+ return xs
404+
405+ @staticmethod
406+ def one_ (xs ):
407+ xs [0 ].fill_ (0 )
408+ xs [1 ].fill_ (0 )
409+ xs [2 ].fill_ (0 )
410+ return xs
411+
412+
413+
414+
415+
272416class EntropySemiring (Semiring ):
273417 """
274418 Implements an entropy expectation semiring.
@@ -279,6 +423,7 @@ class EntropySemiring(Semiring):
279423
280424 * Parameter estimation for probabilistic finite-state transducers :cite:`eisner2002parameter`
281425 * First-and second-order expectation semirings with applications to minimum-risk training on translation forests :cite:`li2009first`
426+ * Sample Selection for Statistical Grammar Induction :cite:`hwa2000samplesf`
282427 """
283428
284429 zero = 0
0 commit comments