@@ -438,15 +438,25 @@ def manual_grads(params):
438
438
deltas ['bbeta_r' ] += dzbeta_r
439
439
deltas ['bbeta_w' ] += dzbeta_w
440
440
441
- shift_grad_s = jacobian (shift , argnum = 1 )
442
- shift_grad_s_r = np .reshape (shift_grad_s (wg_rs [t ], softmax (zs_rs [t ])),
443
- (self .N , 3 ))
444
- shift_grad_s_w = np .reshape (shift_grad_s (wg_ws [t ], softmax (zs_ws [t ])),
445
- (self .N , 3 ))
446
-
447
- # import pdb; pdb.set_trace()
448
- ds_r = np .dot (shift_grad_s_r .T , dws_r )
449
- ds_w = np .dot (shift_grad_s_w .T , dws_w )
441
+ # shift_grad_s = jacobian(shift, argnum=1)
442
+ # shift_grad_s_r = np.reshape(shift_grad_s(wg_rs[t], softmax(zs_rs[t])),
443
+ # (self.N, 3))
444
+ # shift_grad_s_w = np.reshape(shift_grad_s(wg_ws[t], softmax(zs_ws[t])),
445
+ # (self.N, 3))
446
+
447
+ sgsr = np .zeros ((self .N , 3 ))
448
+ sgsw = np .zeros ((self .N , 3 ))
449
+ for i in range (self .N ):
450
+ sgsr [i ,1 ] = wg_rs [t ][(i - 1 ) % self .N ]
451
+ sgsr [i ,0 ] = wg_rs [t ][i ]
452
+ sgsr [i ,2 ] = wg_rs [t ][(i + 1 ) % self .N ]
453
+
454
+ sgsw [i ,1 ] = wg_ws [t ][(i - 1 ) % self .N ]
455
+ sgsw [i ,0 ] = wg_ws [t ][i ]
456
+ sgsw [i ,2 ] = wg_ws [t ][(i + 1 ) % self .N ]
457
+
458
+ ds_r = np .dot (sgsr .T , dws_r )
459
+ ds_w = np .dot (sgsw .T , dws_w )
450
460
451
461
shift_act_jac_r = np .zeros ((3 ,3 ))
452
462
shift_act_jac_w = np .zeros ((3 ,3 ))
0 commit comments