Skip to content

Commit

Permalink
feat: calculate accuracy and recall on test set in inference.py
Browse files Browse the repository at this point in the history
  • Loading branch information
vejvarm committed Nov 22, 2023
1 parent aeb9036 commit ae2d315
Show file tree
Hide file tree
Showing 4 changed files with 234 additions and 102 deletions.
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,15 @@ For training you will need to adjust the paths in [args](args.py) file. At the s
python train.py
```

## Inference Framework
Calculates accuracy and recall on test split
- accuracy averaging: 'micro'
- recal averaging: 'macro'
``` bash
python inference.py --name csqa --batch-size 50 --data-path data/final/csqa --model-path experiments/models/CARTONNER_csqa_e10_v0.0102_multitask.pth.tar
```
will save metric results as JSON files into `ROOT_PATH/args.path_inference/args.name` folder.

## Generate Actions
After the model has finished training we perform the inference in 2 steps.
First, we generate the actions and save them in JSON file using the trained model.
Expand Down
6 changes: 3 additions & 3 deletions args.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def get_parser():
parser.add_argument('--snapshots', default='experiments/snapshots', type=str)
parser.add_argument('--path_results', default='experiments/results', type=str)
parser.add_argument('--path_error_analysis', default='experiments/error_analysis', type=str)
parser.add_argument('--path_inference', default='experiments/inference', type=str)
parser.add_argument('--path-inference', default='experiments/inference', type=str)

# task
parser.add_argument('--task', default=Task.MULTITASK.value, choices=[tsk.value for tsk in Task], type=str)
Expand Down Expand Up @@ -54,11 +54,11 @@ def get_parser():
parser.add_argument('--valfreq', default=1, type=int)
parser.add_argument('--resume', default='', type=str)
parser.add_argument('--clip', default=5, type=int)
parser.add_argument('--batch_size', default=25, type=int) # NOTE: changed from 25
parser.add_argument('--batch-size', default=25, type=int) # NOTE: changed from 25
parser.add_argument('--pool_size', default=100, type=int)

# test and inference
parser.add_argument('--model-path', default='experiments/models/CARTONNER_merged_e10_v0.0153_multitask.pth.tar',
parser.add_argument('--model-path', default='experiments/models/CARTONNER_csqa_e10_v0.0102_multitask.pth.tar',
type=str)
parser.add_argument('--file_path', default='/data/final/csqa/process/test.json', type=str)
parser.add_argument('--inference_partition', default=InferencePartition.TEST.value,
Expand Down
240 changes: 147 additions & 93 deletions inference.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import json
import pathlib
import random
import re
import numpy as np
Expand All @@ -10,7 +12,7 @@

from dataset import CSQADataset, collate_fn
from model import CARTON
from utils import Predictor, AverageMeter, MultiTaskAcc
from utils import Predictor, AverageMeter, MultiTaskAcc, MultiTaskAccTorchmetrics, MultiTaskRecTorchmetrics

from constants import DEVICE, LOGICAL_FORM, COREF, NER, INPUT, PREDICATE_POINTER, TYPE_POINTER, ROOT_PATH
from args import get_parser
Expand All @@ -19,13 +21,11 @@
args = parser.parse_args()

# args.seed = 69 # canada, queen victoria, lefty west
args.seed = 100
args.batch_size = 1
# args.seed = 100
# args.batch_size = 1

# TODO: figure out what I just did :D
# TODO: is the accuracy being calcualted correctly? (most certainly not)
# TODO: what would it take to calculate accuracy based on completel logical form!?
# TODO: save accuracies to files

random.seed(args.seed)
np.random.seed(args.seed)
Expand All @@ -44,6 +44,9 @@


def extract_entities_and_sentences(input_batch, ner_batch, coref_batch):

# TODO: fix `extract_entities_and_sentences` (each token following after entity is cut off)

batch_entities_sentences = []
for input_decoded, ner, coref in zip(input_batch, ner_batch, coref_batch):
entities = {"NA": []}
Expand Down Expand Up @@ -81,7 +84,18 @@ def extract_entities_and_sentences(input_batch, ner_batch, coref_batch):
return batch_entities_sentences


def save_meter_to_file(meter_dict: dict[str: AverageMeter], path_to_file: pathlib.Path):
results = {nm: metric.avg.cpu().tolist() for nm, metric in meter_dict.items()}
results["average"] = np.mean([v for v in results.values()])

with path_to_file.open("w") as f:
json.dump(results, f, indent=4)


if __name__ == "__main__":
save_path = ROOT_PATH.joinpath(args.path_inference).joinpath(args.name)
print(f"results will be saved to `{save_path}`.")

# load data
dataset = CSQADataset(args, splits=('test', )) # assuming we already have the correct vocab cache from all splits!
vocabs = dataset.get_vocabs()
Expand All @@ -95,6 +109,7 @@ def extract_entities_and_sentences(input_batch, ner_batch, coref_batch):
total_batches = (len(test_loader.dataset) + args.batch_size - 1) // args.batch_size

pad = {k: v.stoi["[PAD]"] for k, v in vocabs.items() if k != "id"}
num_classes = {k: len(v) for k, v in vocabs.items() if k != "id"}

model = CARTON(vocabs, DEVICE).to(DEVICE)
print(f"=> loading checkpoint '{args.model_path}'")
Expand All @@ -105,101 +120,140 @@ def extract_entities_and_sentences(input_batch, ner_batch, coref_batch):

predictor = Predictor(model, vocabs)

acc_calculator = MultiTaskAcc(DEVICE)


# acc_calculator = MultiTaskAcc(DEVICE)
# accuracies = {LOGICAL_FORM: AverageMeter(),
# NER: AverageMeter(),
# COREF: AverageMeter(),
# PREDICATE_POINTER: AverageMeter(),
# TYPE_POINTER: AverageMeter()}

acc_calculator = MultiTaskAccTorchmetrics(num_classes, pads=pad, device=DEVICE, averaging_type='micro') # !we use 'micro' to NOT bloat up classes, which don't have much samples (that would be useful for training)
accuracies = {LOGICAL_FORM: AverageMeter(),
NER: AverageMeter(),
COREF: AverageMeter(),
PREDICATE_POINTER: AverageMeter(),
TYPE_POINTER: AverageMeter()}

rec_calculator = MultiTaskRecTorchmetrics(num_classes, pads=pad, device=DEVICE)
recalls = {LOGICAL_FORM: AverageMeter(),
NER: AverageMeter(),
COREF: AverageMeter(),
PREDICATE_POINTER: AverageMeter(),
TYPE_POINTER: AverageMeter()}

# for i, data in random.sample(test_loader, 5):
with tqdm(total=total_batches, desc=f'Inference') as pbar:
for i, batch in enumerate(test_loader):
"""
Using model to do inference
"""
# TODO

# ner = batch.ner
# coref = batch.coref
# predicate_t = batch.predicate_pointer
# type_t = batch.type_pointer

# compute output
output = model(batch.input, batch.logical_form[:, :-1])
# use input and NER to extract entity labels and types
# use KG to look for entities with that label and type

# match found entities with expected entities (accuracy)

# match predicate_pointer output (accuracy)
# match type_pointer output (accuracy)
# match logical_form output (accuracy)

target = {
LOGICAL_FORM: batch.logical_form[:, 1:].contiguous().view(-1),
NER: batch.ner.contiguous().view(-1),
COREF: batch.coref.contiguous().view(-1),
PREDICATE_POINTER: batch.predicate_pointer[:, 1:].contiguous().view(-1),
TYPE_POINTER: batch.type_pointer[:, 1:].contiguous().view(-1),
}

accs = acc_calculator(output, target)

for name, meter in accuracies.items():
meter.update(accs[name])

# """
# Below are the labels
# """
# # Convert tensors to lists
# input_batch = [[vocabs['input'].itos[tok] for tok in sample if tok != pad['input']] for sample in batch.input]
# ner_batch = [[vocabs['ner'].itos[tok] for tok in sample if tok != pad['ner']] for sample in batch.ner]
# coref_batch = [[vocabs['coref'].itos[tok] for tok in sample if tok != pad['coref']] for sample in batch.coref]
# lf_batch = [[vocabs['logical_form'].itos[tok] for tok in sample if tok != pad['logical_form']] for sample in batch.logical_form]
# pp_batch = [[vocabs['predicate_pointer'].itos[tok] for tok in sample if tok != pad['predicate_pointer']] for sample in batch.predicate_pointer]
# tp_batch = [[vocabs['type_pointer'].itos[tok] for tok in sample if tok != pad['type_pointer']] for sample in batch.type_pointer]
#
# batch_results = extract_entities_and_sentences(input_batch, ner_batch, coref_batch)
#
# # TODO: what do we do with [PAD] tokens (Remove/keep and mask?) when calculating accuracy?
#
# #### DEBUG
# # find all B-'s ... extract the type_id from there
# entities = batch_results[0]['entities']
# sent = batch_results[0]['sent']
#
# input_decoded = input_batch[0]
# ner = ner_batch[0]
# coref = coref_batch[0]
#
# lf_decoded = lf_batch[0]
# pp_decoded = pp_batch[0]
# tp_decoded = tp_batch[0]
#
# df_inp = pandas.DataFrame.from_dict({"input": input_decoded, "ner": ner, "coref": coref})
# df_out = pandas.DataFrame.from_dict({"lf": lf_decoded, "pp": pp_decoded, "tp": tp_decoded})
#
# # print(f"### input: {re.sub(PUNCTUATION_PATTERN, '', ' '.join(input_decoded).replace(' ##', ''))}")
# print(f"### input: {sent}")
# print(entities)
# print(df_inp)
# print(df_out)
# print("##########################################\n")
# ###
# #### DEBUG
#
# # in lf
# # fill relation with decoded relation_pointer
# # fill type with decoded type_pointer # NOTE: Insert doesn't use type_pointer at all

pbar.set_postfix({'lf': accuracies[LOGICAL_FORM].avg, 'ner': accuracies[NER].avg,
'coref': accuracies[COREF].avg, 'pp': accuracies[PREDICATE_POINTER].avg,
'tp': accuracies[TYPE_POINTER].avg})
pbar.update(1)

# if i >= 5:
# break
with torch.no_grad():
with tqdm(total=total_batches, desc=f'Inference') as pbar:
for i, batch in enumerate(test_loader):
"""
Using model to do inference
"""

# ner = batch.ner
# coref = batch.coref
# predicate_t = batch.predicate_pointer
# type_t = batch.type_pointer

# compute output
output = model(batch.input, batch.logical_form[:, :-1])
# use input and NER to extract entity labels and types
# use KG to look for entities with that label and type

# match found entities with expected entities (accuracy)

# match predicate_pointer output (accuracy)
# match type_pointer output (accuracy)
# match logical_form output (accuracy)

target = {
LOGICAL_FORM: batch.logical_form[:, 1:].contiguous().view(-1),
NER: batch.ner.contiguous().view(-1),
COREF: batch.coref.contiguous().view(-1),
PREDICATE_POINTER: batch.predicate_pointer[:, 1:].contiguous().view(-1),
TYPE_POINTER: batch.type_pointer[:, 1:].contiguous().view(-1),
}

accs = acc_calculator(output, target)
for name, meter in accuracies.items():
meter.update(accs[name])

recs = rec_calculator(output, target)
for name, meter in recalls.items():
meter.update(recs[name])

# # ### DEBUG
# """
# Below are the labels
# """
# # Convert tensors to lists
# input_batch = [[vocabs['input'].itos[tok] for tok in sample if tok != pad['input']] for sample in batch.input]
# ner_batch = [[vocabs['ner'].itos[tok] for tok in sample if tok != pad['ner']] for sample in batch.ner]
# coref_batch = [[vocabs['coref'].itos[tok] for tok in sample if tok != pad['coref']] for sample in batch.coref]
# lf_batch = [[vocabs['logical_form'].itos[tok] for tok in sample if tok != pad['logical_form']] for sample in batch.logical_form]
# pp_batch = [[vocabs['predicate_pointer'].itos[tok] for tok in sample if tok != pad['predicate_pointer']] for sample in batch.predicate_pointer]
# tp_batch = [[vocabs['type_pointer'].itos[tok] for tok in sample if tok != pad['type_pointer']] for sample in batch.type_pointer]
#
# batch_results = extract_entities_and_sentences(input_batch, ner_batch, coref_batch)
#
# # TODO: what do we do with [PAD] tokens (Remove/keep and mask?) when calculating accuracy?
# # find all B-'s ... extract the type_id from there
# entities = batch_results[0]['entities']
# sent = batch_results[0]['sent']
#
# input_decoded = input_batch[0]
# ner = ner_batch[0]
# coref = coref_batch[0]
#
# lf_decoded = lf_batch[0]
# pp_decoded = pp_batch[0]
# tp_decoded = tp_batch[0]
#
# df_inp = pandas.DataFrame.from_dict({"input": input_decoded, "ner": ner, "coref": coref})
# df_out = pandas.DataFrame.from_dict({"lf": lf_decoded, "pp": pp_decoded, "tp": tp_decoded})
#
# csv_path = ROOT_PATH.joinpath("csv")
# csv_path.mkdir(exist_ok=True, parents=True)
# with csv_path.joinpath(f'test_{i}-asent.json').open("w") as f:
# json.dump({'sent': sent, 'entities': entities}, f, indent=4)
# with csv_path.joinpath(f"test_{i}-binp.csv").open("w") as f:
# df_inp.to_csv(f)
# with csv_path.joinpath(f"test_{i}-cout.csv").open("w") as f:
# df_out.to_csv(f)
#
# # print(f"### input: {re.sub(PUNCTUATION_PATTERN, '', ' '.join(input_decoded).replace(' ##', ''))}")
# print(f"### input: {sent}")
# print(entities)
# print(df_inp)
# print(df_out)
# print("##########################################\n")
#
# # in lf
# # fill relation with decoded relation_pointer
# # fill type with decoded type_pointer # NOTE: Insert doesn't use type_pointer
# # fill entities with id=search(label, type) but first order them by coref
# # TODO: \O.o/ dont forget our nice extraction code above
# # ### DEBUG

pbar.set_postfix({'lf': f"{accuracies[LOGICAL_FORM].avg:.4f}|{recalls[LOGICAL_FORM].avg:.4f}",
'ner': f"{accuracies[NER].avg:.4f}|{recalls[NER].avg:.4f}",
'coref': f"{accuracies[COREF].avg:.4f}|{recalls[COREF].avg:.4f}",
'pp': f"{accuracies[PREDICATE_POINTER].avg:.4f}|{recalls[PREDICATE_POINTER].avg:.4f}",
'tp': f"{accuracies[TYPE_POINTER].avg:.4f}|{recalls[TYPE_POINTER].avg:.4f}"})
pbar.update(1)

# break

# if i >= 5:
# break

# save metric results
save_path.mkdir(exist_ok=True, parents=True)
path_to_acc = save_path.joinpath("acc.json")
path_to_rec = save_path.joinpath("rec.json")
save_meter_to_file(accuracies, path_to_acc)
save_meter_to_file(recalls, path_to_rec)


"""
Expand Down
Loading

0 comments on commit ae2d315

Please sign in to comment.