From 037d943093dcd54558a35dd5a68edbc6ddb56fca Mon Sep 17 00:00:00 2001 From: "vejvarm@freya" Date: Mon, 27 Nov 2023 13:56:21 +0900 Subject: [PATCH] feat: log training and validation loss in tensorboard --- args.py | 2 +- train.py | 18 ++++++++++++------ utils.py | 2 +- 3 files changed, 14 insertions(+), 8 deletions(-) diff --git a/args.py b/args.py index aee8b6a..18b5d72 100644 --- a/args.py +++ b/args.py @@ -24,7 +24,7 @@ def get_parser(): # experiments parser.add_argument('--snapshots', default='experiments/snapshots', type=str) - parser.add_argument('--path_results', default='experiments/results', type=str) + parser.add_argument('--path-results', default='experiments/results', type=str) parser.add_argument('--path_error_analysis', default='experiments/error_analysis', type=str) parser.add_argument('--path-inference', default='experiments/inference', type=str) diff --git a/train.py b/train.py index aad38bb..ded8b07 100644 --- a/train.py +++ b/train.py @@ -1,8 +1,6 @@ import time import random import logging -from typing import Iterator, List -from copy import deepcopy from functools import partial import numpy as np @@ -12,6 +10,7 @@ from model import CARTON from dataset import CSQADataset, collate_fn from torch.utils.data import DataLoader, SequentialSampler, BatchSampler, RandomSampler +from torch.utils.tensorboard import SummaryWriter from utils import (NoamOpt, AverageMeter, SingleTaskLoss, MultiTaskLoss, save_checkpoint, init_weights) from helpers import setup_logger @@ -21,11 +20,13 @@ parser = get_parser() args = parser.parse_args() +LOGDIR = ROOT_PATH.joinpath(args.snapshots).joinpath("logs") +LOGDIR.mkdir(exist_ok=True, parents=True) # set LOGGER LOGGER = setup_logger(__name__, loglevel=logging.INFO, - handlers=[logging.FileHandler(f'{args.path_results}/{MODEL_NAME}_{args.name}_train_{args.task}.log', 'w'), - logging.StreamHandler()]) + handlers=[logging.FileHandler(LOGDIR.joinpath(f"{MODEL_NAME}_{args.name}_train_{args.task}.log"), 'w'), + logging.StreamHandler()]) # set a seed value random.seed(args.seed) @@ -115,6 +116,8 @@ def main(): LOGGER.info(f'Epochs: {args.epochs}') LOGGER.info(f'Batch size: {args.batch_size}') + tb_writer = SummaryWriter(LOGDIR.joinpath("tb")) + # run epochs for epoch in range(args.start_epoch, args.epochs): # evaluate on validation set @@ -131,9 +134,11 @@ def main(): experiment=args.name ) LOGGER.info(f'* Val loss: {val_loss:.4f}') + tb_writer.add_scalar('val loss', val_loss, epoch) # train for one epoch - train(train_loader, model, vocabs, helper_dict['train'], criterion, optimizer, epoch) + train_loss = train(train_loader, model, vocabs, helper_dict['train'], criterion, optimizer, epoch) + tb_writer.add_scalar('training loss', train_loss, epoch+1) # Validate and save the final epoch val_loss = validate(val_loader, model, vocabs, helper_dict['val'], criterion, single_task_loss) @@ -148,6 +153,7 @@ def main(): experiment=args.name ) LOGGER.info(f'* Val loss: {val_loss:.4f}') + tb_writer.add_scalar('val loss', val_loss, args.epochs) def train(train_loader, model, vocabs, helper_data, criterion, optimizer, epoch): @@ -217,6 +223,7 @@ def train(train_loader, model, vocabs, helper_data, criterion, optimizer, epoch) # batch_progress_old = batch_progress LOGGER.info(f'{epoch}: Train loss: {losses.avg:.4f}') + return losses.avg def validate(val_loader, model, vocabs, helper_data, criterion, single_task_loss): @@ -276,7 +283,6 @@ def validate(val_loader, model, vocabs, helper_data, criterion, single_task_loss LOGGER.info(f"Val losses:: LF: {losses_lf.avg} | NER: {losses_ner.avg} | COREF: {losses_coref.avg} | " f"PRED: {losses_pred.avg} | TYPE: {losses_type.avg}") - return losses.avg diff --git a/utils.py b/utils.py index 0391b84..da7f886 100644 --- a/utils.py +++ b/utils.py @@ -468,7 +468,7 @@ def rapidfuzz_query(query, filter_type, kg, res_size=50): def save_checkpoint(state: dict, experiment: str = ""): - filename = f'{ROOT_PATH}/{args.snapshots}/{MODEL_NAME}_{experiment}_e{state[EPOCH]}_v{state[CURR_VAL]:.4f}_{args.task}.pth.tar' + filename = ROOT_PATH.joinpath(args.snapshots).joinpath(experiment).joinpath(f"{MODEL_NAME}_{experiment}_e{state[EPOCH]}_v{state[CURR_VAL]:.4f}_{args.task}.pth.tar") torch.save(state, filename)