Skip to content
This repository has been archived by the owner on May 13, 2024. It is now read-only.

Commit

Permalink
Organize pytest fixture
Browse files Browse the repository at this point in the history
  • Loading branch information
vmenger committed Oct 10, 2023
1 parent d635516 commit 3d2bd31
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 79 deletions.
30 changes: 30 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import json
import pickle

import pytest

from clin_nlp_metrics import Dataset


@pytest.fixture
def mctrainer_data():
with open("tests/data/medcattrainer_export.json", "rb") as f:
return json.load(f)


@pytest.fixture
def mctrainer_dataset(mctrainer_data):
return Dataset.from_medcattrainer(data=mctrainer_data)


@pytest.fixture
def clinlp_docs():
with open("tests/data/clinlp_docs.pickle", "rb") as f:
return pickle.load(f)


@pytest.fixture
def clinlp_dataset(clinlp_docs):
ids = list(f"doc_{x}" for x in range(0, 15))

return Dataset.from_clinlp_docs(nlp_docs=clinlp_docs, ids=ids)
99 changes: 42 additions & 57 deletions tests/test_dataset.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,8 @@
import json
import pickle

import clinlp # noqa: F401
import pytest

from clin_nlp_metrics.dataset import Annotation, Dataset, Document


@pytest.fixture
def mctrainer_data():
with open("tests/data/medcattrainer_export.json", "rb") as f:
return json.load(f)


@pytest.fixture
def dataset(mctrainer_data):
return Dataset.from_medcattrainer(data=mctrainer_data)


@pytest.fixture
def clinlp_docs():
with open("tests/data/clinlp_docs.pickle", "rb") as f:
return pickle.load(f)


class TestAnnotation:
def test_annotation_nervaluate(self):
ann = Annotation(text="test", start=0, end=5, label="test")
Expand Down Expand Up @@ -243,19 +222,19 @@ def test_dataset_nervaluate(self):
[{"text": "test2", "start": 0, "end": 5, "label": "test2"}],
]

def test_dataset_to_nervaluate_with_filter(self, dataset):
def test_dataset_to_nervaluate_with_filter(self, mctrainer_dataset):
def ann_filter(ann):
return any(not qualifier["is_default"] for qualifier in ann.qualifiers)

to_nervaluate = dataset.to_nervaluate(ann_filter=ann_filter)
to_nervaluate = mctrainer_dataset.to_nervaluate(ann_filter=ann_filter)

assert to_nervaluate[0] == [
{"end": 23, "label": "C0002871_anemie", "start": 17, "text": "anemie"}
]
assert to_nervaluate[1] == []

def test_infer_default_qualifiers(self, dataset):
default_qualifiers = dataset.infer_default_qualifiers()
def test_infer_default_qualifiers(self, mctrainer_dataset):
default_qualifiers = mctrainer_dataset.infer_default_qualifiers()

assert default_qualifiers == {
"Negation": "Affirmed",
Expand All @@ -264,71 +243,77 @@ def test_infer_default_qualifiers(self, dataset):
"Plausibility": "Plausible",
}

def test_num_docs(self, dataset):
assert dataset.num_docs() == 14
def test_num_docs(self, mctrainer_dataset):
assert mctrainer_dataset.num_docs() == 14

def test_num_annotations(self, dataset):
assert dataset.num_annotations() == 13
def test_num_annotations(self, mctrainer_dataset):
assert mctrainer_dataset.num_annotations() == 13

def test_span_counts(self, dataset):
assert len(dataset.span_counts()) == 11
def test_span_counts(self, mctrainer_dataset):
assert len(mctrainer_dataset.span_counts()) == 11

def test_span_counts_n_spans(self, dataset):
assert dataset.span_counts(n_spans=3) == {
def test_span_counts_n_spans(self, mctrainer_dataset):
assert mctrainer_dataset.span_counts(n_spans=3) == {
"anemie": 2,
"bloeding": 2,
"prematuriteit": 1,
}

def test_span_counts_callback(self, dataset):
assert dataset.span_counts(n_spans=3, span_callback=lambda x: x.upper()) == {
def test_span_counts_callback(self, mctrainer_dataset):
assert mctrainer_dataset.span_counts(
n_spans=3, span_callback=lambda x: x.upper()
) == {
"ANEMIE": 2,
"BLOEDING": 2,
"PREMATURITEIT": 1,
}

def test_label_counts(self, dataset):
assert len(dataset.label_counts()) == 9
def test_label_counts(self, mctrainer_dataset):
assert len(mctrainer_dataset.label_counts()) == 9

def test_label_counts_n_labels(self, dataset):
assert dataset.label_counts(n_labels=3) == {
def test_label_counts_n_labels(self, mctrainer_dataset):
assert mctrainer_dataset.label_counts(n_labels=3) == {
"C0002871_anemie": 2,
"C0151526_prematuriteit": 2,
"C0270191_intraventriculaire_bloeding": 2,
}

def test_label_counts_callback(self, dataset):
assert dataset.label_counts(
def test_label_counts_callback(self, mctrainer_dataset):
assert mctrainer_dataset.label_counts(
n_labels=3, label_callback=lambda x: x[x.index("_") + 1 :]
) == {"anemie": 2, "prematuriteit": 2, "intraventriculaire_bloeding": 2}

def test_qualifier_counts(self, dataset):
assert dataset.qualifier_counts() == {
def test_qualifier_counts(self, mctrainer_dataset):
assert mctrainer_dataset.qualifier_counts() == {
"Experiencer": {"Patient": 12, "Other": 1},
"Negation": {"Affirmed": 11, "Negated": 2},
"Plausibility": {"Plausible": 11, "Hypothetical": 2},
"Temporality": {"Current": 11, "Historical": 2},
}

def test_stats(self, dataset):
stats = dataset.stats()
def test_stats(self, mctrainer_dataset):
stats = mctrainer_dataset.stats()

assert stats["num_docs"] == dataset.num_docs()
assert stats["num_annotations"] == dataset.num_annotations()
assert stats["span_counts"] == dataset.span_counts()
assert stats["label_counts"] == dataset.label_counts()
assert stats["qualifier_counts"] == dataset.qualifier_counts()
assert stats["num_docs"] == mctrainer_dataset.num_docs()
assert stats["num_annotations"] == mctrainer_dataset.num_annotations()
assert stats["span_counts"] == mctrainer_dataset.span_counts()
assert stats["label_counts"] == mctrainer_dataset.label_counts()
assert stats["qualifier_counts"] == mctrainer_dataset.qualifier_counts()

def test_stats_with_kwargs(self, dataset):
def test_stats_with_kwargs(self, mctrainer_dataset):
n_labels = 1
span_callback = lambda x: x.upper() # noqa: E731

stats = dataset.stats(
stats = mctrainer_dataset.stats(
n_labels=n_labels, span_callback=span_callback, unused_argument=None
)

assert stats["num_docs"] == dataset.num_docs()
assert stats["num_annotations"] == dataset.num_annotations()
assert stats["span_counts"] == dataset.span_counts(span_callback=span_callback)
assert stats["label_counts"] == dataset.label_counts(n_labels=n_labels)
assert stats["qualifier_counts"] == dataset.qualifier_counts()
assert stats["num_docs"] == mctrainer_dataset.num_docs()
assert stats["num_annotations"] == mctrainer_dataset.num_annotations()
assert stats["span_counts"] == mctrainer_dataset.span_counts(
span_callback=span_callback
)
assert stats["label_counts"] == mctrainer_dataset.label_counts(
n_labels=n_labels
)
assert stats["qualifier_counts"] == mctrainer_dataset.qualifier_counts()
23 changes: 1 addition & 22 deletions tests/test_metrics.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,6 @@
import json
import pickle

import pytest

from clin_nlp_metrics import Dataset, Metrics


@pytest.fixture
def mctrainer_dataset():
with open("tests/data/medcattrainer_export.json", "rb") as f:
mctrainer_data = json.load(f)

return Dataset.from_medcattrainer(data=mctrainer_data)


@pytest.fixture
def clinlp_dataset():
with open("tests/data/clinlp_docs.pickle", "rb") as f:
data = pickle.load(f)

ids = list(f"doc_{x}" for x in range(0, 15))

return Dataset.from_clinlp_docs(nlp_docs=data, ids=ids)
from clin_nlp_metrics import Metrics


class TestMetrics:
Expand Down

0 comments on commit 3d2bd31

Please sign in to comment.