Skip to content
38 changes: 33 additions & 5 deletions deep_reference_parser/prodigy/prodigy_to_tsv.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,20 @@
import numpy as np
import plac

from ..io import read_jsonl
from wasabi import Printer

from ..io import read_jsonl
from ..logger import logger

msg = Printer()


class TokenLabelPairs:
"""
Convert prodigy format docs or list of lists into tuples of (token, label).
"""

def __init__(self, line_limit=73, respect_line_endings=True, respect_doc_endings=True):
def __init__(self, line_limit=250, respect_line_endings=False, respect_doc_endings=True):
"""
Args:
line_limit(int): Maximum number of tokens allowed per training
Expand Down Expand Up @@ -191,19 +194,45 @@ def yield_token_label_pair(self, doc, lists=False):
"positional",
None,
str
),
respect_lines=(
"Respect line endings? Or parse entire document in a single string?",
"flag",
"r",
bool
),
respect_docs=(
"Respect doc endings or parse corpus in single string?",
"flag",
"d",
bool
),
line_limit=(
"Number of characters to include on a line",
"option",
"l",
int
)
)
def prodigy_to_tsv(input_file, output_file):
def prodigy_to_tsv(input_file, output_file, respect_lines, respect_docs, line_limit=250):
"""
Convert token annotated jsonl to token annotated tsv ready for use in the
Rodrigues model.
"""

msg.info(f"Respect line endings: {respect_lines}")
msg.info(f"Respect doc endings: {respect_docs}")
msg.info(f"Line limit: {line_limit}")

annotated_data = read_jsonl(input_file)

logger.info("Loaded %s prodigy docs", len(annotated_data))

tlp = TokenLabelPairs()
tlp = TokenLabelPairs(
respect_doc_endings=respect_docs,
respect_line_endings=respect_lines,
line_limit=line_limit
)
token_label_pairs = list(tlp.run(annotated_data))

with open(output_file, 'w') as fb:
Expand All @@ -214,4 +243,3 @@ def prodigy_to_tsv(input_file, output_file):

logger.info("Wrote %s token/label pairs to %s", len(token_label_pairs),
output_file)

41 changes: 32 additions & 9 deletions deep_reference_parser/prodigy/reference_to_token_annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@


class TokenTagger:
def __init__(self, task="splitting", lowercase=True):
def __init__(self, task="splitting", lowercase=True, text=True):
"""
Converts data in prodigy format with full reference spans to per-token
spans
Expand All @@ -20,6 +20,8 @@ def __init__(self, task="splitting", lowercase=True):
explanation.
lowercase (bool): Automatically convert upper case annotations to
lowercase under the parsing scenario.
text (bool): Include the token text in the output span (very useful
for debugging).

Since the parsing, splitting, and classification tasks have quite
different labelling requirements, this class behaves differently
Expand Down Expand Up @@ -48,6 +50,7 @@ def __init__(self, task="splitting", lowercase=True):
self.out = []
self.task = task
self.lowercase = lowercase
self.text = text

def tag_doc(self, doc):
"""
Expand Down Expand Up @@ -177,19 +180,33 @@ def create_span(self, tokens, index, label):
"label": label,
}

if self.text:
span["text"] = token["text"]

return span

def split_long_span(self, tokens, span, start_label, end_label, inside_label):
"""
Split a multi-token span into `n` spans of lengh `1`, where `n=len(tokens)`
"""

spans = []
spans.append(self.create_span(tokens, span["token_start"], start_label))
spans.append(self.create_span(tokens, span["token_end"], end_label))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah so was this line the cause of the bug in the case where span_size = 0 then?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep precisely! It was a case that had never come up before because references were always longer than one token.

start = span["token_start"]
end = span["token_end"]

span_size = end - start

# Case when there is only one token in the span
if span_size == 0:
spans.append(self.create_span(tokens, start, start_label))
# Case when there are two or more tokens in the span
else:
spans.append(self.create_span(tokens, start, start_label))
spans.append(self.create_span(tokens, end, end_label))

for index in range(span["token_start"] + 1, span["token_end"]):
spans.append(self.create_span(tokens, index, inside_label))
if span_size > 1:

for index in range(start + 1, end):
spans.append(self.create_span(tokens, index, inside_label))

spans = sorted(spans, key=lambda k: k["token_start"])

Expand Down Expand Up @@ -221,9 +238,15 @@ def split_long_span(self, tokens, span, start_label, end_label, inside_label):
"f",
bool,
),
text=(
"Output the token text in the span (useful for debugging).",
"flag",
"t",
bool,
),
)
def reference_to_token_annotations(
input_file, output_file, task="splitting", lowercase=False
input_file, output_file, task="splitting", lowercase=False, text=False
):
"""
Creates a span for every token from existing multi-token spans
Expand Down Expand Up @@ -268,10 +291,10 @@ def reference_to_token_annotations(
"Loaded %s documents with no reference annotations", len(not_annotated_docs)
)

annotator = TokenTagger(task=task, lowercase=lowercase)
annotator = TokenTagger(task=task, lowercase=lowercase, text=text)

token_annotated_docs = annotator.run(ref_annotated_docs)
all_docs = token_annotated_docs + token_annotated_docs
all_docs = token_annotated_docs + not_annotated_docs

write_jsonl(all_docs, output_file=output_file)

Expand Down
16 changes: 7 additions & 9 deletions tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,11 @@


def get_path(p):
return os.path.join(
os.path.dirname(__file__),
p
)
return os.path.join(os.path.dirname(__file__), p)

TEST_CFG = get_path('test_data/test_config.ini')
TEST_JSONL = get_path('test_data/test_jsonl.jsonl')
TEST_REFERENCES = get_path('test_data/test_references.txt')
TEST_TSV_PREDICT = get_path('test_data/test_tsv_predict.tsv')
TEST_TSV_TRAIN = get_path('test_data/test_tsv_train.tsv')

TEST_CFG = get_path("test_data/test_config.ini")
TEST_JSONL = get_path("test_data/test_jsonl.jsonl")
TEST_REFERENCES = get_path("test_data/test_references.txt")
TEST_TSV_PREDICT = get_path("test_data/test_tsv_predict.tsv")
TEST_TSV_TRAIN = get_path("test_data/test_tsv_train.tsv")
15 changes: 15 additions & 0 deletions tests/prodigy/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#!/usr/bin/env python3
# coding: utf-8

import os


def get_path(p):
return os.path.join(os.path.dirname(__file__), p)


TEST_TOKENS = get_path("test_data/test_tokens_to_tsv_tokens.jsonl")
TEST_SPANS = get_path("test_data/test_tokens_to_tsv_spans.jsonl")
TEST_REF_TOKENS = get_path("test_data/test_reference_to_token_tokens.jsonl")
TEST_REF_SPANS = get_path("test_data/test_reference_to_token_spans.jsonl")
TEST_REF_EXPECTED_SPANS = get_path("test_data/test_reference_to_token_expected.jsonl")
Loading