3
3
"""
4
4
import math
5
5
import autograd .numpy as np
6
- from autograd import grad
7
- from util .util import rando , sigmoid , softmax , softplus , unwrap , sigmoid_prime , tanh_prime , compare_deltas , dKdu , softmax_grads
6
+ from autograd import grad , jacobian
7
+ from util .util import rando , sigmoid , softmax , softplus , unwrap , sigmoid_prime , tanh_prime , compare_deltas , dKdu , softmax_grads , beta_grads , K_focus
8
8
import memory
9
9
import addressing
10
10
from addressing import cosine_sim
@@ -107,10 +107,11 @@ def l():
107
107
return {}
108
108
109
109
rs = l ()
110
- zk_rs = l ()
110
+ zk_rs = l () # TODO: why only z for ks?
111
111
k_rs , beta_rs , g_rs , s_rs , gamma_rs = l (),l (),l (),l (),l ()
112
112
k_ws , beta_ws , g_ws , s_ws , gamma_ws = l (),l (),l (),l (),l ()
113
113
adds , erases = l (),l ()
114
+ zbeta_rs , zbeta_ws = l (),l ()
114
115
w_ws , w_rs = l (),l () # read weights and write weights
115
116
wc_ws , wc_rs = l (),l () # read and write content weights
116
117
rs [- 1 ] = self .W ['rsInit' ] # stores values read from memory
@@ -135,17 +136,17 @@ def l():
135
136
# parameters to the read head
136
137
zk_rs [t ] = np .dot (W ['ok_r' ],os [t ]) + W ['bk_r' ]
137
138
k_rs [t ] = np .tanh (zk_rs [t ])
138
- beta_rs [t ] = softplus ( np .dot (W ['obeta_r' ],os [t ])
139
- + W [ 'bbeta_r' ])
139
+ zbeta_rs [t ] = np .dot (W ['obeta_r' ],os [t ]) + W [ 'bbeta_r' ]
140
+ beta_rs [ t ] = softplus ( zbeta_rs [ t ])
140
141
g_rs [t ] = sigmoid (np .dot (W ['og_r' ],os [t ]) + W ['bg_r' ])
141
142
s_rs [t ] = softmax (np .dot (W ['os_r' ],os [t ]) + W ['bs_r' ])
142
143
gamma_rs [t ] = 1 + sigmoid (np .dot (W ['ogamma_r' ], os [t ])
143
144
+ W ['bgamma_r' ])
144
145
145
146
# parameters to the write head
146
147
k_ws [t ] = np .tanh (np .dot (W ['ok_w' ],os [t ]) + W ['bk_w' ])
147
- beta_ws [t ] = softplus ( np .dot (W ['obeta_w' ], os [t ])
148
- + W [ 'bbeta_w' ])
148
+ zbeta_ws [t ] = np .dot (W ['obeta_w' ],os [t ]) + W [ 'bbeta_w' ]
149
+ beta_ws [ t ] = softplus ( zbeta_ws [ t ])
149
150
g_ws [t ] = sigmoid (np .dot (W ['og_w' ],os [t ]) + W ['bg_w' ])
150
151
s_ws [t ] = softmax (np .dot (W ['os_w' ],os [t ]) + W ['bs_w' ])
151
152
gamma_ws [t ] = 1 + sigmoid (np .dot (W ['ogamma_w' ], os [t ])
@@ -191,7 +192,8 @@ def l():
191
192
mems [t ] = memory .write (mems [t - 1 ],w_ws [t ],erases [t ],adds [t ])
192
193
193
194
self .stats = [loss , mems , ps , ys , os , zos , hs , zhs , xs , rs , w_rs ,
194
- w_ws , adds , erases , k_rs , k_ws , g_rs , g_ws , wc_rs , wc_ws ]
195
+ w_ws , adds , erases , k_rs , k_ws , g_rs , g_ws , wc_rs , wc_ws ,
196
+ zbeta_rs , zbeta_ws ]
195
197
return np .sum (loss )
196
198
197
199
def manual_grads (params ):
@@ -204,7 +206,8 @@ def manual_grads(params):
204
206
deltas [key ] = np .zeros_like (val )
205
207
206
208
[loss , mems , ps , ys , os , zos , hs , zhs , xs , rs , w_rs ,
207
- w_ws , adds , erases , k_rs , k_ws , g_rs , g_ws , wc_rs , wc_ws ] = self .stats
209
+ w_ws , adds , erases , k_rs , k_ws , g_rs , g_ws , wc_rs , wc_ws ,
210
+ zbeta_rs , zbeta_ws ] = self .stats
208
211
dd = {}
209
212
drs = {}
210
213
dzh = {}
@@ -305,8 +308,8 @@ def manual_grads(params):
305
308
for i in range (self .N ):
306
309
# for every element in the weighting
307
310
for j in range (self .N ):
308
- dwdK_r [i ,j ] += softmax_grads (K_rs , i , j )
309
- dwdK_w [i ,j ] += softmax_grads (K_ws , i , j )
311
+ dwdK_r [i ,j ] += softmax_grads (K_rs , softplus ( zbeta_rs [ t ]), i , j )
312
+ dwdK_w [i ,j ] += softmax_grads (K_ws , softplus ( zbeta_ws [ t ]), i , j )
310
313
311
314
# compute dK for all i in N
312
315
# K is the evaluated cosine similarity for the i-th row of mem matrix
@@ -317,6 +320,7 @@ def manual_grads(params):
317
320
for i in range (self .N ):
318
321
# for every j in N (for every elt of the weighting)
319
322
for j in range (self .N ):
323
+ # specifically, dwdK_r will change, and for write as well
320
324
dK_r [i ] += dwc_r [j ] * dwdK_r [i ,j ]
321
325
dK_w [i ] += dwc_w [j ] * dwdK_w [i ,j ]
322
326
@@ -397,6 +401,32 @@ def manual_grads(params):
397
401
deltas ['bg_r' ] += dzg_r
398
402
deltas ['bg_w' ] += dzg_w
399
403
404
+ # compute dbeta, which affects w_content through interaction with Ks
405
+
406
+ dwcdbeta_r = np .zeros_like (w_rs [0 ])
407
+ dwcdbeta_w = np .zeros_like (w_ws [0 ])
408
+ for i in range (self .N ):
409
+ dwcdbeta_r [i ] = beta_grads (K_rs , softplus (zbeta_rs [t ]), i )
410
+ dwcdbeta_w [i ] = beta_grads (K_ws , softplus (zbeta_ws [t ]), i )
411
+
412
+ # import pdb; pdb.set_trace()
413
+ dbeta_r = np .zeros_like (zbeta_rs [0 ])
414
+ dbeta_w = np .zeros_like (zbeta_ws [0 ])
415
+ for i in range (self .N ):
416
+ dbeta_r [0 ] += dwc_r [i ] * dwcdbeta_r [i ]
417
+ dbeta_w [0 ] += dwc_w [i ] * dwcdbeta_w [i ]
418
+
419
+
420
+ # beta is activated from zbeta by softplus, grad of which is sigmoid
421
+ dzbeta_r = dbeta_r * sigmoid (zbeta_rs [t ])
422
+ dzbeta_w = dbeta_w * sigmoid (zbeta_ws [t ])
423
+
424
+ deltas ['obeta_r' ] += np .dot (dzbeta_r , os [t ].T )
425
+ deltas ['obeta_w' ] += np .dot (dzbeta_w , os [t ].T )
426
+
427
+ deltas ['bbeta_r' ] += dzbeta_r
428
+ deltas ['bbeta_w' ] += dzbeta_w
429
+
400
430
else :
401
431
drs [t ] = np .zeros_like (rs [0 ])
402
432
dmemtilde [t ] = np .zeros_like (mems [0 ])
@@ -417,6 +447,9 @@ def manual_grads(params):
417
447
# and also through the interpolators
418
448
do += np .dot (params ['og_r' ].T , dzg_r )
419
449
do += np .dot (params ['og_w' ].T , dzg_w )
450
+ # and also through beta
451
+ do += np .dot (params ['obeta_r' ].T , dzbeta_r )
452
+ do += np .dot (params ['obeta_w' ].T , dzbeta_w )
420
453
421
454
422
455
# compute deriv w.r.t. pre-activation of o
@@ -479,6 +512,7 @@ def bprop(params, manual_grad):
479
512
480
513
deltas = bprop (self .W , manual_grad )
481
514
[loss , mems , ps , ys , os , zos , hs , zhs , xs , rs , w_rs ,
482
- w_ws , adds , erases , k_rs , k_ws , g_rs , g_ws , wc_rs , wc_ws ] = map (unwrap , self .stats )
515
+ w_ws , adds , erases , k_rs , k_ws , g_rs , g_ws , wc_rs , wc_ws ,
516
+ zbeta_rs , zbeta_ws ] = map (unwrap , self .stats )
483
517
484
518
return loss , deltas , ps , w_rs , w_ws , adds , erases
0 commit comments