Skip to content

Commit 99398ca

Browse files
chg: 💄
1 parent 730259d commit 99398ca

12 files changed

+438
-335
lines changed

tests/common.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,11 @@
55

66

77
def get_path(p):
8-
return os.path.join(
9-
os.path.dirname(__file__),
10-
p
11-
)
8+
return os.path.join(os.path.dirname(__file__), p)
129

13-
TEST_CFG = get_path('test_data/test_config.ini')
14-
TEST_JSONL = get_path('test_data/test_jsonl.jsonl')
15-
TEST_REFERENCES = get_path('test_data/test_references.txt')
16-
TEST_TSV_PREDICT = get_path('test_data/test_tsv_predict.tsv')
17-
TEST_TSV_TRAIN = get_path('test_data/test_tsv_train.tsv')
10+
11+
TEST_CFG = get_path("test_data/test_config.ini")
12+
TEST_JSONL = get_path("test_data/test_jsonl.jsonl")
13+
TEST_REFERENCES = get_path("test_data/test_references.txt")
14+
TEST_TSV_PREDICT = get_path("test_data/test_tsv_predict.tsv")
15+
TEST_TSV_TRAIN = get_path("test_data/test_tsv_train.tsv")

tests/prodigy/common.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,11 @@
55

66

77
def get_path(p):
8-
return os.path.join(
9-
os.path.dirname(__file__),
10-
p
11-
)
8+
return os.path.join(os.path.dirname(__file__), p)
129

13-
TEST_TOKENS = get_path('test_data/test_tokens_to_tsv_tokens.jsonl')
14-
TEST_SPANS = get_path('test_data/test_tokens_to_tsv_spans.jsonl')
15-
TEST_REF_TOKENS = get_path('test_data/test_reference_to_token_tokens.jsonl')
16-
TEST_REF_SPANS = get_path('test_data/test_reference_to_token_spans.jsonl')
17-
TEST_REF_EXPECTED_SPANS = get_path('test_data/test_reference_to_token_expected.jsonl')
10+
11+
TEST_TOKENS = get_path("test_data/test_tokens_to_tsv_tokens.jsonl")
12+
TEST_SPANS = get_path("test_data/test_tokens_to_tsv_spans.jsonl")
13+
TEST_REF_TOKENS = get_path("test_data/test_reference_to_token_tokens.jsonl")
14+
TEST_REF_SPANS = get_path("test_data/test_reference_to_token_spans.jsonl")
15+
TEST_REF_EXPECTED_SPANS = get_path("test_data/test_reference_to_token_expected.jsonl")

tests/prodigy/test_numbered_reference_annotator.py

Lines changed: 32 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,10 @@
33

44
import pytest
55
import spacy
6-
from deep_reference_parser.prodigy.numbered_reference_annotator import NumberedReferenceAnnotator
6+
from deep_reference_parser.prodigy.numbered_reference_annotator import (
7+
NumberedReferenceAnnotator,
8+
)
9+
710

811
@pytest.fixture(scope="function")
912
def nra():
@@ -111,20 +114,30 @@ def test_numbered_reference_splitter(nra):
111114
{"text": "\n", "start": 470, "end": 471, "id": 92},
112115
{"text": "3", "start": 471, "end": 472, "id": 92},
113116
{"text": ".", "start": 472, "end": 473, "id": 92},
114-
]
117+
],
115118
}
116119

117120
docs = list(nra.run([numbered_reference]))
118121
text = docs[0]["text"]
119122
spans = docs[0]["spans"]
120-
ref_1 = text[spans[0]["start"]:spans[0]["end"]]
121-
ref_2 = text[spans[1]["start"]:spans[1]["end"]]
122-
ref_3 = text[spans[2]["start"]:spans[2]["end"]]
123+
ref_1 = text[spans[0]["start"] : spans[0]["end"]]
124+
ref_2 = text[spans[1]["start"] : spans[1]["end"]]
125+
ref_3 = text[spans[2]["start"] : spans[2]["end"]]
123126

124127
assert len(spans) == 3
125-
assert ref_1 == "Global update on the health sector response to HIV, 2014. Geneva: World Health Organization; \n 2014:168."
126-
assert ref_2.strip() == "WHO, UNICEF, UNAIDS. Global update on HIV treatment 2013: results, impact and \n opportunities. Geneva: World Health Organization; 2013:126."
127-
assert ref_3.strip() == "Consolidated guidelines on the use of antiretroviral drugs for treating and preventing HIV infection: \n recommendations for a public health approach. Geneva: World Health Organization; 2013:272."
128+
assert (
129+
ref_1
130+
== "Global update on the health sector response to HIV, 2014. Geneva: World Health Organization; \n 2014:168."
131+
)
132+
assert (
133+
ref_2.strip()
134+
== "WHO, UNICEF, UNAIDS. Global update on HIV treatment 2013: results, impact and \n opportunities. Geneva: World Health Organization; 2013:126."
135+
)
136+
assert (
137+
ref_3.strip()
138+
== "Consolidated guidelines on the use of antiretroviral drugs for treating and preventing HIV infection: \n recommendations for a public health approach. Geneva: World Health Organization; 2013:272."
139+
)
140+
128141

129142
def test_numbered_reference_splitter_line_endings(nra):
130143
"""
@@ -196,15 +209,21 @@ def test_numbered_reference_splitter_line_endings(nra):
196209
{"text": "\n\n", "start": 261, "end": 263, "id": 58},
197210
{"text": "3", "start": 262, "end": 264, "id": 59},
198211
{"text": ".", "start": 263, "end": 265, "id": 60},
199-
]
212+
],
200213
}
201214

202215
docs = list(nra.run([numbered_reference]))
203216
text = docs[0]["text"]
204217
spans = docs[0]["spans"]
205-
ref_1 = text[spans[0]["start"]:spans[0]["end"]]
206-
ref_2 = text[spans[1]["start"]:spans[1]["end"]]
218+
ref_1 = text[spans[0]["start"] : spans[0]["end"]]
219+
ref_2 = text[spans[1]["start"] : spans[1]["end"]]
207220

208221
assert len(spans) == 2
209-
assert ref_1.strip() == "Global update on the health sector response to HIV, 2014. Geneva: World Health Organization; \n 2014:168."
210-
assert ref_2.strip() == "WHO, UNICEF, UNAIDS. Global update on HIV treatment 2013: results, impact and \n opportunities. Geneva: World Health Organization; 2013:126"
222+
assert (
223+
ref_1.strip()
224+
== "Global update on the health sector response to HIV, 2014. Geneva: World Health Organization; \n 2014:168."
225+
)
226+
assert (
227+
ref_2.strip()
228+
== "WHO, UNICEF, UNAIDS. Global update on HIV treatment 2013: results, impact and \n opportunities. Geneva: World Health Organization; 2013:126"
229+
)

tests/prodigy/test_prodigy_to_tsv.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -740,9 +740,4 @@ def test_reference_spans_real_example(doc):
740740

741741
import pprint
742742

743-
#pp = pprint.PrettyPrinter()
744-
#pp.pprint(actual)
745-
#for token, span in zip(doc["tokens"], doc["spans"]):
746-
# print({token["text"]:span["label"]})
747-
748743
assert actual == expected

tests/prodigy/test_reach_to_prodigy.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,37 +4,40 @@
44
import pytest
55
from deep_reference_parser.prodigy.reach_to_prodigy import ReachToProdigy
66

7+
78
@pytest.fixture(scope="function")
89
def stp():
910
ref_sections = [{}, {}, {}]
1011
return ReachToProdigy(ref_sections)
1112

13+
1214
def test_combine_n_rows(stp):
1315

1416
doc = list(range(100, 200))
1517
out = stp.combine_n_rows(doc, n=5, join_char=" ")
1618

17-
last_in_doc = doc[len(doc) -1]
19+
last_in_doc = doc[len(doc) - 1]
1820
last_in_out = int(out[-1].split(" ")[-1])
1921

2022
assert last_in_doc == last_in_out
2123

22-
assert out[0] == '100 101 102 103 104'
23-
assert out[-2] == '190 191 192 193 194'
24-
assert out[-1] == '195 196 197 198 199'
24+
assert out[0] == "100 101 102 103 104"
25+
assert out[-2] == "190 191 192 193 194"
26+
assert out[-1] == "195 196 197 198 199"
27+
2528

2629
def test_combine_n_rows_uneven_split(stp):
2730

2831
doc = list(range(100, 200))
2932
out = stp.combine_n_rows(doc, n=7, join_char=" ")
3033

31-
last_in_doc = doc[len(doc) -1]
34+
last_in_doc = doc[len(doc) - 1]
3235
last_in_out = int(out[-1].split(" ")[-1])
3336

3437
assert last_in_doc == last_in_out
3538
assert len(out[-1].split(" ")) == 2
3639
assert len(out[-2].split(" ")) == 7
3740

38-
assert out[0] == '100 101 102 103 104 105 106'
39-
assert out[-2] == '191 192 193 194 195 196 197'
40-
assert out[-1] == '198 199'
41+
assert out[0] == "100 101 102 103 104 105 106"
42+
assert out[-2] == "191 192 193 194 195 196 197"
43+
assert out[-1] == "198 199"

0 commit comments

Comments
 (0)