Skip to content

Commit 3441c0b

Browse files
committed
small fixes
1 parent 0bce6c6 commit 3441c0b

File tree

3 files changed

+4
-3
lines changed

3 files changed

+4
-3
lines changed

dataset/dataset.py

+1
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ def __iter__(self):
7979
self.pairs = np.random.permutation(np.array(self.pairs, dtype=object))
8080
else:
8181
self.pairs = np.array(self.pairs, dtype=object)
82+
self.size = len(self.pairs)
8283
return self
8384

8485
def __next__(self):

models.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -93,5 +93,5 @@ def get_model(args):
9393
).to(args.device)
9494
if args.wandb:
9595
import wandb
96-
wandb.watch((encoder.attn_layers, decoder.attn_layers))
96+
wandb.watch((encoder.attn_layers, decoder.net.attn_layers))
9797
return Model(encoder, decoder, args)

train.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,10 @@ def train(args):
2626
model = get_model(args)
2727
encoder, decoder = model.encoder, model.decoder
2828
opt = optim.Adam(model.parameters(), args.lr)
29-
scheduler = optim.lr_scheduler.OneCycleLR(opt, max_lr=0.05, steps_per_epoch=len(dataloader), epochs=args.epochs)
29+
scheduler = optim.lr_scheduler.OneCycleLR(opt, max_lr=0.002, steps_per_epoch=len(dataloader), epochs=args.epochs)
3030

3131
for e in range(args.epochs):
32-
dset = tqdm(dataloader)
32+
dset = tqdm(iter(dataloader))
3333
for i, (seq, im) in enumerate(dset):
3434
opt.zero_grad()
3535
tgt_seq, tgt_mask = seq['input_ids'].to(device), seq['attention_mask'].bool().to(device)

0 commit comments

Comments
 (0)