From ae2d315d8c9f9f6a5442fa3be5a698e59c3de1fd Mon Sep 17 00:00:00 2001 From: "vejvarm@freya" Date: Wed, 22 Nov 2023 20:43:24 +0900 Subject: [PATCH] feat: calculate accuracy and recall on test set in `inference.py` --- README.md | 9 ++ args.py | 6 +- inference.py | 240 +++++++++++++++++++++++++++++++-------------------- utils.py | 81 +++++++++++++++-- 4 files changed, 234 insertions(+), 102 deletions(-) diff --git a/README.md b/README.md index bf7cb0a..b1a8565 100644 --- a/README.md +++ b/README.md @@ -56,6 +56,15 @@ For training you will need to adjust the paths in [args](args.py) file. At the s python train.py ``` +## Inference Framework +Calculates accuracy and recall on test split +- accuracy averaging: 'micro' +- recal averaging: 'macro' +``` bash +python inference.py --name csqa --batch-size 50 --data-path data/final/csqa --model-path experiments/models/CARTONNER_csqa_e10_v0.0102_multitask.pth.tar +``` +will save metric results as JSON files into `ROOT_PATH/args.path_inference/args.name` folder. + ## Generate Actions After the model has finished training we perform the inference in 2 steps. First, we generate the actions and save them in JSON file using the trained model. diff --git a/args.py b/args.py index e0bf3b0..9549cfd 100644 --- a/args.py +++ b/args.py @@ -26,7 +26,7 @@ def get_parser(): parser.add_argument('--snapshots', default='experiments/snapshots', type=str) parser.add_argument('--path_results', default='experiments/results', type=str) parser.add_argument('--path_error_analysis', default='experiments/error_analysis', type=str) - parser.add_argument('--path_inference', default='experiments/inference', type=str) + parser.add_argument('--path-inference', default='experiments/inference', type=str) # task parser.add_argument('--task', default=Task.MULTITASK.value, choices=[tsk.value for tsk in Task], type=str) @@ -54,11 +54,11 @@ def get_parser(): parser.add_argument('--valfreq', default=1, type=int) parser.add_argument('--resume', default='', type=str) parser.add_argument('--clip', default=5, type=int) - parser.add_argument('--batch_size', default=25, type=int) # NOTE: changed from 25 + parser.add_argument('--batch-size', default=25, type=int) # NOTE: changed from 25 parser.add_argument('--pool_size', default=100, type=int) # test and inference - parser.add_argument('--model-path', default='experiments/models/CARTONNER_merged_e10_v0.0153_multitask.pth.tar', + parser.add_argument('--model-path', default='experiments/models/CARTONNER_csqa_e10_v0.0102_multitask.pth.tar', type=str) parser.add_argument('--file_path', default='/data/final/csqa/process/test.json', type=str) parser.add_argument('--inference_partition', default=InferencePartition.TEST.value, diff --git a/inference.py b/inference.py index a86966d..63a6178 100644 --- a/inference.py +++ b/inference.py @@ -1,3 +1,5 @@ +import json +import pathlib import random import re import numpy as np @@ -10,7 +12,7 @@ from dataset import CSQADataset, collate_fn from model import CARTON -from utils import Predictor, AverageMeter, MultiTaskAcc +from utils import Predictor, AverageMeter, MultiTaskAcc, MultiTaskAccTorchmetrics, MultiTaskRecTorchmetrics from constants import DEVICE, LOGICAL_FORM, COREF, NER, INPUT, PREDICATE_POINTER, TYPE_POINTER, ROOT_PATH from args import get_parser @@ -19,13 +21,11 @@ args = parser.parse_args() # args.seed = 69 # canada, queen victoria, lefty west -args.seed = 100 -args.batch_size = 1 +# args.seed = 100 +# args.batch_size = 1 # TODO: figure out what I just did :D -# TODO: is the accuracy being calcualted correctly? (most certainly not) # TODO: what would it take to calculate accuracy based on completel logical form!? -# TODO: save accuracies to files random.seed(args.seed) np.random.seed(args.seed) @@ -44,6 +44,9 @@ def extract_entities_and_sentences(input_batch, ner_batch, coref_batch): + + # TODO: fix `extract_entities_and_sentences` (each token following after entity is cut off) + batch_entities_sentences = [] for input_decoded, ner, coref in zip(input_batch, ner_batch, coref_batch): entities = {"NA": []} @@ -81,7 +84,18 @@ def extract_entities_and_sentences(input_batch, ner_batch, coref_batch): return batch_entities_sentences +def save_meter_to_file(meter_dict: dict[str: AverageMeter], path_to_file: pathlib.Path): + results = {nm: metric.avg.cpu().tolist() for nm, metric in meter_dict.items()} + results["average"] = np.mean([v for v in results.values()]) + + with path_to_file.open("w") as f: + json.dump(results, f, indent=4) + + if __name__ == "__main__": + save_path = ROOT_PATH.joinpath(args.path_inference).joinpath(args.name) + print(f"results will be saved to `{save_path}`.") + # load data dataset = CSQADataset(args, splits=('test', )) # assuming we already have the correct vocab cache from all splits! vocabs = dataset.get_vocabs() @@ -95,6 +109,7 @@ def extract_entities_and_sentences(input_batch, ner_batch, coref_batch): total_batches = (len(test_loader.dataset) + args.batch_size - 1) // args.batch_size pad = {k: v.stoi["[PAD]"] for k, v in vocabs.items() if k != "id"} + num_classes = {k: len(v) for k, v in vocabs.items() if k != "id"} model = CARTON(vocabs, DEVICE).to(DEVICE) print(f"=> loading checkpoint '{args.model_path}'") @@ -105,101 +120,140 @@ def extract_entities_and_sentences(input_batch, ner_batch, coref_batch): predictor = Predictor(model, vocabs) - acc_calculator = MultiTaskAcc(DEVICE) + + + # acc_calculator = MultiTaskAcc(DEVICE) + # accuracies = {LOGICAL_FORM: AverageMeter(), + # NER: AverageMeter(), + # COREF: AverageMeter(), + # PREDICATE_POINTER: AverageMeter(), + # TYPE_POINTER: AverageMeter()} + + acc_calculator = MultiTaskAccTorchmetrics(num_classes, pads=pad, device=DEVICE, averaging_type='micro') # !we use 'micro' to NOT bloat up classes, which don't have much samples (that would be useful for training) accuracies = {LOGICAL_FORM: AverageMeter(), NER: AverageMeter(), COREF: AverageMeter(), PREDICATE_POINTER: AverageMeter(), TYPE_POINTER: AverageMeter()} + rec_calculator = MultiTaskRecTorchmetrics(num_classes, pads=pad, device=DEVICE) + recalls = {LOGICAL_FORM: AverageMeter(), + NER: AverageMeter(), + COREF: AverageMeter(), + PREDICATE_POINTER: AverageMeter(), + TYPE_POINTER: AverageMeter()} + # for i, data in random.sample(test_loader, 5): - with tqdm(total=total_batches, desc=f'Inference') as pbar: - for i, batch in enumerate(test_loader): - """ - Using model to do inference - """ - # TODO - - # ner = batch.ner - # coref = batch.coref - # predicate_t = batch.predicate_pointer - # type_t = batch.type_pointer - - # compute output - output = model(batch.input, batch.logical_form[:, :-1]) - # use input and NER to extract entity labels and types - # use KG to look for entities with that label and type - - # match found entities with expected entities (accuracy) - - # match predicate_pointer output (accuracy) - # match type_pointer output (accuracy) - # match logical_form output (accuracy) - - target = { - LOGICAL_FORM: batch.logical_form[:, 1:].contiguous().view(-1), - NER: batch.ner.contiguous().view(-1), - COREF: batch.coref.contiguous().view(-1), - PREDICATE_POINTER: batch.predicate_pointer[:, 1:].contiguous().view(-1), - TYPE_POINTER: batch.type_pointer[:, 1:].contiguous().view(-1), - } - - accs = acc_calculator(output, target) - - for name, meter in accuracies.items(): - meter.update(accs[name]) - - # """ - # Below are the labels - # """ - # # Convert tensors to lists - # input_batch = [[vocabs['input'].itos[tok] for tok in sample if tok != pad['input']] for sample in batch.input] - # ner_batch = [[vocabs['ner'].itos[tok] for tok in sample if tok != pad['ner']] for sample in batch.ner] - # coref_batch = [[vocabs['coref'].itos[tok] for tok in sample if tok != pad['coref']] for sample in batch.coref] - # lf_batch = [[vocabs['logical_form'].itos[tok] for tok in sample if tok != pad['logical_form']] for sample in batch.logical_form] - # pp_batch = [[vocabs['predicate_pointer'].itos[tok] for tok in sample if tok != pad['predicate_pointer']] for sample in batch.predicate_pointer] - # tp_batch = [[vocabs['type_pointer'].itos[tok] for tok in sample if tok != pad['type_pointer']] for sample in batch.type_pointer] - # - # batch_results = extract_entities_and_sentences(input_batch, ner_batch, coref_batch) - # - # # TODO: what do we do with [PAD] tokens (Remove/keep and mask?) when calculating accuracy? - # - # #### DEBUG - # # find all B-'s ... extract the type_id from there - # entities = batch_results[0]['entities'] - # sent = batch_results[0]['sent'] - # - # input_decoded = input_batch[0] - # ner = ner_batch[0] - # coref = coref_batch[0] - # - # lf_decoded = lf_batch[0] - # pp_decoded = pp_batch[0] - # tp_decoded = tp_batch[0] - # - # df_inp = pandas.DataFrame.from_dict({"input": input_decoded, "ner": ner, "coref": coref}) - # df_out = pandas.DataFrame.from_dict({"lf": lf_decoded, "pp": pp_decoded, "tp": tp_decoded}) - # - # # print(f"### input: {re.sub(PUNCTUATION_PATTERN, '', ' '.join(input_decoded).replace(' ##', ''))}") - # print(f"### input: {sent}") - # print(entities) - # print(df_inp) - # print(df_out) - # print("##########################################\n") - # ### - # #### DEBUG - # - # # in lf - # # fill relation with decoded relation_pointer - # # fill type with decoded type_pointer # NOTE: Insert doesn't use type_pointer at all - - pbar.set_postfix({'lf': accuracies[LOGICAL_FORM].avg, 'ner': accuracies[NER].avg, - 'coref': accuracies[COREF].avg, 'pp': accuracies[PREDICATE_POINTER].avg, - 'tp': accuracies[TYPE_POINTER].avg}) - pbar.update(1) - - # if i >= 5: - # break + with torch.no_grad(): + with tqdm(total=total_batches, desc=f'Inference') as pbar: + for i, batch in enumerate(test_loader): + """ + Using model to do inference + """ + + # ner = batch.ner + # coref = batch.coref + # predicate_t = batch.predicate_pointer + # type_t = batch.type_pointer + + # compute output + output = model(batch.input, batch.logical_form[:, :-1]) + # use input and NER to extract entity labels and types + # use KG to look for entities with that label and type + + # match found entities with expected entities (accuracy) + + # match predicate_pointer output (accuracy) + # match type_pointer output (accuracy) + # match logical_form output (accuracy) + + target = { + LOGICAL_FORM: batch.logical_form[:, 1:].contiguous().view(-1), + NER: batch.ner.contiguous().view(-1), + COREF: batch.coref.contiguous().view(-1), + PREDICATE_POINTER: batch.predicate_pointer[:, 1:].contiguous().view(-1), + TYPE_POINTER: batch.type_pointer[:, 1:].contiguous().view(-1), + } + + accs = acc_calculator(output, target) + for name, meter in accuracies.items(): + meter.update(accs[name]) + + recs = rec_calculator(output, target) + for name, meter in recalls.items(): + meter.update(recs[name]) + + # # ### DEBUG + # """ + # Below are the labels + # """ + # # Convert tensors to lists + # input_batch = [[vocabs['input'].itos[tok] for tok in sample if tok != pad['input']] for sample in batch.input] + # ner_batch = [[vocabs['ner'].itos[tok] for tok in sample if tok != pad['ner']] for sample in batch.ner] + # coref_batch = [[vocabs['coref'].itos[tok] for tok in sample if tok != pad['coref']] for sample in batch.coref] + # lf_batch = [[vocabs['logical_form'].itos[tok] for tok in sample if tok != pad['logical_form']] for sample in batch.logical_form] + # pp_batch = [[vocabs['predicate_pointer'].itos[tok] for tok in sample if tok != pad['predicate_pointer']] for sample in batch.predicate_pointer] + # tp_batch = [[vocabs['type_pointer'].itos[tok] for tok in sample if tok != pad['type_pointer']] for sample in batch.type_pointer] + # + # batch_results = extract_entities_and_sentences(input_batch, ner_batch, coref_batch) + # + # # TODO: what do we do with [PAD] tokens (Remove/keep and mask?) when calculating accuracy? + # # find all B-'s ... extract the type_id from there + # entities = batch_results[0]['entities'] + # sent = batch_results[0]['sent'] + # + # input_decoded = input_batch[0] + # ner = ner_batch[0] + # coref = coref_batch[0] + # + # lf_decoded = lf_batch[0] + # pp_decoded = pp_batch[0] + # tp_decoded = tp_batch[0] + # + # df_inp = pandas.DataFrame.from_dict({"input": input_decoded, "ner": ner, "coref": coref}) + # df_out = pandas.DataFrame.from_dict({"lf": lf_decoded, "pp": pp_decoded, "tp": tp_decoded}) + # + # csv_path = ROOT_PATH.joinpath("csv") + # csv_path.mkdir(exist_ok=True, parents=True) + # with csv_path.joinpath(f'test_{i}-asent.json').open("w") as f: + # json.dump({'sent': sent, 'entities': entities}, f, indent=4) + # with csv_path.joinpath(f"test_{i}-binp.csv").open("w") as f: + # df_inp.to_csv(f) + # with csv_path.joinpath(f"test_{i}-cout.csv").open("w") as f: + # df_out.to_csv(f) + # + # # print(f"### input: {re.sub(PUNCTUATION_PATTERN, '', ' '.join(input_decoded).replace(' ##', ''))}") + # print(f"### input: {sent}") + # print(entities) + # print(df_inp) + # print(df_out) + # print("##########################################\n") + # + # # in lf + # # fill relation with decoded relation_pointer + # # fill type with decoded type_pointer # NOTE: Insert doesn't use type_pointer + # # fill entities with id=search(label, type) but first order them by coref + # # TODO: \O.o/ dont forget our nice extraction code above + # # ### DEBUG + + pbar.set_postfix({'lf': f"{accuracies[LOGICAL_FORM].avg:.4f}|{recalls[LOGICAL_FORM].avg:.4f}", + 'ner': f"{accuracies[NER].avg:.4f}|{recalls[NER].avg:.4f}", + 'coref': f"{accuracies[COREF].avg:.4f}|{recalls[COREF].avg:.4f}", + 'pp': f"{accuracies[PREDICATE_POINTER].avg:.4f}|{recalls[PREDICATE_POINTER].avg:.4f}", + 'tp': f"{accuracies[TYPE_POINTER].avg:.4f}|{recalls[TYPE_POINTER].avg:.4f}"}) + pbar.update(1) + + # break + + # if i >= 5: + # break + + # save metric results + save_path.mkdir(exist_ok=True, parents=True) + path_to_acc = save_path.joinpath("acc.json") + path_to_rec = save_path.joinpath("rec.json") + save_meter_to_file(accuracies, path_to_acc) + save_meter_to_file(recalls, path_to_rec) """ diff --git a/utils.py b/utils.py index 8f689d2..0391b84 100644 --- a/utils.py +++ b/utils.py @@ -5,6 +5,7 @@ import logging import torch.nn as nn from tqdm import tqdm +from torchmetrics.classification import MulticlassAccuracy, MulticlassRecall import helpers from action_executor.actions import search_by_label, create_entity @@ -55,10 +56,12 @@ def rate(self, step = None): def zero_grad(self): self.optimizer.zero_grad() + # meter class for storing results class AverageMeter(object): """Computes and stores the average and current value""" - def __init__(self): + def __init__(self, name="meter"): + self.name = name self.reset() def reset(self): @@ -73,6 +76,7 @@ def update(self, val, n=1): self.count += n self.avg = self.sum / self.count + class Predictor(object): """Predictor class""" def __init__(self, model, vocabs): @@ -479,7 +483,7 @@ def forward(self, output, target): class SingleTaskAccuracy(nn.Module): - '''Single Task Accuracy''' + '''Single Task Accuracy (equivalent to "micro"-averaged accuracy)''' def __init__(self, device=DEVICE): super().__init__() self.device = device @@ -488,14 +492,16 @@ def forward(self, output, target): # Assuming outputs and labels are torch tensors. # Outputs could be raw logits or probabilities from the last layer of a neural network # Convert outputs to predicted class indices if they are not already - # TODO: maybe do .to(self.device) - preds = output.argmax(dim=1) if output.ndim > 1 else output + preds = output.argmax(dim=1) if output.ndim > 1 else output # get token with max probability + # print(f"output ({output.ndim}|{preds.ndim}): {target}|{preds}") correct = preds.eq(target).sum() + return correct.float() / target.size(0) class MultiTaskAcc(nn.Module): - '''Multi Task Learning Accuracy Calculation''' + """Multi Task Learning Accuracy Calculation""" + def __init__(self, device=DEVICE): super().__init__() self.device = device @@ -528,8 +534,68 @@ def forward(self, output, target): } +class MultiTaskAccTorchmetrics(nn.Module): + """Multi Task Learning Accuracy Calculation implemented via TorchMetrics.""" + + def __init__(self, num_classes: dict, pads: dict = None, device=DEVICE, averaging_type="macro", + module_names=(LOGICAL_FORM, NER, COREF, PREDICATE_POINTER, TYPE_POINTER)): + """ + :param averaging_type: if "micro": Equivalent to the MultiTaskAcc class (good for eval) + if "macro": Gives equal weight to all classes (good for training, not good for eval) + if "weighted" macro, but weighted by class importance TODO: understand better + """ + super().__init__() + self.module_names = module_names + self.multi_acc = {} + for name in self.module_names: + n_classes = num_classes[name] + if pads is not None: + ignore_idx = pads[name] + else: + ignore_idx = None + self.multi_acc[name] = MulticlassAccuracy(average=averaging_type, multidim_average='global', + num_classes=n_classes, ignore_index=ignore_idx).to(device) + + def forward(self, output, target): + # weighted loss + accs = torch.stack([self.multi_acc[mn](output[mn], target[mn]) for mn in self.module_names]) + + results = {mn: accs[i] for i, mn in enumerate(self.module_names)} + results[MULTITASK] = accs.mean() + + return results + + +class MultiTaskRecTorchmetrics(nn.Module): + """Multi Task Learning Macro-averaged Recall Calculation implemented via torchmetrics""" + + def __init__(self, num_classes: dict, pads: dict = None, device=DEVICE, averaging_type="macro", + module_names=(LOGICAL_FORM, NER, COREF, PREDICATE_POINTER, TYPE_POINTER)): + super().__init__() + self.module_names = module_names + self.multi_rec = {} + for name in self.module_names: + n_classes = num_classes[name] + if pads is not None: + ignore_idx = pads[name] + else: + ignore_idx = None + self.multi_rec[name] = MulticlassRecall(average=averaging_type, multidim_average='global', + num_classes=n_classes, ignore_index=ignore_idx).to(device) + + def forward(self, output, target): + # weighted loss + recalls = torch.stack([self.multi_rec[mn](output[mn], target[mn]) for mn in self.module_names]) + + results = {mn: recalls[i] for i, mn in enumerate(self.module_names)} + results[MULTITASK] = recalls.mean() + # for mn in self.module_names: + # print(f"{output[mn].shape}|{target[mn].shape}") + return results + + class MultiTaskLoss(nn.Module): - '''Multi Task Learning Loss''' + """Multi Task Learning Loss""" def __init__(self, ignore_index, device=DEVICE): super().__init__() self.device = device @@ -567,12 +633,14 @@ def forward(self, output, target): MULTITASK: losses.mean() }[args.task] + def init_weights(model): # initialize model parameters with Glorot / fan_avg for p in model.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p) + # ANCHOR LASAGNE parameter initialisation def Embedding(num_embeddings, embedding_dim, padding_idx): """Embedding layer""" @@ -581,6 +649,7 @@ def Embedding(num_embeddings, embedding_dim, padding_idx): nn.init.constant_(m.weight[padding_idx], 0) return m + def Linear(in_features, out_features, bias=True): """Linear layer""" m = nn.Linear(in_features, out_features, bias=bias)