Skip to content

Commit

Permalink
fix: set model into eval mode during inference
Browse files Browse the repository at this point in the history
  • Loading branch information
vejvarm committed Jan 11, 2024
1 parent 7d4e6f3 commit f416767
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 22 deletions.
58 changes: 36 additions & 22 deletions infer_one.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,25 +95,15 @@ def extract_entities_and_sentences(input_batch, ner_batch, coref_batch):
return batch_entities_sentences


def compose_logical_form(inp, pred_lf, pred_coref, pred_pp, pred_tp, entities):
def compose_logical_form(inp, pred_lf, pred_coref, pred_pp, pred_tp, entities, eid2lab: dict = None, pid2lab: dict = None):
inp_str = " ".join(inp)

lf = pred_lf
coref = pred_coref
pp = pred_pp
tp = pred_tp

# for CSQA it works, but we get coref indexing errors for Merged, as one entity label belongs to more than one lf `entity` slot
# TODO: fix this:
# ['entity', 'relation', 'entity', 'insert', 'entity', 'relation', 'entity']
# ['0', '1']
# {'NA': [], '0': 'japan national route 415', '1': 'national highway of japan'}
# ['1']
# {'NA': [], '0': 'japan national route 415', '1': 'national highway of japan'}
# []
# {'NA': [], '0': 'japan national route 415', '1': 'national highway of japan'}

composed_lf = []
composed_lf = ""
ent_keys = sorted([k for k in entities.keys() if k != "NA"], key=lambda x: int(x))
ent_keys_filled = []
if ent_keys:
Expand All @@ -125,20 +115,29 @@ def compose_logical_form(inp, pred_lf, pred_coref, pred_pp, pred_tp, entities):
for i, act in enumerate(lf):
if act == "entity":
try:
composed_lf.append(entities[ent_keys_filled.pop(0)])
composed_lf += entities[ent_keys_filled.pop(0)]
except IndexError:
# print(f"ent idx: {ent_idx} | {entities}")
try:
composed_lf.append(entities["NA"].pop())
composed_lf += entities["NA"].pop()
except IndexError:
print("No more entities to fill in logical form")
composed_lf.append("[UNK]$ENTITY")
# print("No more entities to fill in logical form")
composed_lf += "[UNK]$ENTITY"
composed_lf += ", "
elif act == "relation":
composed_lf.append(pp[i])
if pid2lab is not None:
composed_lf += pid2lab[pp[i]]
else:
composed_lf += pp[i]
composed_lf += ", "
elif act == "type":
composed_lf.append(tp[i])
if eid2lab is not None:
composed_lf += eid2lab[tp[i]]
else:
composed_lf += tp[i]
composed_lf += ", "
else:
composed_lf.append(act)
composed_lf += act + "("

return composed_lf

Expand All @@ -148,10 +147,18 @@ def compose_logical_form(inp, pred_lf, pred_coref, pred_pp, pred_tp, entities):
dataset = CSQADataset(args, splits=('test', )) # assuming we already have the correct vocab cache from all splits!
vocabs = dataset.build_vocabs(args.stream_data)

# load KG labels
eid2lab_dict = json.load(ROOT_PATH.joinpath("knowledge_graph/items_wikidata_n.json").open())
pid2lab_dict = json.load(ROOT_PATH.joinpath("knowledge_graph/index_rel_dict.json").open())

eid2lab_dict.update({"NA": "[UNK]$TYPE"})
pid2lab_dict.update({"NA": "[UNK]$RELATION"})

pad = {k: v.stoi["[PAD]"] for k, v in vocabs.items() if k != "id"}
num_classes = {k: len(v) for k, v in vocabs.items() if k != "id"}

model = CARTON(vocabs, DEVICE).to(DEVICE)
model.eval()
print(f"=> loading checkpoint '{args.model_path}'")
checkpoint = torch.load(f'{ROOT_PATH}/{args.model_path}', encoding='latin1', map_location=DEVICE)
args.start_epoch = checkpoint['epoch']
Expand All @@ -163,8 +170,14 @@ def compose_logical_form(inp, pred_lf, pred_coref, pred_pp, pred_tp, entities):

max_lf_len = 10
while True:
# tokenize user input
utterance = input("Enter query: ")
if utterance == "exit":
break
if utterance == "":
print("Please enter a sentence or type `exit` to quit.")
continue

# tokenize user input
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
tokens = tokenizer(utterance)['input_ids']
tokens = tokenizer.convert_ids_to_tokens(tokens)
Expand Down Expand Up @@ -198,7 +211,7 @@ def compose_logical_form(inp, pred_lf, pred_coref, pred_pp, pred_tp, entities):
k: [[vocabs[k].itos[tok] for tok in sample if tok != pad[k]] for sample in preds[k]] for k in preds.keys() # removing [PAD] tokens
}

print(preds_decoded)
# print(preds_decoded)

batch_results = extract_entities_and_sentences(i_decoded, preds_decoded[NER], preds_decoded[COREF])

Expand All @@ -211,7 +224,7 @@ def compose_logical_form(inp, pred_lf, pred_coref, pred_pp, pred_tp, entities):

composed_lf = compose_logical_form(i_decoded[b], preds_decoded[LOGICAL_FORM][b],
preds_decoded[COREF][b], preds_decoded[PREDICATE_POINTER][b],
preds_decoded[TYPE_POINTER][b], entities)
preds_decoded[TYPE_POINTER][b], entities, eid2lab_dict, pid2lab_dict)

# make into function >>>
df_inp = pandas.DataFrame.from_dict({"input": i_decoded[b],
Expand All @@ -231,5 +244,6 @@ def compose_logical_form(inp, pred_lf, pred_coref, pred_pp, pred_tp, entities):

print(f"### input: {sent}")
print(preds_decoded[LOGICAL_FORM])
print(preds_decoded[NER])
print(composed_lf)
print("##########################################\n")
1 change: 1 addition & 0 deletions inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ def save_meter_to_file(meter_dict: dict[str: AverageMeter], path_to_file: pathli
checkpoint = torch.load(f'{ROOT_PATH}/{args.model_path}', encoding='latin1', map_location=DEVICE)
args.start_epoch = checkpoint['epoch']
model.load_state_dict(checkpoint['state_dict'])
model.eval()
print(f"=> loaded checkpoint '{args.model_path}' (epoch {checkpoint['epoch']})")

# ACCURACY Metric
Expand Down

0 comments on commit f416767

Please sign in to comment.