-
Notifications
You must be signed in to change notification settings - Fork 0
/
get_inference_from_model.py
150 lines (134 loc) · 7.68 KB
/
get_inference_from_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import os
import json
import logging
from pprint import pprint
import sys
import torch
from biomedical_bert_ner.utils.utilities import predictions_from_model
from biomedical_bert_ner.utils.utilities import align_predicted_labels_with_original_sentence_tokens
from biomedical_bert_ner.utils.utilities import load_and_cache_examples
from biomedical_bert_ner.utils.utilities import get_labels, convert_to_ents
from biomedical_bert_ner.models.models import *
from transformers import BertTokenizer, BertConfig
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device Being used as {} \n".format(DEVICE))
logging.basicConfig(
filename="inference_logs.txt",
filemode="w"
)
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)
class NERTagger:
def __init__(
self, labels_file,
model_config_path, device
):
self.model_config_path = model_config_path
self.labels_file = labels_file
self.device = device
if os.path.exists(self.model_config_path):
with open(self.model_config_path, "r", encoding="utf-8") as reader:
text = reader.read()
self.model_config_dict = json.loads(text)
else:
print("model_config_path doesn't exist.")
sys.exit()
if os.path.exists(self.model_config_dict["final_model_saving_dir"]):
self.model_file = self.model_config_dict["final_model_saving_dir"] + "pytorch_model.bin"
self.config_file = self.model_config_dict["final_model_saving_dir"] + "bert_config.json"
self.vocab_file = self.model_config_dict["final_model_saving_dir"] + "vocab.txt"
else:
print("model_saving_dir doesn't exist.")
sys.exit()
if os.path.exists(self.labels_file):
print("Labels file exist")
else:
print("labels_file doesn't exist.")
sys.exit()
self.bert_config = BertConfig.from_json_file(self.config_file)
self.bert_tokenizer = BertTokenizer.from_pretrained(
self.vocab_file,
config=self.bert_config,
do_lower_case=self.model_config_dict["tokenizer_do_lower_case"]
)
self.labels = get_labels(self.labels_file)
self.label2idx = {l: i for i, l in enumerate(self.labels)}
if self.model_config_dict["model_type"] == "crf":
self.model = BertCrfForNER.from_pretrained(
self.model_file,
config=self.bert_config,
pad_idx=self.bert_tokenizer.pad_token_id,
sep_idx=self.bert_tokenizer.sep_token_id,
num_labels=len(self.labels)
)
elif self.model_config_dict["model_type"] == "token_classification":
self.model = BertForTokenClassification.from_pretrained(
self.model_file,
config=self.bert_config,
num_labels=len(self.labels),
classification_layer_sizes=self.model_config_dict["classification_layer_sizes"]
)
elif self.model_config_dict["model_type"] == "lstm_crf":
self.model = BertLstmCrf.from_pretrained(
self.model_file,
config=self.bert_config,
num_labels=len(self.labels),
pad_idx=self.bert_tokenizer.pad_token_id,
lstm_hidden_dim=self.model_config_dict["lstm_hidden_dim"],
num_lstm_layers=self.model_config_dict["num_lstm_layers"],
bidirectional=self.model_config_dict["bidirectional"]
)
self.model.to(self.device)
print("Model loaded successfully from the config provided.")
def tag_sentences(self, sentence_list, logger, batch_size):
dataset, examples, features = load_and_cache_examples(
max_seq_length=self.model_config_dict["max_seq_length"],
tokenizer=self.bert_tokenizer,
label_map=self.label2idx,
pad_token_label_id=self.label2idx["O"],
mode="inference", data_dir=None,
logger=logger, sentence_list=sentence_list,
return_features_and_examples=True
)
label_predictions = predictions_from_model(
model=self.model, tokenizer=self.bert_tokenizer,
dataset=dataset, batch_size=batch_size,
label2idx=self.label2idx, device=self.device
)
# restructure test_label_predictions with real labels
aligned_predicted_labels, _ = align_predicted_labels_with_original_sentence_tokens(
label_predictions, examples, features,
max_seq_length=self.model_config_dict["max_seq_length"],
num_special_tokens=self.model_config_dict["num_special_tokens"]
)
results = []
for label_tags, example in zip(aligned_predicted_labels, examples):
results.append(
convert_to_ents(example.words, label_tags)
)
return results
if __name__ == "__main__":
sentence_list = [
"Number of glucocorticoid receptors in lymphocytes and their sensitivity to hormone action .",
"The study demonstrated a decreased level of glucocorticoid receptors ( GR ) in peripheral blood lymphocytes from hypercholesterolemic subjects , and an elevated level in patients with acute myocardial infarction .",
"In the lymphocytes with a high GR number , dexamethasone inhibited [ 3H ] -thymidine and [ 3H ] -acetate incorporation into DNA and cholesterol , respectively , in the same manner as in the control cells .",
"On the other hand , a decreased GR number resulted in a less efficient dexamethasone inhibition of the incorporation of labeled compounds .",
"hese data showed that the sensitivity of lymphocytes to glucocorticoids changed only with a decrease of GR level .",
"Treatment with I-hydroxyvitamin D3 ( 1-1.5 mg daily , within 4 weeks ) led to normalization of total and ionized form of Ca2+ and of 25 ( OH ) D , but did not affect the PTH content in blood .",
"The data obtained suggest that under conditions of glomerulonephritis only high content of receptors to 1.25 ( OH ) 2D3 in lymphocytes enabled to perform the cell response to the hormone effect .",
"To investigate whether the tumor expression of beta-2-microglobulin ( beta 2-M ) could serve as a marker of tumor biologic behavior , the authors studied specimens of breast carcinomas from 60 consecutive female patients .",
"Presence of beta 2-M was analyzed by immunohistochemistry .",
"I love data science",
"Humira showed better results than Cimzia for treating psoriasis .",
"Important advancements in the treatment of non - small cell lung cancer (NSCLC) have been achieved over the past two decades, increasing our understanding of the disease biology and mechanisms of tumour progression, and advancing early detection and multimodal care .",
"The use of small molecule tyrosine kinase inhibitors and immunotherapy has led to unprecedented survival benefits in selected patients .",
"However, the overall cure and survival rates for NSCLC remain low, particularly in metastatic disease .",
"Therefore, continued research into new drugs and combination therapies is required to expand the clinical benefit to a broader patient population and to improve outcomes in NSCLC .",
"The non-small cell lung cancer immune contexture. A major determinant of tumor characteristics and patient outcome ."
]
tagger = NERTagger(
labels_file="/media/rabbit/Work_data/all_nlp_datasets/bio_ner_datasets/jnlpba/labels_file.txt",
model_config_path="configs/crf_ner_config.json",
device=DEVICE
)
pprint(tagger.tag_sentences(sentence_list, logger=logger, batch_size=2))