Skip to content

Commit

Permalink
FIX: NER predict return
Browse files Browse the repository at this point in the history
  • Loading branch information
flaviussn committed May 9, 2020
1 parent fce538d commit 236ded9
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 8 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -732,7 +732,7 @@ Args:

Returns:
* preds: A Python list of lists with dicts containg each word mapped to its NER tag.
* model_outputs: A python list of the raw model outputs for each text.
* model_outputs: A Python list of lists with dicts containing each word mapped to its list with raw model output.


**`train(self, train_dataset, output_dir)`**
Expand Down
16 changes: 15 additions & 1 deletion examples/named_entity_recognition/named_entity_recognition.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import pandas as pd
from simpletransformers.ner import NERModel
import numpy as np
from scipy.special import softmax

# Creating train_df and eval_df for demonstration
train_data = [
Expand Down Expand Up @@ -45,7 +47,19 @@
# Evaluate the model
result, model_outputs, predictions = model.eval_model(eval_df)


# Predictions on arbitary text strings
predictions, raw_outputs = model.predict(["Some arbitary sentence"])
sentences = ["Some arbitary sentence","Simple Transformers sentence"]
predictions, raw_outputs = model.predict(sentences)

print(predictions)

# More detailed preditctions
for n,(preds,outs) in enumerate(zip(predictions,raw_outputs)):
print("\n___________________________")
print("Sentence: ",sentences[n])
for pred, out in zip(preds, outs):
key = list(pred.keys())[0]
new_out = out[key]
preds = list(softmax(np.mean(new_out,axis=0)))
print(key,pred[key],preds[np.argmax(preds)],preds)
81 changes: 75 additions & 6 deletions simpletransformers/ner/ner_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -681,7 +681,7 @@ def predict(self, to_predict, split_on_space=True):
Returns:
preds: A Python list of lists with dicts containing each word mapped to its NER tag.
model_outputs: A python list of the raw model outputs for each text.
model_outputs: A Python list of lists with dicts containing each word mapped to its list with raw model output.
"""

device = self.device
Expand All @@ -693,7 +693,9 @@ def predict(self, to_predict, split_on_space=True):

if split_on_space:
predict_examples = [
InputExample(i, sentence.split(), [self.labels[0] for word in sentence.split()])
InputExample(
i, sentence.split(), [self.labels[0] for word in sentence.split()]
)
for i, sentence in enumerate(to_predict)
]
else:
Expand All @@ -705,7 +707,9 @@ def predict(self, to_predict, split_on_space=True):
eval_dataset = self.load_and_cache_examples(None, to_predict=predict_examples)

eval_sampler = SequentialSampler(eval_dataset)
eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args["eval_batch_size"])
eval_dataloader = DataLoader(
eval_dataset, sampler=eval_sampler, batch_size=args["eval_batch_size"]
)

eval_loss = 0.0
nb_eval_steps = 0
Expand Down Expand Up @@ -735,12 +739,24 @@ def predict(self, to_predict, split_on_space=True):
if preds is None:
preds = logits.detach().cpu().numpy()
out_label_ids = inputs["labels"].detach().cpu().numpy()
out_input_ids = inputs["input_ids"].detach().cpu().numpy()
out_attention_mask = inputs["attention_mask"].detach().cpu().numpy()
else:
preds = np.append(preds, logits.detach().cpu().numpy(), axis=0)
out_label_ids = np.append(out_label_ids, inputs["labels"].detach().cpu().numpy(), axis=0)
out_label_ids = np.append(
out_label_ids, inputs["labels"].detach().cpu().numpy(), axis=0
)
out_input_ids = np.append(
out_input_ids, inputs["input_ids"].detach().cpu().numpy(), axis=0
)
out_attention_mask = np.append(
out_attention_mask,
inputs["attention_mask"].detach().cpu().numpy(),
axis=0,
)

eval_loss = eval_loss / nb_eval_steps
model_outputs = preds
token_logits = preds
preds = np.argmax(preds, axis=2)

label_map = {i: label for i, label in enumerate(self.labels)}
Expand All @@ -755,12 +771,65 @@ def predict(self, to_predict, split_on_space=True):
preds_list[i].append(label_map[preds[i][j]])

preds = [
[{word: preds_list[i][j]} for j, word in enumerate(sentence.split()[: len(preds_list[i])])]
[
{word: preds_list[i][j]}
for j, word in enumerate(sentence.split()[: len(preds_list[i])])
]
for i, sentence in enumerate(to_predict)
]

word_tokens = []
for n, sentence in enumerate(to_predict):
w_log = self._convert_tokens_to_word_logits(
out_input_ids[n],
out_label_ids[n],
out_attention_mask[n],
token_logits[n],
)
word_tokens.append(w_log)

model_outputs = [
[
{word: word_tokens[i][j]}
for j, word in enumerate(sentence.split()[: len(preds_list[i])])
]
for i, sentence in enumerate(to_predict)
]
return preds, model_outputs

def _convert_tokens_to_word_logits(
self, input_ids, label_ids, attention_mask, logits
):

ignore_ids = [
self.tokenizer.convert_tokens_to_ids(self.tokenizer.pad_token),
self.tokenizer.convert_tokens_to_ids(self.tokenizer.sep_token),
self.tokenizer.convert_tokens_to_ids(self.tokenizer.cls_token),
]

# Remove unuseful positions
masked_ids = input_ids[(1 == attention_mask)]
masked_labels = label_ids[(1 == attention_mask)]
masked_logits = logits[(1 == attention_mask)]
for id in ignore_ids:
masked_labels = masked_labels[(id != masked_ids)]
masked_logits = masked_logits[(id != masked_ids)]
masked_ids = masked_ids[(id != masked_ids)]

# Map to word logits
word_logits = []
tmp = []
for n, lab in enumerate(masked_labels):
if lab != self.pad_token_label_id:
if n != 0:
word_logits.append(tmp)
tmp = [list(masked_logits[n])]
else:
tmp.append(list(masked_logits[n]))
word_logits.append(tmp)

return word_logits

def load_and_cache_examples(self, data, evaluate=False, no_cache=False, to_predict=None):
"""
Reads data_file and generates a TensorDataset containing InputFeatures. Caches the InputFeatures.
Expand Down

0 comments on commit 236ded9

Please sign in to comment.