From f416767653cfb741a220a0c52042e1c18fd56646 Mon Sep 17 00:00:00 2001 From: "vejvarm@freya" Date: Thu, 11 Jan 2024 10:59:19 +0900 Subject: [PATCH] fix: set model into eval mode during inference --- infer_one.py | 58 ++++++++++++++++++++++++++++++++-------------------- inference.py | 1 + 2 files changed, 37 insertions(+), 22 deletions(-) diff --git a/infer_one.py b/infer_one.py index ac8deaa..d91d60e 100644 --- a/infer_one.py +++ b/infer_one.py @@ -95,7 +95,7 @@ def extract_entities_and_sentences(input_batch, ner_batch, coref_batch): return batch_entities_sentences -def compose_logical_form(inp, pred_lf, pred_coref, pred_pp, pred_tp, entities): +def compose_logical_form(inp, pred_lf, pred_coref, pred_pp, pred_tp, entities, eid2lab: dict = None, pid2lab: dict = None): inp_str = " ".join(inp) lf = pred_lf @@ -103,17 +103,7 @@ def compose_logical_form(inp, pred_lf, pred_coref, pred_pp, pred_tp, entities): pp = pred_pp tp = pred_tp - # for CSQA it works, but we get coref indexing errors for Merged, as one entity label belongs to more than one lf `entity` slot - # TODO: fix this: - # ['entity', 'relation', 'entity', 'insert', 'entity', 'relation', 'entity'] - # ['0', '1'] - # {'NA': [], '0': 'japan national route 415', '1': 'national highway of japan'} - # ['1'] - # {'NA': [], '0': 'japan national route 415', '1': 'national highway of japan'} - # [] - # {'NA': [], '0': 'japan national route 415', '1': 'national highway of japan'} - - composed_lf = [] + composed_lf = "" ent_keys = sorted([k for k in entities.keys() if k != "NA"], key=lambda x: int(x)) ent_keys_filled = [] if ent_keys: @@ -125,20 +115,29 @@ def compose_logical_form(inp, pred_lf, pred_coref, pred_pp, pred_tp, entities): for i, act in enumerate(lf): if act == "entity": try: - composed_lf.append(entities[ent_keys_filled.pop(0)]) + composed_lf += entities[ent_keys_filled.pop(0)] except IndexError: # print(f"ent idx: {ent_idx} | {entities}") try: - composed_lf.append(entities["NA"].pop()) + composed_lf += entities["NA"].pop() except IndexError: - print("No more entities to fill in logical form") - composed_lf.append("[UNK]$ENTITY") + # print("No more entities to fill in logical form") + composed_lf += "[UNK]$ENTITY" + composed_lf += ", " elif act == "relation": - composed_lf.append(pp[i]) + if pid2lab is not None: + composed_lf += pid2lab[pp[i]] + else: + composed_lf += pp[i] + composed_lf += ", " elif act == "type": - composed_lf.append(tp[i]) + if eid2lab is not None: + composed_lf += eid2lab[tp[i]] + else: + composed_lf += tp[i] + composed_lf += ", " else: - composed_lf.append(act) + composed_lf += act + "(" return composed_lf @@ -148,10 +147,18 @@ def compose_logical_form(inp, pred_lf, pred_coref, pred_pp, pred_tp, entities): dataset = CSQADataset(args, splits=('test', )) # assuming we already have the correct vocab cache from all splits! vocabs = dataset.build_vocabs(args.stream_data) + # load KG labels + eid2lab_dict = json.load(ROOT_PATH.joinpath("knowledge_graph/items_wikidata_n.json").open()) + pid2lab_dict = json.load(ROOT_PATH.joinpath("knowledge_graph/index_rel_dict.json").open()) + + eid2lab_dict.update({"NA": "[UNK]$TYPE"}) + pid2lab_dict.update({"NA": "[UNK]$RELATION"}) + pad = {k: v.stoi["[PAD]"] for k, v in vocabs.items() if k != "id"} num_classes = {k: len(v) for k, v in vocabs.items() if k != "id"} model = CARTON(vocabs, DEVICE).to(DEVICE) + model.eval() print(f"=> loading checkpoint '{args.model_path}'") checkpoint = torch.load(f'{ROOT_PATH}/{args.model_path}', encoding='latin1', map_location=DEVICE) args.start_epoch = checkpoint['epoch'] @@ -163,8 +170,14 @@ def compose_logical_form(inp, pred_lf, pred_coref, pred_pp, pred_tp, entities): max_lf_len = 10 while True: - # tokenize user input utterance = input("Enter query: ") + if utterance == "exit": + break + if utterance == "": + print("Please enter a sentence or type `exit` to quit.") + continue + + # tokenize user input tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') tokens = tokenizer(utterance)['input_ids'] tokens = tokenizer.convert_ids_to_tokens(tokens) @@ -198,7 +211,7 @@ def compose_logical_form(inp, pred_lf, pred_coref, pred_pp, pred_tp, entities): k: [[vocabs[k].itos[tok] for tok in sample if tok != pad[k]] for sample in preds[k]] for k in preds.keys() # removing [PAD] tokens } - print(preds_decoded) + # print(preds_decoded) batch_results = extract_entities_and_sentences(i_decoded, preds_decoded[NER], preds_decoded[COREF]) @@ -211,7 +224,7 @@ def compose_logical_form(inp, pred_lf, pred_coref, pred_pp, pred_tp, entities): composed_lf = compose_logical_form(i_decoded[b], preds_decoded[LOGICAL_FORM][b], preds_decoded[COREF][b], preds_decoded[PREDICATE_POINTER][b], - preds_decoded[TYPE_POINTER][b], entities) + preds_decoded[TYPE_POINTER][b], entities, eid2lab_dict, pid2lab_dict) # make into function >>> df_inp = pandas.DataFrame.from_dict({"input": i_decoded[b], @@ -231,5 +244,6 @@ def compose_logical_form(inp, pred_lf, pred_coref, pred_pp, pred_tp, entities): print(f"### input: {sent}") print(preds_decoded[LOGICAL_FORM]) + print(preds_decoded[NER]) print(composed_lf) print("##########################################\n") diff --git a/inference.py b/inference.py index 2811a57..44ecc27 100644 --- a/inference.py +++ b/inference.py @@ -111,6 +111,7 @@ def save_meter_to_file(meter_dict: dict[str: AverageMeter], path_to_file: pathli checkpoint = torch.load(f'{ROOT_PATH}/{args.model_path}', encoding='latin1', map_location=DEVICE) args.start_epoch = checkpoint['epoch'] model.load_state_dict(checkpoint['state_dict']) + model.eval() print(f"=> loaded checkpoint '{args.model_path}' (epoch {checkpoint['epoch']})") # ACCURACY Metric