@@ -170,17 +170,38 @@ def from_parts(arcs):
170170 return labels , None
171171
172172
173- def deptree_part (arc_scores , eps = 1e-5 ):
173+ def deptree_part (arc_scores , multi_root , lengths , eps = 1e-5 ):
174+ if lengths is not None :
175+ batch , N , N = arc_scores .shape
176+ x = torch .arange (N , device = arc_scores .device ).expand (batch , N )
177+ if not torch .is_tensor (lengths ):
178+ lengths = torch .tensor (lengths , device = arc_scores .device )
179+ lengths = lengths .unsqueeze (1 )
180+ x = x < lengths
181+ det_offset = torch .diag_embed ((~ x ).float ())
182+ x = x .unsqueeze (2 ).expand (- 1 , - 1 , N )
183+ mask = torch .transpose (x , 1 , 2 ) * x
184+ mask = mask .float ()
185+ mask [mask == 0 ] = float ('-inf' )
186+ mask [mask == 1 ] = 0
187+ arc_scores = arc_scores + mask
174188 input = arc_scores
175189 eye = torch .eye (input .shape [1 ], device = input .device )
176190 laplacian = input .exp () + eps
177191 lap = laplacian .masked_fill (eye != 0 , 0 )
178192 lap = - lap + torch .diag_embed (lap .sum (1 ), offset = 0 , dim1 = - 2 , dim2 = - 1 )
179- lap [:, 0 ] = torch .diagonal (input , 0 , - 2 , - 1 ).exp ()
193+ if lengths is not None :
194+ lap += det_offset
195+
196+ if multi_root :
197+ rss = torch .diagonal (input , 0 , - 2 , - 1 ).exp () # root selection scores
198+ lap = lap + torch .diag_embed (rss , offset = 0 , dim1 = - 2 , dim2 = - 1 )
199+ else :
200+ lap [:, 0 ] = torch .diagonal (input , 0 , - 2 , - 1 ).exp ()
180201 return lap .logdet ()
181-
182-
183- def deptree_nonproj (arc_scores , eps = 1e-5 ):
202+
203+
204+ def deptree_nonproj (arc_scores , multi_root , lengths , eps = 1e-5 ):
184205 """
185206 Compute the marginals of a non-projective dependency tree using the
186207 matrix-tree theorem.
@@ -196,27 +217,61 @@ def deptree_nonproj(arc_scores, eps=1e-5):
196217 Returns:
197218 arc_marginals : b x N x N.
198219 """
199-
220+ if lengths is not None :
221+ batch , N , N = arc_scores .shape
222+ x = torch .arange (N , device = arc_scores .device ).expand (batch , N )
223+ if not torch .is_tensor (lengths ):
224+ lengths = torch .tensor (lengths , device = arc_scores .device )
225+ lengths = lengths .unsqueeze (1 )
226+ x = x < lengths
227+ det_offset = torch .diag_embed ((~ x ).float ())
228+ x = x .unsqueeze (2 ).expand (- 1 , - 1 , N )
229+ mask = torch .transpose (x , 1 , 2 ) * x
230+ mask = mask .float ()
231+ mask [mask == 0 ] = float ('-inf' )
232+ mask [mask == 1 ] = 0
233+ arc_scores = arc_scores + mask
234+
200235 input = arc_scores
201236 eye = torch .eye (input .shape [1 ], device = input .device )
202237 laplacian = input .exp () + eps
203238 lap = laplacian .masked_fill (eye != 0 , 0 )
204239 lap = - lap + torch .diag_embed (lap .sum (1 ), offset = 0 , dim1 = - 2 , dim2 = - 1 )
205- lap [:, 0 ] = torch .diagonal (input , 0 , - 2 , - 1 ).exp ()
206- inv_laplacian = lap .inverse ()
207- factor = (
208- torch .diagonal (inv_laplacian , 0 , - 2 , - 1 )
209- .unsqueeze (2 )
210- .expand_as (input )
211- .transpose (1 , 2 )
212- )
213- term1 = input .exp ().mul (factor ).clone ()
214- term2 = input .exp ().mul (inv_laplacian .transpose (1 , 2 )).clone ()
215- term1 [:, :, 0 ] = 0
216- term2 [:, 0 ] = 0
217- output = term1 - term2
218- roots_output = (
219- torch .diagonal (input , 0 , - 2 , - 1 ).exp ().mul (inv_laplacian .transpose (1 , 2 )[:, 0 ])
220- )
240+ if lengths is not None :
241+ lap += det_offset
242+
243+ if multi_root :
244+ rss = torch .diagonal (input , 0 , - 2 , - 1 ).exp () # root selection scores
245+ lap = lap + torch .diag_embed (rss , offset = 0 , dim1 = - 2 , dim2 = - 1 )
246+ inv_laplacian = lap .inverse ()
247+ factor = (
248+ torch .diagonal (inv_laplacian , 0 , - 2 , - 1 )
249+ .unsqueeze (2 )
250+ .expand_as (input )
251+ .transpose (1 , 2 )
252+ )
253+ term1 = input .exp ().mul (factor ).clone ()
254+ term2 = input .exp ().mul (inv_laplacian .transpose (1 , 2 )).clone ()
255+ output = term1 - term2
256+ roots_output = (
257+ torch .diagonal (input , 0 , - 2 , - 1 ).exp ().mul (torch .diagonal (inv_laplacian .transpose (1 , 2 ), 0 , - 2 , - 1 ))
258+ )
259+ else :
260+ lap [:, 0 ] = torch .diagonal (input , 0 , - 2 , - 1 ).exp ()
261+ inv_laplacian = lap .inverse ()
262+ factor = (
263+ torch .diagonal (inv_laplacian , 0 , - 2 , - 1 )
264+ .unsqueeze (2 )
265+ .expand_as (input )
266+ .transpose (1 , 2 )
267+ )
268+ term1 = input .exp ().mul (factor ).clone ()
269+ term2 = input .exp ().mul (inv_laplacian .transpose (1 , 2 )).clone ()
270+ term1 [:, :, 0 ] = 0
271+ term2 [:, 0 ] = 0
272+ output = term1 - term2
273+ roots_output = (
274+ torch .diagonal (input , 0 , - 2 , - 1 ).exp ().mul (inv_laplacian .transpose (1 , 2 )[:, 0 ])
275+ )
221276 output = output + torch .diag_embed (roots_output , 0 , - 2 , - 1 )
222277 return output
0 commit comments