Skip to content

Commit

Permalink
refactor: improve logging and clean up console output
Browse files Browse the repository at this point in the history
  • Loading branch information
vejvarm committed Nov 20, 2023
1 parent d0ac77a commit def05e9
Show file tree
Hide file tree
Showing 4 changed files with 118 additions and 81 deletions.
1 change: 1 addition & 0 deletions args.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ def get_parser():
parser.add_argument('--seed', default=1234, type=int)
parser.add_argument('--no-cuda', action='store_true')
parser.add_argument('--cuda-device', default=0, type=int)
parser.add_argument('--name', default="", type=str)

# data
parser.add_argument('--data-path', default='data/final/csqa')
Expand Down
47 changes: 47 additions & 0 deletions lab/filename-extract-val-losses.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import pathlib
import re
from matplotlib import pyplot as plt


# Function to extract epoch and validation loss from file names
def extract_data(file_names):
epochs = []
losses = []
for file_name in file_names:
match = re.search(r"_e(\d+)_v([\d.]+)_", file_name)
if match:
epoch, loss = match.groups()
epochs.append(int(epoch))
losses.append(float(loss))
return epochs, losses


def plot(results: dict):
plt.figure(figsize=(10, 6))
for name, result in results.items():
epochs = result[0]
losses = result[1]
plt.plot(epochs, losses, label=name)

plt.xlabel('Epoch')
plt.ylabel('Validation Loss')
plt.title('Validation Loss Across Epochs')
plt.legend()
plt.grid(True)
plt.show()


if __name__ == "__main__":
folders = {"csqa": "pth/to/csqa/results/folder",
"merged": "pth/to/merged/results/folder"}

results = {}
for name, fldr in folders.items():
files = pathlib.Path(fldr).glob("*.pth.tar")

# Extracting data
epochs, losses = extract_data(files)

results[name] = (epochs, losses)

print(results)
147 changes: 68 additions & 79 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,7 @@
from torchtext.data import BucketIterator
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, SequentialSampler, BatchSampler, RandomSampler
from utils import (NoamOpt, AverageMeter,
SingleTaskLoss, MultiTaskLoss,
save_checkpoint, init_weights)
from utils import (NoamOpt, AverageMeter, SingleTaskLoss, MultiTaskLoss, save_checkpoint, init_weights)

from helpers import setup_logger
from constants import *
Expand Down Expand Up @@ -208,6 +206,7 @@ def main():
# initialize model weights
init_weights(model)

LOGGER.info(f"Model: `{MODEL_NAME}`, Experiment: `{args.name}`")
LOGGER.info(f'The model has {sum(p.numel() for p in model.parameters() if p.requires_grad):,} trainable parameters')

# define loss function (criterion)
Expand Down Expand Up @@ -257,7 +256,6 @@ def main():
shuffle=False,
collate_fn=partial(collate_fn, vocabs=vocabs, device=DEVICE))


LOGGER.info('Loaders prepared.')
LOGGER.info(f"Training data: {len(train_data)}")
LOGGER.info(f"Validation data: {len(val_data)}")
Expand All @@ -267,8 +265,8 @@ def main():
LOGGER.info(f"Unique tokens in logical form vocabulary: {len(vocabs[LOGICAL_FORM])}")
LOGGER.info(f"Unique tokens in ner vocabulary: {len(vocabs[NER])}")
LOGGER.info(f"Unique tokens in coref vocabulary: {len(vocabs[COREF])}")
LOGGER.info(f'Batch: {args.batch_size}')
LOGGER.info(f'Epochs: {args.epochs}')
LOGGER.info(f'Batches: {args.batch_size}')

# run epochs
for epoch in range(args.start_epoch, args.epochs):
Expand All @@ -280,11 +278,14 @@ def main():
val_loss = validate(val_loader, model, vocabs, val_helper, criterion, single_task_loss)
best_val = min(val_loss, best_val) # log every validation step
save_checkpoint({
EPOCH: epoch + 1,
STATE_DICT: model.state_dict(),
BEST_VAL: best_val,
OPTIMIZER: optimizer.optimizer.state_dict(),
CURR_VAL: val_loss})
EPOCH: epoch + 1,
STATE_DICT: model.state_dict(),
BEST_VAL: best_val,
OPTIMIZER: optimizer.optimizer.state_dict(),
CURR_VAL: val_loss
},
experiment=args.name
)
LOGGER.info(f'* Val loss: {val_loss:.4f}')


Expand All @@ -298,75 +299,63 @@ def train(train_loader, model, vocabs, helper_data, criterion, optimizer, epoch)

end = time.time()
batch_progress_old = -1
for i, batch in tqdm(enumerate(train_loader), total=total_batches, desc=f"Epoch {epoch}"):
# get inputs
input = batch.input
logical_form = batch.logical_form
ner = batch.ner
coref = batch.coref
predicate_t = batch.predicate_pointer
type_t = batch.type_pointer

# compute output
output = model(input, logical_form[:, :-1])
LOGGER.debug(f'output[NER] in train: ({output[NER].shape}) {output[NER]}')
LOGGER.debug(f'output[COREF] in train: ({output[COREF].shape}) {output[COREF]}')

ner_out = output[NER].detach().argmax(1).tolist()
LOGGER.debug(f'ner_out in train: ({len(ner_out)}) {ner_out}')
ner_str = [vocabs[NER].itos[i] for i in ner_out][1:-1]
LOGGER.debug(f'ner_str in train: ({len(ner_str)}) {ner_str}')
ner_indices = {k: tag.split('-')[-1] for k, tag in enumerate(ner_str) if
tag.startswith(B) or tag.startswith(I)} # idx: type_id
LOGGER.debug(f'ner_indices in train: ({len(ner_indices)}) {ner_indices}')
# coref_indices = {k: tag for k, tag in enumerate(coref_str) if tag not in ['NA']}
# create a ner dictionary with index as key and entity as value
# NOTE: WE ACTUALLY DON'T NEED ANY OF THIS!
# THE NER MODULE IS NOT LEARING ANYTHING NEW ... we don't need a specific loss for that
# ONLY THING WE NEED IS TO ADD NEW ENTRIES TO THE CSQA Dataset!

# NER module in TRAIN
# TODO: implement the ner module functionality, as in Inference,
# goal: missing entities added to index
# loss: compare added entities and their labels with the args.elastic_index_ent_full (Levenshtein distance? naah, either it's right or not)
# ner_prediction = output[NER]
# coref_prediction = output[COREF]
# ner_indices = OrderedDict({k: tag.split('-')[-1] for k, tag in enumerate(ner_prediction) if
# tag.startswith(B) or tag.startswith(I)}) # idx: type_id
# coref_indices = OrderedDict({k: tag for k, tag in enumerate(coref_prediction) if tag not in ['NA']})
# # create a ner dictionary with index as key and entity as value
# ner_idx_ent = self.create_ner_idx_ent_dict(ner_indices, context_question)
# output[NER]

# prepare targets
target = {
LOGICAL_FORM: logical_form[:, 1:].contiguous().view(-1),
NER: ner.contiguous().view(-1),
COREF: coref.contiguous().view(-1),
PREDICATE_POINTER: predicate_t[:, 1:].contiguous().view(-1),
TYPE_POINTER: type_t[:, 1:].contiguous().view(-1),
}

# compute loss
loss = criterion(output, target) if args.task == MULTITASK else criterion(output[args.task], target[args.task])

# record loss
losses.update(loss.detach(), input.size(0))

# compute gradient and do Adam step
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
optimizer.step()

# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()

batch_progress = int(((i+1)/total_batches)*100) # percentage
if batch_progress > batch_progress_old:
LOGGER.info(f'{epoch}: Batch {batch_progress:02d}% - Train loss {losses.val:.4f} ({losses.avg:.4f})')
batch_progress_old = batch_progress
with tqdm(total=total_batches, desc=f'Epoch {epoch + 1}/{args.epochs}') as pbar:
for i, batch in enumerate(train_loader):
# get inputs
input = batch.input
logical_form = batch.logical_form
ner = batch.ner
coref = batch.coref
predicate_t = batch.predicate_pointer
type_t = batch.type_pointer

# compute output
output = model(input, logical_form[:, :-1])
LOGGER.debug(f'output[NER] in train: ({output[NER].shape}) {output[NER]}')
LOGGER.debug(f'output[COREF] in train: ({output[COREF].shape}) {output[COREF]}')

ner_out = output[NER].detach().argmax(1).tolist()
LOGGER.debug(f'ner_out in train: ({len(ner_out)}) {ner_out}')
ner_str = [vocabs[NER].itos[i] for i in ner_out][1:-1]
LOGGER.debug(f'ner_str in train: ({len(ner_str)}) {ner_str}')
ner_indices = {k: tag.split('-')[-1] for k, tag in enumerate(ner_str) if
tag.startswith(B) or tag.startswith(I)} # idx: type_id
LOGGER.debug(f'ner_indices in train: ({len(ner_indices)}) {ner_indices}')

# prepare targets
target = {
LOGICAL_FORM: logical_form[:, 1:].contiguous().view(-1),
NER: ner.contiguous().view(-1),
COREF: coref.contiguous().view(-1),
PREDICATE_POINTER: predicate_t[:, 1:].contiguous().view(-1),
TYPE_POINTER: type_t[:, 1:].contiguous().view(-1),
}

# compute loss
loss = criterion(output, target) if args.task == MULTITASK else criterion(output[args.task], target[args.task])

# record loss
losses.update(loss.detach(), input.size(0))

# compute gradient and do Adam step
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
optimizer.step()

# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()

pbar.set_postfix({'loss': losses.val, 'avg': losses.avg})
pbar.update(1)

# batch_progress = int(((i+1)/total_batches)*100) # percentage
# if batch_progress > batch_progress_old:
# LOGGER.info(f'{epoch}: Batch {batch_progress:02d}% - Train loss {losses.val:.4f} ({losses.avg:.4f})')
# batch_progress_old = batch_progress

LOGGER.info(f'{epoch}: Train loss: {losses.avg:.4f}')


def validate(val_loader, model, vocabs, helper_data, criterion, single_task_loss):
Expand Down
4 changes: 2 additions & 2 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,8 +456,8 @@ def rapidfuzz_query(query, filter_type, kg, res_size=50):
return filtered_res if filtered_res else unfiltered_res


def save_checkpoint(state):
filename = f'{ROOT_PATH}/{args.snapshots}/{MODEL_NAME}_e{state[EPOCH]}_v{state[CURR_VAL]:.4f}_{args.task}.pth.tar'
def save_checkpoint(state: dict, experiment: str = ""):
filename = f'{ROOT_PATH}/{args.snapshots}/{MODEL_NAME}_{experiment}_e{state[EPOCH]}_v{state[CURR_VAL]:.4f}_{args.task}.pth.tar'
torch.save(state, filename)


Expand Down

0 comments on commit def05e9

Please sign in to comment.