Skip to content

Commit

Permalink
refactor: change specials and UNK tokens in vocabs
Browse files Browse the repository at this point in the history
  • Loading branch information
vejvarm committed Nov 30, 2023
1 parent af52116 commit 0e6f387
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 79 deletions.
17 changes: 14 additions & 3 deletions dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
CONTEXT_QUESTION, CONTEXT_ENTITIES, ANSWER, RESULTS, PREV_RESULTS, START_TOKEN, CTX_TOKEN,
UNK_TOKEN, END_TOKEN, INPUT, ID, NER, COREF, PREDICATE_POINTER, TYPE_POINTER, B, I, O)


@dataclass
class DataBatch:
"""
Expand Down Expand Up @@ -94,8 +93,11 @@ def collate_fn(batch, vocabs: dict, device: str):


class CSQADataset:
UNK = UNK_TOKEN

def __init__(self, args, splits=('train', 'val', 'test')):
Vocab.UNK = UNK_TOKEN

self.data_path = ROOT_PATH.joinpath(args.data_path)
self.source_paths = {split: self.data_path.joinpath(split) for split in splits}
self.splits = splits
Expand Down Expand Up @@ -166,7 +168,7 @@ def _build_vocabs(self, from_cache=False):
specials=[O, PAD_TOKEN],
vocab_cache=self.vocab_cache.joinpath("ner_vocab.pkl"))
vocabs[COREF] = self._build_vocab([item[4] for item in data_aggregate],
specials=['0', PAD_TOKEN],
specials=[NA_TOKEN, PAD_TOKEN],
vocab_cache=self.vocab_cache.joinpath("coref_vocab.pkl"))
vocabs[PREDICATE_POINTER] = self._build_vocab([item[5] for item in data_aggregate],
specials=[NA_TOKEN, PAD_TOKEN],
Expand All @@ -175,7 +177,7 @@ def _build_vocabs(self, from_cache=False):
specials=[NA_TOKEN, PAD_TOKEN],
vocab_cache=self.vocab_cache.joinpath("type_vocab.pkl"))
vocabs[ENTITY] = self._build_vocab([item[7] for item in data_aggregate],
specials=[PAD_TOKEN, NA_TOKEN],
specials=[NA_TOKEN, PAD_TOKEN],
vocab_cache=self.vocab_cache.joinpath("ent_vocab.pkl"))

return vocabs
Expand All @@ -190,6 +192,14 @@ def _build_vocabs_streaming(self):
raw_data = json.load(json_file)

processed_data, _ = self._prepare_data([raw_data])
# # !DEBUG >>>
# debug_data = {'raw': raw_data, 'processed': processed_data}
# debug_path = ROOT_PATH.joinpath("debug")
# new_path = debug_path.joinpath(file_path.parent.name).joinpath(file_path.name).with_suffix(".pkl")
# new_path.parent.mkdir(exist_ok=True, parents=True)
# with open(new_path, "wb") as debug_file:
# pickle.dump(debug_data, debug_file)
# # <<<
self.update_counters(processed_data)

# Create and save vocabularies
Expand Down Expand Up @@ -292,6 +302,7 @@ def _prepare_data(self, data):
input.append(context[1])
ner_tag.append(f'{context[-1]}-{context[-2]}' if context[-1] in [B, I] else context[-1])

# TODO: understand this and see if we do this correctly (compare with C)
# coref entities - prepare coref values
action_entities = [action[1] for action in system[GOLD_ACTIONS] if action[0] == ENTITY]
for context in reversed(user['context']):
Expand Down
75 changes: 0 additions & 75 deletions inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,15 +113,6 @@ def save_meter_to_file(meter_dict: dict[str: AverageMeter], path_to_file: pathli
model.load_state_dict(checkpoint['state_dict'])
print(f"=> loaded checkpoint '{args.model_path}' (epoch {checkpoint['epoch']})")

# predictor = Predictor(model, vocabs)

# acc_calculator = MultiTaskAcc(DEVICE)
# accuracies = {LOGICAL_FORM: AverageMeter(),
# NER: AverageMeter(),
# COREF: AverageMeter(),
# PREDICATE_POINTER: AverageMeter(),
# TYPE_POINTER: AverageMeter()}

acc_calculator = MultiTaskAccTorchmetrics(num_classes, pads=pad, device=DEVICE, averaging_type='micro') # !we use 'micro' to NOT bloat up classes, which don't have much samples (that would be useful for training)
accuracies = {LOGICAL_FORM: AverageMeter(),
NER: AverageMeter(),
Expand All @@ -144,21 +135,8 @@ def save_meter_to_file(meter_dict: dict[str: AverageMeter], path_to_file: pathli
Using model to do inference
"""

# ner = batch.ner
# coref = batch.coref
# predicate_t = batch.predicate_pointer
# type_t = batch.type_pointer

# compute output
output = model(batch.input, batch.logical_form[:, :-1])
# use input and NER to extract entity labels and types
# use KG to look for entities with that label and type

# match found entities with expected entities (accuracy)

# match predicate_pointer output (accuracy)
# match type_pointer output (accuracy)
# match logical_form output (accuracy)

target = {
LOGICAL_FORM: batch.logical_form[:, 1:].contiguous().view(-1),
Expand All @@ -176,59 +154,6 @@ def save_meter_to_file(meter_dict: dict[str: AverageMeter], path_to_file: pathli
for name, meter in recalls.items():
meter.update(recs[name])

# # ### DEBUG
# """
# Below are the labels
# """
# # Convert tensors to lists
# input_batch = [[vocabs['input'].itos[tok] for tok in sample if tok != pad['input']] for sample in batch.input]
# ner_batch = [[vocabs['ner'].itos[tok] for tok in sample if tok != pad['ner']] for sample in batch.ner]
# coref_batch = [[vocabs['coref'].itos[tok] for tok in sample if tok != pad['coref']] for sample in batch.coref]
# lf_batch = [[vocabs['logical_form'].itos[tok] for tok in sample if tok != pad['logical_form']] for sample in batch.logical_form]
# pp_batch = [[vocabs['predicate_pointer'].itos[tok] for tok in sample if tok != pad['predicate_pointer']] for sample in batch.predicate_pointer]
# tp_batch = [[vocabs['type_pointer'].itos[tok] for tok in sample if tok != pad['type_pointer']] for sample in batch.type_pointer]
#
# batch_results = extract_entities_and_sentences(input_batch, ner_batch, coref_batch)
#
# # TODO: what do we do with [PAD] tokens (Remove/keep and mask?) when calculating accuracy?
# # find all B-'s ... extract the type_id from there
# entities = batch_results[0]['entities']
# sent = batch_results[0]['sent']
#
# input_decoded = input_batch[0]
# ner = ner_batch[0]
# coref = coref_batch[0]
#
# lf_decoded = lf_batch[0]
# pp_decoded = pp_batch[0]
# tp_decoded = tp_batch[0]
#
# df_inp = pandas.DataFrame.from_dict({"input": input_decoded, "ner": ner, "coref": coref})
# df_out = pandas.DataFrame.from_dict({"lf": lf_decoded, "pp": pp_decoded, "tp": tp_decoded})
#
# csv_path = ROOT_PATH.joinpath("csv")
# csv_path.mkdir(exist_ok=True, parents=True)
# with csv_path.joinpath(f'test_{i}-asent.json').open("w") as f:
# json.dump({'sent': sent, 'entities': entities}, f, indent=4)
# with csv_path.joinpath(f"test_{i}-binp.csv").open("w") as f:
# df_inp.to_csv(f)
# with csv_path.joinpath(f"test_{i}-cout.csv").open("w") as f:
# df_out.to_csv(f)
#
# # print(f"### input: {re.sub(PUNCTUATION_PATTERN, '', ' '.join(input_decoded).replace(' ##', ''))}")
# print(f"### input: {sent}")
# print(entities)
# print(df_inp)
# print(df_out)
# print("##########################################\n")
#
# # in lf
# # fill relation with decoded relation_pointer
# # fill type with decoded type_pointer # NOTE: Insert doesn't use type_pointer
# # fill entities with id=search(label, type) but first order them by coref
# # TODO: \O.o/ dont forget our nice extraction code above
# # ### DEBUG

pbar.set_postfix({'lf': f"{accuracies[LOGICAL_FORM].avg:.4f}|{recalls[LOGICAL_FORM].avg:.4f}",
'ner': f"{accuracies[NER].avg:.4f}|{recalls[NER].avg:.4f}",
'coref': f"{accuracies[COREF].avg:.4f}|{recalls[COREF].avg:.4f}",
Expand Down
2 changes: 1 addition & 1 deletion lab/get_vocab_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@
vocab = pickle.load(pth.open("rb"), encoding="utf8")
vocab_example_map[pth.stem] = [vocab.itos[i] for i in range(10)]

json.dump(vocab_example_map, vocab_folder.joinpath("examples2.json").open("w"))
json.dump(vocab_example_map, vocab_folder.joinpath("examples.json").open("w"))

0 comments on commit 0e6f387

Please sign in to comment.