@@ -53,6 +53,106 @@ def sum(xs, dim=-1):
5353 return _SampledLogSumExp .apply (xs , dim )
5454
5555
56+ def GumbelMaxSemiring (temp ):
57+ class _GumbelMaxLogSumExp (torch .autograd .Function ):
58+ @staticmethod
59+ def forward (ctx , input , dim ):
60+ ctx .save_for_backward (input , torch .tensor (dim ))
61+ return torch .logsumexp (input , dim = dim )
62+
63+ @staticmethod
64+ def backward (ctx , grad_output ):
65+ logits , dim = ctx .saved_tensors
66+ grad_input = None
67+ if ctx .needs_input_grad [0 ]:
68+
69+ def sample (ls ):
70+ pre_shape = ls .shape
71+ update = (
72+ ls + torch .distributions .Gumbel (0 , 1 ).sample ((ls .shape [- 1 ],))
73+ ) / temp
74+ out = torch .nn .functional .one_hot (update .max (- 1 )[1 ], pre_shape [- 1 ])
75+ return out
76+
77+ if dim == - 1 :
78+ s = sample (logits )
79+ else :
80+ dim = dim if dim >= 0 else logits .dim () + dim
81+ perm = [i for i in range (logits .dim ()) if i != dim ] + [dim ]
82+ rev_perm = [
83+ a for a , b in sorted (enumerate (perm ), key = lambda a : a [1 ])
84+ ]
85+ s = sample (logits .permute (perm )).permute (rev_perm )
86+
87+ grad_input = grad_output .unsqueeze (dim ).mul (s )
88+ return grad_input , None
89+
90+ class _GumbelMaxSemiring (_BaseLog ):
91+ @staticmethod
92+ def sum (xs , dim = - 1 ):
93+ return _GumbelMaxLogSumExp .apply (xs , dim )
94+
95+ return _GumbelMaxSemiring
96+
97+
98+ def GumbelCRFSemiring (temp ):
99+ class ST (torch .autograd .Function ):
100+ @staticmethod
101+ def forward (ctx , logits , dim ):
102+ out = torch .nn .functional .one_hot (logits .max (- 1 )[1 ], dim )
103+ out = out .type_as (logits )
104+ ctx .save_for_backward (logits , out )
105+ return out
106+
107+ @staticmethod
108+ def backward (ctx , grad_output ):
109+ logits , out = ctx .saved_tensors
110+ with torch .enable_grad ():
111+ ret = torch .autograd .grad (
112+ logits .softmax (- 1 ), logits , out * grad_output
113+ )[0 ]
114+ return ret , None
115+
116+ class _GumbelCRFLogSumExp (torch .autograd .Function ):
117+ @staticmethod
118+ def forward (ctx , input , dim ):
119+ ctx .save_for_backward (input , torch .tensor (dim ))
120+ return torch .logsumexp (input , dim = dim )
121+
122+ @staticmethod
123+ def backward (ctx , grad_output ):
124+ logits , dim = ctx .saved_tensors
125+ grad_input = None
126+ if ctx .needs_input_grad [0 ]:
127+
128+ def sample (ls ):
129+ update = (
130+ ls + torch .distributions .Gumbel (0 , 1 ).sample ((ls .shape [- 1 ],))
131+ ) / temp
132+ out = ST .apply (update , ls .shape [- 1 ])
133+ return out
134+
135+ if dim == - 1 :
136+ s = sample (logits )
137+ else :
138+ dim = dim if dim >= 0 else logits .dim () + dim
139+ perm = [i for i in range (logits .dim ()) if i != dim ] + [dim ]
140+ rev_perm = [
141+ a for a , b in sorted (enumerate (perm ), key = lambda a : a [1 ])
142+ ]
143+ s = sample (logits .permute (perm )).permute (rev_perm )
144+
145+ grad_input = grad_output .unsqueeze (dim ).mul (s )
146+ return grad_input , None
147+
148+ class _GumbelCRFSemiring (_BaseLog ):
149+ @staticmethod
150+ def sum (xs , dim = - 1 ):
151+ return _GumbelCRFLogSumExp .apply (xs , dim )
152+
153+ return _GumbelCRFSemiring
154+
155+
56156bits = torch .tensor ([pow (2 , i ) for i in range (1 , 18 )])
57157
58158
0 commit comments