@@ -34,7 +34,7 @@ def logpartition(self, log_potentials, lengths=None, force_grad=False):
3434        )
3535
3636        # Init. 
37-         mask  =  torch .zeros (* init .shape ).bool ()
37+         mask  =  torch .zeros (* init .shape ,  device = log_potentials . device ).bool ()
3838        mask [:, :, :, 0 , 0 ].diagonal (0 , - 2 , - 1 ).fill_ (True )
3939        init  =  semiring .fill (init , mask , semiring .one )
4040
@@ -61,10 +61,13 @@ def logpartition(self, log_potentials, lengths=None, force_grad=False):
6161        c [:, :, : K  -  1 , 0 ] =  semiring .sum (
6262            torch .stack ([c .data [:, :, : K  -  1 , 0 ], lp [:, :, 1 :K ]], dim = - 1 )
6363        )
64-         end  =  torch .min (lengths ) -  1 
65-         mask  =  torch .zeros (* init .shape ).bool ()
64+         mask  =  torch .zeros (* init .shape , device = log_potentials .device ).bool ()
65+         mask_length  =  torch .arange (bin_N ).view (1 , bin_N , 1 ).expand (batch , bin_N , C )
66+         mask_length  =  mask_length .to (log_potentials .device )
6667        for  k  in  range (1 , K  -  1 ):
67-             mask [:, :, : end  -  (k  -  1 ), k  -  1 , k ].diagonal (0 , - 2 , - 1 ).fill_ (True )
68+             mask_length_k  =  mask_length  <  (lengths  -  1  -  (k  -  1 )).view (batch , 1 , 1 )
69+             mask_length_k  =  semiring .convert (mask_length_k )
70+             mask [:, :, :, k  -  1 , k ].diagonal (0 , - 2 , - 1 ).masked_fill_ (mask_length_k , True )
6871        init  =  semiring .fill (init , mask , semiring .one )
6972
7073        K_1  =  K  -  1 
@@ -83,37 +86,37 @@ def logpartition(self, log_potentials, lengths=None, force_grad=False):
8386        v  =  semiring .sum (semiring .sum (final [:, :, 0 , :, 0 , :].contiguous ()))
8487        return  v , [log_potentials ]
8588
86-     #  def _dp_standard(self, edge, lengths=None, force_grad=False):
87-     #      semiring = self.semiring
88-     #      ssize = semiring.size()
89-     #      edge, batch, N, K, C, lengths = self._check_potentials(edge, lengths)
90-     #      edge.requires_grad_(True)
91- 
92-     #      # Init
93-     #      # All paths starting at N of len K
94-     #      alpha = self._make_chart(1, (batch, N, K, C), edge, force_grad)[0]
95- 
96-     #      # All paths finishing at N with label C
97-     #      beta = self._make_chart(N, (batch, C), edge, force_grad)
98-     #       semiring.one_ (beta[0].data )
99- 
100-     #      # Main.
101-     #      for n in range(1, N):
102-     #          alpha[:, :, n - 1] = semiring.dot(
103-     #              beta[n - 1].view(ssize, batch, 1, 1, C),
104-     #              edge[:, :, n - 1].view(ssize, batch, K, C, C),
105-     #          )
106- 
107-     #          t = max(n - K, -1)
108-     #          f1 = torch.arange(n - 1, t, -1)
109-     #          f2 = torch.arange(1, len(f1) + 1)
110-     #          beta[n][:] = semiring.sum(
111-     #              torch.stack([alpha[:, :, a, b] for a, b in zip(f1, f2)], dim=-1)
112-     #          )
113-     #      v = semiring.sum(
114-     #          torch.stack([beta[l - 1][:, i] for i, l in enumerate(lengths)], dim=1)
115-     #      )
116-     #      return v, [edge], beta
89+     def  _dp_standard (self , edge , lengths = None , force_grad = False ):
90+         semiring  =  self .semiring 
91+         ssize  =  semiring .size ()
92+         edge , batch , N , K , C , lengths  =  self ._check_potentials (edge , lengths )
93+         edge .requires_grad_ (True )
94+ 
95+         # Init 
96+         # All paths starting at N of len K 
97+         alpha  =  self ._make_chart (1 , (batch , N , K , C ), edge , force_grad )[0 ]
98+ 
99+         # All paths finishing at N with label C 
100+         beta  =  self ._make_chart (N , (batch , C ), edge , force_grad )
101+         beta [ 0 ]  =   semiring .fill (beta [0 ],  torch . tensor ( True ). to ( edge . device ),  semiring . one )
102+ 
103+         # Main. 
104+         for  n  in  range (1 , N ):
105+             alpha [:, :, n  -  1 ] =  semiring .dot (
106+                 beta [n  -  1 ].view (ssize , batch , 1 , 1 , C ),
107+                 edge [:, :, n  -  1 ].view (ssize , batch , K , C , C ),
108+             )
109+ 
110+             t  =  max (n  -  K , - 1 )
111+             f1  =  torch .arange (n  -  1 , t , - 1 )
112+             f2  =  torch .arange (1 , len (f1 ) +  1 )
113+             beta [n ][:] =  semiring .sum (
114+                 torch .stack ([alpha [:, :, a , b ] for  a , b  in  zip (f1 , f2 )], dim = - 1 )
115+             )
116+         v  =  semiring .sum (
117+             torch .stack ([beta [l  -  1 ][:, i ] for  i , l  in  enumerate (lengths )], dim = 1 )
118+         )
119+         return  v , [edge ], beta 
117120
118121    @staticmethod  
119122    def  to_parts (sequence , extra , lengths = None ):
0 commit comments