Skip to content

Commit

Permalink
feat: log training and validation loss in tensorboard
Browse files Browse the repository at this point in the history
  • Loading branch information
vejvarm committed Nov 27, 2023
1 parent d16d075 commit 037d943
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 8 deletions.
2 changes: 1 addition & 1 deletion args.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def get_parser():

# experiments
parser.add_argument('--snapshots', default='experiments/snapshots', type=str)
parser.add_argument('--path_results', default='experiments/results', type=str)
parser.add_argument('--path-results', default='experiments/results', type=str)
parser.add_argument('--path_error_analysis', default='experiments/error_analysis', type=str)
parser.add_argument('--path-inference', default='experiments/inference', type=str)

Expand Down
18 changes: 12 additions & 6 deletions train.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import time
import random
import logging
from typing import Iterator, List
from copy import deepcopy
from functools import partial

import numpy as np
Expand All @@ -12,6 +10,7 @@
from model import CARTON
from dataset import CSQADataset, collate_fn
from torch.utils.data import DataLoader, SequentialSampler, BatchSampler, RandomSampler
from torch.utils.tensorboard import SummaryWriter
from utils import (NoamOpt, AverageMeter, SingleTaskLoss, MultiTaskLoss, save_checkpoint, init_weights)

from helpers import setup_logger
Expand All @@ -21,11 +20,13 @@
parser = get_parser()
args = parser.parse_args()

LOGDIR = ROOT_PATH.joinpath(args.snapshots).joinpath("logs")
LOGDIR.mkdir(exist_ok=True, parents=True)
# set LOGGER
LOGGER = setup_logger(__name__,
loglevel=logging.INFO,
handlers=[logging.FileHandler(f'{args.path_results}/{MODEL_NAME}_{args.name}_train_{args.task}.log', 'w'),
logging.StreamHandler()])
handlers=[logging.FileHandler(LOGDIR.joinpath(f"{MODEL_NAME}_{args.name}_train_{args.task}.log"), 'w'),
logging.StreamHandler()])

# set a seed value
random.seed(args.seed)
Expand Down Expand Up @@ -115,6 +116,8 @@ def main():
LOGGER.info(f'Epochs: {args.epochs}')
LOGGER.info(f'Batch size: {args.batch_size}')

tb_writer = SummaryWriter(LOGDIR.joinpath("tb"))

# run epochs
for epoch in range(args.start_epoch, args.epochs):
# evaluate on validation set
Expand All @@ -131,9 +134,11 @@ def main():
experiment=args.name
)
LOGGER.info(f'* Val loss: {val_loss:.4f}')
tb_writer.add_scalar('val loss', val_loss, epoch)

# train for one epoch
train(train_loader, model, vocabs, helper_dict['train'], criterion, optimizer, epoch)
train_loss = train(train_loader, model, vocabs, helper_dict['train'], criterion, optimizer, epoch)
tb_writer.add_scalar('training loss', train_loss, epoch+1)

# Validate and save the final epoch
val_loss = validate(val_loader, model, vocabs, helper_dict['val'], criterion, single_task_loss)
Expand All @@ -148,6 +153,7 @@ def main():
experiment=args.name
)
LOGGER.info(f'* Val loss: {val_loss:.4f}')
tb_writer.add_scalar('val loss', val_loss, args.epochs)


def train(train_loader, model, vocabs, helper_data, criterion, optimizer, epoch):
Expand Down Expand Up @@ -217,6 +223,7 @@ def train(train_loader, model, vocabs, helper_data, criterion, optimizer, epoch)
# batch_progress_old = batch_progress

LOGGER.info(f'{epoch}: Train loss: {losses.avg:.4f}')
return losses.avg


def validate(val_loader, model, vocabs, helper_data, criterion, single_task_loss):
Expand Down Expand Up @@ -276,7 +283,6 @@ def validate(val_loader, model, vocabs, helper_data, criterion, single_task_loss

LOGGER.info(f"Val losses:: LF: {losses_lf.avg} | NER: {losses_ner.avg} | COREF: {losses_coref.avg} | "
f"PRED: {losses_pred.avg} | TYPE: {losses_type.avg}")

return losses.avg


Expand Down
2 changes: 1 addition & 1 deletion utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,7 +468,7 @@ def rapidfuzz_query(query, filter_type, kg, res_size=50):


def save_checkpoint(state: dict, experiment: str = ""):
filename = f'{ROOT_PATH}/{args.snapshots}/{MODEL_NAME}_{experiment}_e{state[EPOCH]}_v{state[CURR_VAL]:.4f}_{args.task}.pth.tar'
filename = ROOT_PATH.joinpath(args.snapshots).joinpath(experiment).joinpath(f"{MODEL_NAME}_{experiment}_e{state[EPOCH]}_v{state[CURR_VAL]:.4f}_{args.task}.pth.tar")
torch.save(state, filename)


Expand Down

0 comments on commit 037d943

Please sign in to comment.