Skip to content

Commit

Permalink
Merge branch '188-seqeval-label-remap' into 'main'
Browse files Browse the repository at this point in the history
Resolve "Allow remapping of label when computing NER metrics with SeqEval"

Closes #188

See merge request heka/medkit!234

changelog: Resolve "Allow remapping of label when computing NER metrics with SeqEval"
  • Loading branch information
ghisvail committed Nov 23, 2023
2 parents bd6a928 + af01377 commit 9f9f50f
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 1 deletion.
1 change: 1 addition & 0 deletions medkit/text/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
__all__ = [
"context",
"metrics",
"ner",
"postprocessing",
"preprocessing",
Expand Down
10 changes: 9 additions & 1 deletion medkit/text/metrics/ner.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def __init__(
return_metrics_by_label: bool = True,
average: Literal["macro", "weighted"] = "macro",
tokenizer: Optional[Any] = None,
labels_remapping: Optional[Dict[str, str]] = None,
):
"""
Parameters
Expand All @@ -96,15 +97,20 @@ def __init__(
Type of average to be performed in metrics.
- `macro`, unweighted mean (default)
- `weighted`, weighted average by support (number of true instances by label)
tokenizer:
Optional Fast Tokenizer to convert text into tokens.
If not provided, the text is tokenized by character.
labels_remapping:
Optional remapping of labels, useful when there is a mismatch
between the predicted labels and the reference labels to evaluate
against. If a label (of a reference of predicted entity) is found in
this dict, the corresponding value will be used as label instead.
"""
self.tokenizer = tokenizer
self.tagging_scheme = tagging_scheme
self.return_metrics_by_label = return_metrics_by_label
self.average = average
self.labels_remapping = labels_remapping

def compute(
self, documents: List[TextDocument], predicted_entities: List[List[Entity]]
Expand Down Expand Up @@ -160,6 +166,8 @@ def _tag_text_with_entities(self, text: str, entities: List[Entity]):
tags = ["O"] * len(text)
for ent in entities:
label = ent.label
if self.labels_remapping:
label = self.labels_remapping.get(label, label)
ent_spans = span_utils.normalize_spans(ent.spans)
# skip if all spans were ModifiedSpans and we are
# not able to refer back to text
Expand Down
41 changes: 41 additions & 0 deletions tests/unit/text/metrics/test_seqeval_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,3 +202,44 @@ def test_modified_spans():
"support": 1,
"accuracy": 0.7,
}


def test_labels_remapping(document):
# identical to reference entities but with abbreviated labels
predicted_entities = [
Entity(label="CORP", spans=[Span(start=0, end=6)], text="medkit"),
Entity(label="LANG", spans=[Span(start=12, end=18)], text="python"),
]

expected_metrics = {
"macro_precision": 1.0,
"macro_recall": 1.0,
"macro_f1-score": 1.0,
"support": 2,
"accuracy": 1.0,
}

# remap only predicted entities
evaluator = SeqEvalEvaluator(
labels_remapping={"CORP": "corporation", "LANG": "language"},
return_metrics_by_label=False,
)
metrics = evaluator.compute(
documents=[document], predicted_entities=[predicted_entities]
)
assert metrics == expected_metrics

# remap all entities (predicted and reference) to unique label
evaluator = SeqEvalEvaluator(
labels_remapping={
"CORP": "ent",
"LANG": "ent",
"corporation": "ent",
"language": "ent",
},
return_metrics_by_label=False,
)
metrics = evaluator.compute(
documents=[document], predicted_entities=[predicted_entities]
)
assert metrics == expected_metrics

0 comments on commit 9f9f50f

Please sign in to comment.