@@ -219,7 +219,9 @@ class AttentiveProbe(Probe):
219
219
"""
220
220
def __init__ (
221
221
self , dkey , source_seq_length , input_dim , out_dim , num_heads = 8 , attn_dim = 64 ,
222
- target_seq_length = 1 , learnable_query_dim = 32 , batch_size = 1 , hid_dim = 32 , use_LN = True , use_LN_input = False , use_softmax = True , ** kwargs
222
+ target_seq_length = 1 , learnable_query_dim = 32 , batch_size = 1 , hid_dim = 32 ,
223
+ use_LN = True , use_LN_input = False , use_softmax = True , dropout = 0.5 , eta = 0.0002 ,
224
+ eta_decay = 0.0 , min_eta = 1e-5 , ** kwargs
223
225
):
224
226
super ().__init__ (dkey , batch_size , ** kwargs )
225
227
assert attn_dim % num_heads == 0 , f"`attn_dim` must be divisible by `num_heads`. Got { attn_dim } and { num_heads } ."
@@ -232,9 +234,9 @@ def __init__(
232
234
self .use_softmax = use_softmax
233
235
self .use_LN = use_LN
234
236
self .use_LN_input = use_LN_input
235
- self .dropout = 0.5
237
+ self .dropout = dropout
236
238
237
- sigma = 0.05
239
+ sigma = 0.02
238
240
## cross-attention parameters
239
241
Wq = random .normal (subkeys [0 ], (learnable_query_dim , attn_dim )) * sigma
240
242
bq = random .normal (subkeys [1 ], (1 , attn_dim )) * sigma
@@ -287,7 +289,10 @@ def __init__(
287
289
self .grad_fx = jax .value_and_grad (eval_attention_probe , argnums = 1 , has_aux = True ) #, allow_int=True)
288
290
## set up update rule/optimizer
289
291
self .optim_params = adam .adam_init (self .probe_params )
290
- self .eta = 0.0002 #0.001
292
+ # Learning rate scheduling
293
+ self .eta = eta #0.001
294
+ self .eta_decay = eta_decay
295
+ self .min_eta = min_eta
291
296
292
297
# Finally, the dkey for the noise_key
293
298
self .noise_key = subkeys [24 ]
@@ -319,5 +324,7 @@ def update(self, embeddings, labels, dkey=None):
319
324
self .optim_params , self .probe_params = adam .adam_step (
320
325
self .optim_params , self .probe_params , grads , eta = self .eta
321
326
)
327
+
328
+ self .eta = max (self .min_eta , self .eta - self .eta_decay * self .eta )
322
329
return loss , predictions
323
330
0 commit comments