Skip to content

Commit

Permalink
refactor: add pre-padding with [START] token to logical_form during t…
Browse files Browse the repository at this point in the history
…raining/validation/inference
  • Loading branch information
vejvarm committed Jan 1, 2024
1 parent 87fe44c commit 4445516
Show file tree
Hide file tree
Showing 5 changed files with 122 additions and 19 deletions.
90 changes: 84 additions & 6 deletions lab/build_sentences.py → build_sentences.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import torch
from tqdm import tqdm

from dataset import CSQADataset, collate_fn
from dataset import CSQADataset, collate_fn, prepad_tensors_with_start_tokens
from model import CARTON
from utils import Predictor, AverageMeter, MultiTaskAcc, MultiTaskAccTorchmetrics, MultiTaskRecTorchmetrics

Expand All @@ -25,6 +25,7 @@
# args.seed = 69 # canada, queen victoria, lefty west
# args.seed = 100
args.batch_size = 1
assert args.batch_size == 1, "batch_size must be 1 for building filled logical forms"

# TODO: what would it take to calculate accuracy based on completel logical form!?

Expand Down Expand Up @@ -102,6 +103,51 @@ def extract_entities_and_sentences(input_batch, ner_batch, coref_batch):
return batch_entities_sentences


def compose_logical_form(inp, pred_lf, pred_coref, pred_pp, pred_tp, entities):
inp_str = " ".join(inp)

lf = pred_lf
# ner = preds[NER][0]
coref = pred_coref
pp = pred_pp
tp = pred_tp

# for CSQA it works, but we get coref indexing errors for Merged, as one entity label belongs to more than one lf `entity` slot
# TODO: fix this:
# ['entity', 'relation', 'entity', 'insert', 'entity', 'relation', 'entity']
# ['0', '1']
# {'NA': [], '0': 'japan national route 415', '1': 'national highway of japan'}
# ['1']
# {'NA': [], '0': 'japan national route 415', '1': 'national highway of japan'}
# []
# {'NA': [], '0': 'japan national route 415', '1': 'national highway of japan'}

composed_lf = []
ent_keys = sorted([k for k in entities.keys() if k != "NA"], key=lambda x: int(x))
ent_keys_filled = []
if ent_keys:
for i in range(int(ent_keys[-1]) + 1):
if str(i) in ent_keys:
ent_keys_filled.append(str(i))
else:
ent_keys_filled.append(ent_keys[0])
for i, act in enumerate(lf):
if act == "entity":
try:
composed_lf.append(entities[ent_keys_filled.pop(0)])
except IndexError:
# print(f"ent idx: {ent_idx} | {entities}")
composed_lf.append(entities["NA"].pop())
elif act == "relation":
composed_lf.append(pp[i])
elif act == "type":
composed_lf.append(tp[i])
else:
composed_lf.append(act)

return composed_lf


if __name__ == "__main__":
save_path = ROOT_PATH.joinpath(args.path_inference).joinpath(args.name)
print(f"BATCH SIZE: {args.batch_size}")
Expand Down Expand Up @@ -133,23 +179,39 @@ def extract_entities_and_sentences(input_batch, ner_batch, coref_batch):
with torch.no_grad():
with tqdm(total=total_batches, desc=f'Inference') as pbar:
for i, batch in enumerate(test_loader):
logical_form, predicate_t, type_t = prepad_tensors_with_start_tokens(batch, vocabs, device=DEVICE)

# print(vocabs[LOGICAL_FORM].stoi)
# print(logical_form.shape)
# tg_lf = torch.zeros(logical_form.shape[0], 1, dtype=torch.long).to(DEVICE)
tg_lf = logical_form[:, :1]
# tg_lf = torch.hstack([tg_lf, logical_form[:, :-1]])
# print(tg_lf)
# exit()

# infer predictions from model
output = model(batch.input, batch.logical_form[:, :-1]) # dict
for j in range(logical_form.shape[1] - 1):
output = model(batch.input, tg_lf) # dict

pred = torch.argmax(output[LOGICAL_FORM], dim=1).view(args.batch_size, -1)

tg_lf = torch.hstack([tg_lf, pred[:, j:j+1]])
print(tg_lf)
preds = {
k: torch.argmax(output[k], dim=1).view(args.batch_size, -1) for k in [LOGICAL_FORM, NER,
COREF, PREDICATE_POINTER,
TYPE_POINTER]
}
print(preds[LOGICAL_FORM])
print(preds[LOGICAL_FORM].shape)

# get labels from data
target = {
LOGICAL_FORM: batch.logical_form[:, 1:].contiguous().view(args.batch_size, -1),
LOGICAL_FORM: logical_form[:, 1:].contiguous().view(args.batch_size, -1),
NER: batch.ner.contiguous().view(args.batch_size, -1),
COREF: batch.coref.contiguous().view(args.batch_size, -1),
PREDICATE_POINTER: batch.predicate_pointer[:, 1:].contiguous().view(args.batch_size, -1),
TYPE_POINTER: batch.type_pointer[:, 1:].contiguous().view(args.batch_size, -1),
PREDICATE_POINTER: predicate_t[:, 1:].contiguous().view(args.batch_size, -1),
TYPE_POINTER: type_t[:, 1:].contiguous().view(args.batch_size, -1),
}

# Convert batches of tensors to lists
Expand All @@ -166,14 +228,24 @@ def extract_entities_and_sentences(input_batch, ner_batch, coref_batch):
# k: [[vocabs[k].itos[tok] for tok in preds[k][i] if tok != pad[k]] for i in range(len(t_decoded[k]))] for k in preds.keys()
}

print(t_decoded)
print(preds_decoded)

exit()

batch_results = extract_entities_and_sentences(i_decoded, t_decoded[NER], t_decoded[COREF])

# TODO: what do we do with [PAD] tokens (Remove/keep and mask?) when calculating accuracy?
# find all B-'s ... extract the type_id from there
composed_lfs = []
for b in range(args.batch_size):
entities = batch_results[b]['entities']
sent = batch_results[b]['sent']

composed_lf = compose_logical_form(i_decoded[b], preds_decoded[LOGICAL_FORM][b],
preds_decoded[COREF][b], preds_decoded[PREDICATE_POINTER][b],
preds_decoded[TYPE_POINTER][b], entities)

# make into function >>>
df_inp = pandas.DataFrame.from_dict({"input": i_decoded[b],
"ner (p)": preds_decoded[NER][b],
Expand All @@ -197,8 +269,14 @@ def extract_entities_and_sentences(input_batch, ner_batch, coref_batch):

# print(f"### input: {re.sub(PUNCTUATION_PATTERN, '', ' '.join(input_decoded).replace(' ##', ''))}")
print(f"### input: {sent}")
# print(preds_decoded[LOGICAL_FORM])
print(t_decoded[LOGICAL_FORM])
print(preds_decoded[LOGICAL_FORM])
# print(entities)
print(composed_lf)
# print(preds_decoded[NER])
# print(t_decoded[NER])
# print(preds_decoded[COREF])
# print(t_decoded[COREF])
# print(entities)
# print(df_inp)
# print(df_out)
Expand Down
25 changes: 25 additions & 0 deletions dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,3 +677,28 @@ def update_counters(self, processed_data):
self.counters[ENTITY].update(item[7])

self.id += 1


def prepad_tensors_with_start_tokens(batch, vocabs: dict, device):
lf = batch.logical_form
pp = batch.predicate_pointer
tp = batch.type_pointer

# pad first position of Decoder output with `[START]` token and PP and TP with `NA` token
lf_pad = torch.hstack([torch.full((lf.shape[0], 1), vocabs[LOGICAL_FORM].stoi['[START]']).to(device), lf])
pp_pad = torch.hstack([torch.full((pp.shape[0], 1), vocabs[PREDICATE_POINTER].stoi['NA']).to(device), pp])
tp_pad = torch.hstack([torch.full((tp.shape[0], 1), vocabs[TYPE_POINTER].stoi['NA']).to(device), tp])

return lf_pad, pp_pad, tp_pad


if __name__ == "__main__":
# to check vocabularies
from args import get_parser
parser = get_parser()
args = parser.parse_args()

dataset = CSQADataset(args,
splits=('test',)) # assuming we already have the correct vocab cache from all splits!
data_dict, helper_dict = dataset.preprocess_data()
vocabs = dataset.build_vocabs(args.stream_data)
1 change: 0 additions & 1 deletion helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from args import parse_and_get_args
args = parse_and_get_args()


def extract_individual_losses_from_train_log(file_name):
with open(f"{ROOT_PATH}/{args.path_results}/{file_name}", 'r') as f:
col_names = []
Expand Down
12 changes: 7 additions & 5 deletions inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import torch
from tqdm import tqdm

from dataset import CSQADataset, collate_fn
from dataset import CSQADataset, collate_fn, prepad_tensors_with_start_tokens
from model import CARTON
from utils import Predictor, AverageMeter, MultiTaskAcc, MultiTaskAccTorchmetrics, MultiTaskRecTorchmetrics

Expand Down Expand Up @@ -137,15 +137,17 @@ def save_meter_to_file(meter_dict: dict[str: AverageMeter], path_to_file: pathli
Using model to do inference
"""

logical_form, predicate_t, type_t = prepad_tensors_with_start_tokens(batch, vocabs, device=DEVICE)

# compute output
output = model(batch.input, batch.logical_form[:, :-1])
output = model(batch.input, logical_form[:, :-1]) # TODO: we should feed one lf token at a time

target = {
LOGICAL_FORM: batch.logical_form[:, 1:].contiguous().view(-1),
LOGICAL_FORM: logical_form[:, 1:].contiguous().view(-1),
NER: batch.ner.contiguous().view(-1),
COREF: batch.coref.contiguous().view(-1),
PREDICATE_POINTER: batch.predicate_pointer[:, 1:].contiguous().view(-1),
TYPE_POINTER: batch.type_pointer[:, 1:].contiguous().view(-1),
PREDICATE_POINTER: predicate_t[:, 1:].contiguous().view(-1),
TYPE_POINTER: type_t[:, 1:].contiguous().view(-1),
}

accs = acc_calculator(output, target)
Expand Down
13 changes: 6 additions & 7 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from tqdm import tqdm

from model import CARTON
from dataset import CSQADataset, collate_fn
from dataset import CSQADataset, collate_fn, prepad_tensors_with_start_tokens
from torch.utils.data import DataLoader, SequentialSampler, BatchSampler, RandomSampler
from torch.utils.tensorboard import SummaryWriter
from utils import (NoamOpt, AverageMeter, MultiTaskLoss, save_checkpoint, init_weights,
Expand Down Expand Up @@ -179,11 +179,11 @@ def train(train_loader, model, vocabs, helper_data, criterion, optimizer, epoch)
for i, batch in enumerate(train_loader):
# get inputs
input = batch.input
logical_form = batch.logical_form
ner = batch.ner
coref = batch.coref
predicate_t = batch.predicate_pointer
type_t = batch.type_pointer

# pad first position of Decoder output with `[START]` token and PP and TP with `NA` token
logical_form, predicate_t, type_t = prepad_tensors_with_start_tokens(batch, vocabs, device=DEVICE)

# compute output
output = model(input, logical_form[:, :-1])
Expand Down Expand Up @@ -261,11 +261,10 @@ def validate(val_loader, model, vocabs, helper_data, criterion):
for _, batch in tqdm(enumerate(val_loader), desc="\tvalidation", total=len(val_loader)):
# get inputs
input = batch.input
logical_form = batch.logical_form
ner = batch.ner
coref = batch.coref
predicate_t = batch.predicate_pointer
type_t = batch.type_pointer

logical_form, predicate_t, type_t = prepad_tensors_with_start_tokens(batch, vocabs, device=DEVICE)

# compute output
output = model(input, logical_form[:, :-1])
Expand Down

0 comments on commit 4445516

Please sign in to comment.