Skip to content

Refactor Token and Sentence Positional Properties #3001

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Nov 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 28 additions & 29 deletions flair/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,8 +504,7 @@ def __init__(
self.head_id: Optional[int] = head_id
self.whitespace_after: int = whitespace_after

self.start_pos = start_position
self.end_pos = start_position + len(text)
self._start_position = start_position

self._embeddings: Dict = {}
self.tags_proba_dist: Dict[str, List[Label]] = {}
Expand All @@ -518,7 +517,7 @@ def idx(self) -> int:
raise ValueError

@property
def text(self):
def text(self) -> str:
return self.form

@property
Expand All @@ -538,11 +537,15 @@ def get_head(self):

@property
def start_position(self) -> int:
return self.start_pos
return self._start_position

@start_position.setter
def start_position(self, value: int) -> None:
self._start_position = value

@property
def end_position(self) -> int:
return self.end_pos
return self.start_position + len(self.text)

@property
def embedding(self):
Expand Down Expand Up @@ -709,8 +712,7 @@ def __init__(

self.language_code: Optional[str] = language_code

self.start_pos = start_position
self.end_pos = start_position + len(text)
self._start_position = start_position

# the tokenizer used for this sentence
if isinstance(use_tokenizer, Tokenizer):
Expand All @@ -730,30 +732,22 @@ def __init__(
words = tokenizer.tokenize(text)
else:
words = text
text = " ".join(words)

# determine token positions and whitespace_after flag
current_offset = 0
previous_word_offset = -1
previous_token = None
current_offset: int = 0
previous_token: Optional[Token] = None
for word in words:
try:
word_offset = text.index(word, current_offset)
start_position = word_offset
delta_offset = start_position - current_offset
except ValueError:
word_offset = previous_word_offset + 1
start_position = current_offset + 1 if current_offset > 0 else current_offset
delta_offset = start_position - current_offset

if word:
token = Token(text=word, start_position=start_position)
self.add_token(token)
word_start_position: int = text.index(word, current_offset)
delta_offset: int = word_start_position - current_offset

token: Token = Token(text=word, start_position=word_start_position)
self.add_token(token)

if previous_token is not None:
previous_token.whitespace_after = delta_offset

current_offset = word_offset + len(word)
previous_word_offset = current_offset - 1
current_offset = token.end_position
previous_token = token

# the last token has no whitespace after
Expand Down Expand Up @@ -815,8 +809,7 @@ def add_token(self, token: Union[Token, str]):
token.sentence = self
token._internal_index = len(self.tokens) + 1
if token.start_position == 0 and len(self) > 0:
token.start_pos = len(self.to_original_text()) + self[-1].whitespace_after
token.end_pos = token.start_pos + len(token.text)
token.start_position = len(self.to_original_text()) + self[-1].whitespace_after

# append token to sentence
self.tokens.append(token)
Expand Down Expand Up @@ -958,7 +951,7 @@ def to_original_text(self) -> str:
if len(self) == 0:
return ""
# otherwise, return concatenation of tokens with the correct offsets
return self[0].start_pos * " " + "".join([t.text + t.whitespace_after * " " for t in self.tokens]).strip()
return self[0].start_position * " " + "".join([t.text + t.whitespace_after * " " for t in self.tokens]).strip()

def to_dict(self, tag_type: str = None):
labels = []
Expand Down Expand Up @@ -1001,11 +994,17 @@ def __repr__(self):

@property
def start_position(self) -> int:
return 0
return self._start_position

@start_position.setter
def start_position(self, value: int) -> None:
self._start_position = value

@property
def end_position(self) -> int:
return len(self.to_original_text())
# The sentence's start position is not propagated to its tokens.
# Therefore, we need to add the sentence's start position to its last token's end position, including whitespaces.
return self.start_position + self[-1].end_position + self[-1].whitespace_after

def get_language_code(self) -> str:
if self.language_code is None:
Expand Down
6 changes: 3 additions & 3 deletions flair/datasets/biomedical.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,9 +384,9 @@ def write_to_conll(self, dataset: InternalBioNerDataset, output_file: Path):

for flair_token in sentence.tokens:
token = flair_token.text.strip()
assert sentence.start_pos is not None
assert flair_token.start_pos is not None
offset = sentence.start_pos + flair_token.start_pos
assert sentence.start_position is not None
assert flair_token.start_position is not None
offset = sentence.start_position + flair_token.start_position

if current_entity and offset >= current_entity.char_span.stop:
in_entity = False
Expand Down
31 changes: 19 additions & 12 deletions flair/datasets/entity_linking.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def __init__(

# sentence splitting and tokenization
sentences = sentence_splitter.split(string)
sentence_offsets = [sentence.start_pos or 0 for sentence in sentences]
sentence_offsets = [sentence.start_position or 0 for sentence in sentences]

# iterate through all annotations and add to corresponding tokens
for mention_start, mention_length, wikiname in zip(indices, lengths, wikinames):
Expand All @@ -158,10 +158,10 @@ def __init__(
# set annotation for tokens of entity mention
first = True
for token in sentences[sentence_index].tokens:
assert token.start_pos is not None
assert token.end_pos is not None
assert token.start_position is not None
assert token.end_position is not None
if (
token.start_pos >= mention_start and token.end_pos <= mention_end
token.start_position >= mention_start and token.end_position <= mention_end
): # token belongs to entity mention
if first:
token.set_label(typename="nel", value="B-" + wikiname)
Expand Down Expand Up @@ -603,7 +603,7 @@ def __init__(

# split sentences and tokenize
sentences = sentence_splitter.split(text)
sentence_offsets = [sentence.start_pos or 0 for sentence in sentences]
sentence_offsets = [sentence.start_position or 0 for sentence in sentences]

# iterate through all annotations and add to corresponding tokens
for elem in root:
Expand Down Expand Up @@ -631,10 +631,10 @@ def __init__(
# set annotation for tokens of entity mention
first = True
for token in sentences[sentence_index].tokens:
assert token.start_pos is not None
assert token.end_pos is not None
assert token.start_position is not None
assert token.end_position is not None
if (
token.start_pos >= mention_start and token.end_pos <= mention_end
token.start_position >= mention_start and token.end_position <= mention_end
): # token belongs to entity mention
assert elem[1].text is not None
if first:
Expand Down Expand Up @@ -928,18 +928,25 @@ def _text_to_cols(self, sentence: Sentence, links: list, outfile):
if links:
# Keep track which is the correct corresponding entity link, in cases where there is >1 link in a sentence
link_index = [
j for j, v in enumerate(links) if (sentence[i].start_pos >= v[0] and sentence[i].end_pos <= v[1])
j
for j, v in enumerate(links)
if (sentence[i].start_position >= v[0] and sentence[i].end_position <= v[1])
]
# Write the token with a corresponding tag to file
try:
if any(sentence[i].start_pos == v[0] and sentence[i].end_pos == v[1] for j, v in enumerate(links)):
if any(
sentence[i].start_position == v[0] and sentence[i].end_position == v[1]
for j, v in enumerate(links)
):
outfile.writelines(sentence[i].text + "\tS-" + links[link_index[0]][2] + "\n")
elif any(
sentence[i].start_pos == v[0] and sentence[i].end_pos != v[1] for j, v in enumerate(links)
sentence[i].start_position == v[0] and sentence[i].end_position != v[1]
for j, v in enumerate(links)
):
outfile.writelines(sentence[i].text + "\tB-" + links[link_index[0]][2] + "\n")
elif any(
sentence[i].start_pos >= v[0] and sentence[i].end_pos <= v[1] for j, v in enumerate(links)
sentence[i].start_position >= v[0] and sentence[i].end_position <= v[1]
for j, v in enumerate(links)
):
outfile.writelines(sentence[i].text + "\tI-" + links[link_index[0]][2] + "\n")
else:
Expand Down
12 changes: 6 additions & 6 deletions flair/datasets/relation_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,10 +687,10 @@ def drugprot_document_to_tokenlists(
(abstract_offset, abstract_sentences),
]:
for sent in sents:
assert sent.start_pos is not None
assert sent.end_pos is not None
sent_char_start = sent.start_pos + offset
sent_char_end = sent.end_pos + offset
assert sent.start_position is not None
assert sent.end_position is not None
sent_char_start = sent.start_position + offset
sent_char_end = sent.end_position + offset

entities_in_sent = set()
for entity_id, (_, char_start, char_end, _) in entities.items():
Expand All @@ -701,8 +701,8 @@ def drugprot_document_to_tokenlists(

token_offsets = [
(
sent.start_pos + (token.start_pos or 0) + offset,
sent.start_pos + (token.end_pos or 0) + offset,
sent.start_position + (token.start_position or 0) + offset,
sent.start_position + (token.end_position or 0) + offset,
)
for token in sent.tokens
]
Expand Down
11 changes: 5 additions & 6 deletions flair/datasets/sequence_labeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,10 +216,10 @@ def _add_label_to_sentence(self, text: str, sentence: Sentence, start: int, end:
start_idx = -1
end_idx = -1
for token in sentence:
if token.start_pos <= start <= token.end_pos and start_idx == -1:
if token.start_position <= start <= token.end_position and start_idx == -1:
start_idx = token.idx - 1

if token.start_pos <= end <= token.end_pos and end_idx == -1:
if token.start_position <= end <= token.end_position and end_idx == -1:
end_idx = token.idx - 1

# If end index is not found set to last token
Expand Down Expand Up @@ -740,12 +740,11 @@ def _parse_token(self, line: str, column_name_map: Dict[int, str], last_token: O
if last_token is None:
start = 0
else:
assert last_token.end_pos is not None
start = last_token.end_pos
assert last_token.end_position is not None
start = last_token.end_position
if last_token.whitespace_after > 0:
start += last_token.whitespace_after
token.start_pos = start
token.end_pos = token.start_pos + len(token.text)
token.start_position = start
return token

def _remap_label(self, tag):
Expand Down
4 changes: 2 additions & 2 deletions flair/models/regexp_tagger.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ class TokenCollection:

def __post_init__(self):
for token in self.tokens:
self.__tokens_start_pos.append(token.start_pos)
self.__tokens_end_pos.append(token.end_pos)
self.__tokens_start_pos.append(token.start_position)
self.__tokens_end_pos.append(token.end_position)

@property
def tokens(self) -> List[Token]:
Expand Down
4 changes: 2 additions & 2 deletions tests/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,8 @@ def test_load_sequence_labeling_whitespace_after(tasks_base_path):
assert corpus.train[0].to_tokenized_string() == "It is a German - owned firm ."
assert corpus.train[0].to_plain_string() == "It is a German-owned firm."
for token in corpus.train[0]:
assert token.start_pos is not None
assert token.end_pos is not None
assert token.start_position is not None
assert token.end_position is not None


def test_load_column_corpus_options(tasks_base_path):
Expand Down
27 changes: 27 additions & 0 deletions tests/test_sentence.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,30 @@ def test_token_labeling():
sentence[0].add_label("pos", "erstes")
assert [label.value for label in sentence.get_labels("pos")] == ["first", "primero", "erstes"]
assert sentence[0].get_label("pos").value == "first"


def test_start_end_position_untokenized() -> None:
sentence: Sentence = Sentence("This is a sentence.", start_position=10)
assert sentence.start_position == 10
assert sentence.end_position == 29
assert [(token.start_position, token.end_position) for token in sentence] == [
(0, 4),
(5, 7),
(8, 9),
(10, 18),
(18, 19),
]


def test_start_end_position_pretokenized() -> None:
# Initializing a Sentence this way assumes that there is a space after each token
sentence: Sentence = Sentence(["This", "is", "a", "sentence", "."], start_position=10)
assert sentence.start_position == 10
assert sentence.end_position == 30
assert [(token.start_position, token.end_position) for token in sentence] == [
(0, 4),
(5, 7),
(8, 9),
(10, 18),
(19, 20),
]
Loading