Skip to content

Commit e7d7cae

Browse files
committed
replace shift jacobian.
1 parent dd34205 commit e7d7cae

File tree

1 file changed

+19
-9
lines changed

1 file changed

+19
-9
lines changed

ntm/ntm.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -438,15 +438,25 @@ def manual_grads(params):
438438
deltas['bbeta_r'] += dzbeta_r
439439
deltas['bbeta_w'] += dzbeta_w
440440

441-
shift_grad_s = jacobian(shift, argnum=1)
442-
shift_grad_s_r = np.reshape(shift_grad_s(wg_rs[t], softmax(zs_rs[t])),
443-
(self.N, 3))
444-
shift_grad_s_w = np.reshape(shift_grad_s(wg_ws[t], softmax(zs_ws[t])),
445-
(self.N, 3))
446-
447-
# import pdb; pdb.set_trace()
448-
ds_r = np.dot(shift_grad_s_r.T, dws_r)
449-
ds_w = np.dot(shift_grad_s_w.T, dws_w)
441+
# shift_grad_s = jacobian(shift, argnum=1)
442+
# shift_grad_s_r = np.reshape(shift_grad_s(wg_rs[t], softmax(zs_rs[t])),
443+
# (self.N, 3))
444+
# shift_grad_s_w = np.reshape(shift_grad_s(wg_ws[t], softmax(zs_ws[t])),
445+
# (self.N, 3))
446+
447+
sgsr = np.zeros((self.N, 3))
448+
sgsw = np.zeros((self.N, 3))
449+
for i in range(self.N):
450+
sgsr[i,1] = wg_rs[t][(i - 1) % self.N]
451+
sgsr[i,0] = wg_rs[t][i]
452+
sgsr[i,2] = wg_rs[t][(i + 1) % self.N]
453+
454+
sgsw[i,1] = wg_ws[t][(i - 1) % self.N]
455+
sgsw[i,0] = wg_ws[t][i]
456+
sgsw[i,2] = wg_ws[t][(i + 1) % self.N]
457+
458+
ds_r = np.dot(sgsr.T, dws_r)
459+
ds_w = np.dot(sgsw.T, dws_w)
450460

451461
shift_act_jac_r = np.zeros((3,3))
452462
shift_act_jac_w = np.zeros((3,3))

0 commit comments

Comments
 (0)