Skip to content

Commit 04e1343

Browse files
committed
hyperparameter tunning arguments added
1 parent 695e9d8 commit 04e1343

File tree

2 files changed

+18
-6
lines changed

2 files changed

+18
-6
lines changed

ngclearn/utils/analysis/attentive_probe.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,9 @@ class AttentiveProbe(Probe):
219219
"""
220220
def __init__(
221221
self, dkey, source_seq_length, input_dim, out_dim, num_heads=8, attn_dim=64,
222-
target_seq_length=1, learnable_query_dim=32, batch_size=1, hid_dim=32, use_LN=True, use_LN_input=False, use_softmax=True, **kwargs
222+
target_seq_length=1, learnable_query_dim=32, batch_size=1, hid_dim=32,
223+
use_LN=True, use_LN_input=False, use_softmax=True, dropout=0.5, eta=0.0002,
224+
eta_decay=0.0, min_eta=1e-5, **kwargs
223225
):
224226
super().__init__(dkey, batch_size, **kwargs)
225227
assert attn_dim % num_heads == 0, f"`attn_dim` must be divisible by `num_heads`. Got {attn_dim} and {num_heads}."
@@ -232,9 +234,9 @@ def __init__(
232234
self.use_softmax = use_softmax
233235
self.use_LN = use_LN
234236
self.use_LN_input = use_LN_input
235-
self.dropout = 0.5
237+
self.dropout = dropout
236238

237-
sigma = 0.05
239+
sigma = 0.02
238240
## cross-attention parameters
239241
Wq = random.normal(subkeys[0], (learnable_query_dim, attn_dim)) * sigma
240242
bq = random.normal(subkeys[1], (1, attn_dim)) * sigma
@@ -287,7 +289,10 @@ def __init__(
287289
self.grad_fx = jax.value_and_grad(eval_attention_probe, argnums=1, has_aux=True) #, allow_int=True)
288290
## set up update rule/optimizer
289291
self.optim_params = adam.adam_init(self.probe_params)
290-
self.eta = 0.0002 #0.001
292+
# Learning rate scheduling
293+
self.eta = eta #0.001
294+
self.eta_decay = eta_decay
295+
self.min_eta = min_eta
291296

292297
# Finally, the dkey for the noise_key
293298
self.noise_key = subkeys[24]
@@ -319,5 +324,7 @@ def update(self, embeddings, labels, dkey=None):
319324
self.optim_params, self.probe_params = adam.adam_step(
320325
self.optim_params, self.probe_params, grads, eta=self.eta
321326
)
327+
328+
self.eta = max(self.min_eta, self.eta - self.eta_decay * self.eta)
322329
return loss, predictions
323330

ngclearn/utils/analysis/probe.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -150,9 +150,14 @@ def fit(self, dataset, dev_dataset=None, n_iter=50, patience=20):
150150
L = (_L * x_mb.shape[0]) + L ## we remove the batch division from loss w.r.t. x_mb/y_mb
151151

152152
if dev_data is not None:
153-
print(f"\r{ii} L = {L / Ns:.3f} Acc = {acc / Ns:.2f} Dev.Acc = {best_acc:.2f}", end="")
153+
print_string = f"\r{ii} L = {L / Ns:.3f} Acc = {acc / Ns:.2f} Dev.Acc = {best_acc:.2f}"
154154
else:
155-
print(f"\r{ii} L = {L / Ns:.3f} Acc = {acc / Ns:.2f}", end="")
155+
print_string = f"\r{ii} L = {L / Ns:.3f} Acc = {acc / Ns:.2f}"
156+
157+
if hasattr(self, "eta"):
158+
print_string += f" LR = {getattr(self, 'eta'):.6f}"
159+
160+
print(print_string, end = "")
156161

157162
acc = acc / Ns
158163
L = L / Ns ## compute current loss over (train) dataset

0 commit comments

Comments
 (0)