Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 95 additions & 0 deletions flair/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,6 +687,24 @@ def __len__(self) -> int:
"""Length of the data point (e.g., number of tokens for Sentence)."""
raise NotImplementedError

# Default implementation for simpler classes
def _get_dynamic_embedding_names(self) -> set[str]:
"""
Internal helper to find names of embeddings with requires_grad=True.
Default implementation checks only direct embeddings. Subclasses override
for recursive checks if needed.
"""
return {name for name, vec in self._embeddings.items() if vec.requires_grad}

# Default implementation for simpler classes
def _get_all_embedding_names(self) -> set[str]:
"""
Internal helper to find names of all embeddings.
Default implementation checks only direct embeddings. Subclasses override
for recursive checks if needed.
"""
return set(self._embeddings.keys())


class EntityCandidate:
"""Represents a potential candidate entity from a knowledge base for entity linking."""
Expand Down Expand Up @@ -1034,6 +1052,24 @@ def to_dict(self, tag_type: Optional[str] = None):
"labels": [label.to_dict() for label in self.get_labels(tag_type)],
}

# Keep overridden implementation for Span
def _get_dynamic_embedding_names(self) -> set[str]:
"""Finds dynamic embedding names from the span itself and its tokens."""
# Start with default implementation for self._embeddings
names = super()._get_dynamic_embedding_names()
for token in self.tokens:
names.update(token._get_dynamic_embedding_names()) # Token uses default impl.
return names

# Keep overridden implementation for Span
def _get_all_embedding_names(self) -> set[str]:
"""Finds all embedding names from the span itself and its tokens."""
# Start with default implementation for self._embeddings
names = super()._get_all_embedding_names()
for token in self.tokens:
names.update(token._get_all_embedding_names()) # Token uses default impl.
return names


class Relation(_PartOfSentence):
"""Represents a directed relationship between two Spans in the same Sentence.
Expand Down Expand Up @@ -1989,6 +2025,27 @@ def retokenize(self, new_tokenizer: Tokenizer) -> None:
# The .tokens property would also do this on next access, but this forces it now.
self._perform_retokenization_with_annotation_preservation(new_tokenizer)

# Keep overridden implementation for Sentence
def _get_dynamic_embedding_names(self) -> set[str]:
"""Finds dynamic embedding names from the sentence itself and its tokens."""
# Start with default implementation for self._embeddings
names = super()._get_dynamic_embedding_names()
# Check embeddings of constituent tokens as well
if self._is_tokenized(): # Avoid tokenizing if not already done
for token in self.tokens:
names.update(token._get_dynamic_embedding_names()) # Token uses default impl.
return names

# Keep overridden implementation for Sentence
def _get_all_embedding_names(self) -> set[str]:
"""Finds all embedding names from the sentence itself and its tokens."""
# Start with default implementation for self._embeddings
names = super()._get_all_embedding_names()
if self._is_tokenized(): # Avoid tokenizing if not already done
for token in self.tokens:
names.update(token._get_all_embedding_names()) # Token uses default impl.
return names


class DataPair(DataPoint, typing.Generic[DT, DT2]):
"""Represents a pair of DataPoints, often used for sentence-pair tasks."""
Expand Down Expand Up @@ -2038,6 +2095,24 @@ def end_position(self) -> int:
def text(self):
return self.first.text + " || " + self.second.text

# Keep overridden implementation for DataPair
def _get_dynamic_embedding_names(self) -> set[str]:
"""Finds dynamic embedding names from the pair itself and its components."""
# Start with default implementation for self._embeddings
names = super()._get_dynamic_embedding_names()
names.update(self.first._get_dynamic_embedding_names()) # Recursive call
names.update(self.second._get_dynamic_embedding_names()) # Recursive call
return names

# Keep overridden implementation for DataPair
def _get_all_embedding_names(self) -> set[str]:
"""Finds all embedding names from the pair itself and its components."""
# Start with default implementation for self._embeddings
names = super()._get_all_embedding_names()
names.update(self.first._get_all_embedding_names()) # Recursive call
names.update(self.second._get_all_embedding_names()) # Recursive call
return names


TextPair = DataPair[Sentence, Sentence]
"""Type alias for a DataPair consisting of two Sentences."""
Expand Down Expand Up @@ -2092,6 +2167,26 @@ def end_position(self) -> int:
def text(self):
return self.first.text + " || " + self.second.text + "||" + self.third.text

# Keep overridden implementation for DataTriple
def _get_dynamic_embedding_names(self) -> set[str]:
"""Finds dynamic embedding names from the triple itself and its components."""
# Start with default implementation for self._embeddings
names = super()._get_dynamic_embedding_names()
names.update(self.first._get_dynamic_embedding_names()) # Recursive call
names.update(self.second._get_dynamic_embedding_names()) # Recursive call
names.update(self.third._get_dynamic_embedding_names()) # Recursive call
return names

# Keep overridden implementation for DataTriple
def _get_all_embedding_names(self) -> set[str]:
"""Finds all embedding names from the triple itself and its components."""
# Start with default implementation for self._embeddings
names = super()._get_all_embedding_names()
names.update(self.first._get_all_embedding_names()) # Recursive call
names.update(self.second._get_all_embedding_names()) # Recursive call
names.update(self.third._get_all_embedding_names()) # Recursive call
return names


TextTriple = DataTriple[Sentence, Sentence, Sentence]
"""Type alias for a DataTriple consisting of three Sentences."""
Expand Down
48 changes: 31 additions & 17 deletions flair/training_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,25 +411,39 @@ def store_embeddings(


def identify_dynamic_embeddings(data_points: list[DT]) -> Optional[list[str]]:
dynamic_embeddings = []
all_embeddings = []
"""
Identifies the names of all embeddings across a list of DataPoints
that have requires_grad set to True by checking the DataPoints and their components.

Args:
data_points: A list of Flair DataPoints (Token, Sentence, DataPair, etc.).

Returns:
A list of unique dynamic embedding names, or None if no embeddings (dynamic or static)
are found in any of the data points.
"""
all_dynamic_embeddings: set[str] = set()
any_embeddings_found = False

for data_point in data_points:
if isinstance(data_point, Sentence):
first_token = data_point[0]
for name, vector in first_token._embeddings.items():
if vector.requires_grad:
dynamic_embeddings.append(name)
all_embeddings.append(name)

for name, vector in data_point._embeddings.items():
if vector.requires_grad:
dynamic_embeddings.append(name)
all_embeddings.append(name)
if dynamic_embeddings:
return dynamic_embeddings
if not all_embeddings:
# Use the internal helper method defined in the DataPoint class
# This method handles recursion for composite types like Sentence, DataPair etc.
point_dynamic_embeddings = data_point._get_dynamic_embedding_names()
all_dynamic_embeddings.update(point_dynamic_embeddings)

# Check if *any* embeddings exist at all (dynamic or static)
# to decide whether to return None or an empty list later.
if not any_embeddings_found:
# Check if the point has *any* embeddings using the helper
if data_point._get_all_embedding_names():
any_embeddings_found = True

# Return None only if no embeddings whatsoever were found across all data points
if not any_embeddings_found and not all_dynamic_embeddings:
return None
return list(set(dynamic_embeddings))

# Otherwise, return the list of unique dynamic embedding names found (could be empty)
return list(all_dynamic_embeddings)


class TokenEntity(NamedTuple):
Expand Down
191 changes: 191 additions & 0 deletions tests/test_training_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
import torch

from flair.data import Token, Sentence, DataPair, Span, DataTriple, Image
from flair.training_utils import identify_dynamic_embeddings

# Helper tensors
dynamic_tensor = torch.tensor([1.0, 2.0], requires_grad=True)
static_tensor = torch.tensor([3.0, 4.0], requires_grad=False)


def test_identify_dynamic_embeddings_empty_list():
"""Test with an empty list of data points."""
assert identify_dynamic_embeddings([]) is None


def test_identify_dynamic_embeddings_no_embeddings():
"""Test with data points that have no embeddings at all."""
token = Token("hello")
sentence = Sentence("world")
pair = DataPair(Token("a"), Token("b"))
assert identify_dynamic_embeddings([token]) is None
assert identify_dynamic_embeddings([sentence]) is None # Sentence itself has no embedding yet
assert identify_dynamic_embeddings([pair]) is None
assert identify_dynamic_embeddings([token, sentence, pair]) is None


def test_identify_dynamic_embeddings_only_static():
"""Test with data points having only static embeddings."""
token = Token("hello")
token.set_embedding("static_tok", static_tensor.clone())

sentence = Sentence("world .") # Creates tokens
sentence.set_embedding("static_sent", static_tensor.clone())
sentence.tokens[0].set_embedding("static_sent_tok", static_tensor.clone())

pair_tok1 = Token("a")
pair_tok1.set_embedding("static_pair_tok1", static_tensor.clone())
pair_tok2 = Token("b")
pair_tok2.set_embedding("static_pair_tok2", static_tensor.clone())
pair = DataPair(pair_tok1, pair_tok2)
pair.set_embedding("static_pair", static_tensor.clone())

assert identify_dynamic_embeddings([token]) == []
assert identify_dynamic_embeddings([sentence]) == []
assert identify_dynamic_embeddings([pair]) == []
assert identify_dynamic_embeddings([token, sentence, pair]) == []


def test_identify_dynamic_embeddings_token():
"""Test with a single Token having mixed embeddings."""
token = Token("test")
token.set_embedding("dynamic_1", dynamic_tensor.clone())
token.set_embedding("static_1", static_tensor.clone())
result = identify_dynamic_embeddings([token])
assert isinstance(result, list)
assert set(result) == {"dynamic_1"}


def test_identify_dynamic_embeddings_sentence_direct():
"""Test with a Sentence having direct mixed embeddings (no token embeddings)."""
sentence = Sentence("test sentence")
sentence.set_embedding("dynamic_sent", dynamic_tensor.clone())
sentence.set_embedding("static_sent", static_tensor.clone())
# Note: sentence.tokens exist but don't have embeddings set here
result = identify_dynamic_embeddings([sentence])
assert isinstance(result, list)
assert set(result) == {"dynamic_sent"}


def test_identify_dynamic_embeddings_sentence_with_tokens():
"""Test with a Sentence and its Tokens having mixed embeddings."""
sentence = Sentence("test sentence") # Creates tokens
sentence.set_embedding("dynamic_sent", dynamic_tensor.clone())
sentence.set_embedding("static_sent", static_tensor.clone())
sentence.tokens[0].set_embedding("dynamic_tok_0", dynamic_tensor.clone())
sentence.tokens[0].set_embedding("static_tok_0", static_tensor.clone())
sentence.tokens[1].set_embedding("static_tok_1", static_tensor.clone())

result = identify_dynamic_embeddings([sentence])
assert isinstance(result, list)
assert set(result) == {"dynamic_sent", "dynamic_tok_0"}


def test_identify_dynamic_embeddings_span():
"""Test with a Span containing tokens with mixed embeddings."""
sentence = Sentence("This is a span test") # Creates tokens
sentence.tokens[2].set_embedding("dynamic_tok_2", dynamic_tensor.clone()) # Token within span
sentence.tokens[3].set_embedding("static_tok_3", static_tensor.clone()) # Token within span
sentence.tokens[0].set_embedding("static_tok_0", static_tensor.clone()) # Token outside span

span = sentence[2:4] # Span over "a span"

result = identify_dynamic_embeddings([span]) # Test span directly (depends on how user might use it)
assert isinstance(result, list)
# Should find dynamic embeddings on the span AND its constituent tokens
assert set(result) == {"dynamic_tok_2"}

# More typical use case: check the sentence containing the span
result_sent = identify_dynamic_embeddings([sentence])
assert isinstance(result_sent, list)
assert set(result_sent) == {"dynamic_tok_2"}


def test_identify_dynamic_embeddings_datapair():
"""Test with a DataPair containing Tokens with mixed embeddings."""
tok1 = Token("first")
tok1.set_embedding("dynamic_tok1", dynamic_tensor.clone())
tok1.set_embedding("static_tok1", static_tensor.clone())

tok2 = Token("second")
tok2.set_embedding("static_tok2", static_tensor.clone())

pair = DataPair(tok1, tok2)
pair.set_embedding("dynamic_pair", dynamic_tensor.clone())
pair.set_embedding("static_pair", static_tensor.clone())

result = identify_dynamic_embeddings([pair])
assert isinstance(result, list)
assert set(result) == {"dynamic_tok1", "dynamic_pair"}


def test_identify_dynamic_embeddings_datatriple():
"""Test with a DataTriple containing Tokens with mixed embeddings."""
tok1 = Token("first")
tok1.set_embedding("dynamic_tok1", dynamic_tensor.clone())

tok2 = Token("second")
tok2.set_embedding("static_tok2", static_tensor.clone())

tok3 = Token("third")
tok3.set_embedding("dynamic_tok3", dynamic_tensor.clone())

triple = DataTriple(tok1, tok2, tok3)
triple.set_embedding("dynamic_triple", dynamic_tensor.clone())
triple.set_embedding("static_triple", static_tensor.clone())

result = identify_dynamic_embeddings([triple])
assert isinstance(result, list)
assert set(result) == {"dynamic_tok1", "dynamic_tok3", "dynamic_triple"}


def test_identify_dynamic_embeddings_image():
"""Test with an Image data point."""
image = Image()
image.set_embedding("dynamic_img", dynamic_tensor.clone())
image.set_embedding("static_img", static_tensor.clone())

result = identify_dynamic_embeddings([image])
assert isinstance(result, list)
assert set(result) == {"dynamic_img"}


def test_identify_dynamic_embeddings_mixed_list():
"""Test with a list containing various data point types."""
token = Token("just a token")
token.set_embedding("dynamic_mixed_tok", dynamic_tensor.clone())

sentence = Sentence("a sentence")
sentence.set_embedding("dynamic_mixed_sent", dynamic_tensor.clone())
sentence.tokens[0].set_embedding("dynamic_mixed_sent_tok", dynamic_tensor.clone())

pair_tok1 = Token("pair_a")
pair_tok1.set_embedding("dynamic_mixed_pair_tok", dynamic_tensor.clone())
pair_tok2 = Token("pair_b")
pair = DataPair(pair_tok1, pair_tok2)
pair.set_embedding("dynamic_mixed_pair", dynamic_tensor.clone())

image = Image()
image.set_embedding("dynamic_mixed_img", dynamic_tensor.clone())

# Add one with only static to ensure it doesn't get picked up
static_token = Token("static only")
static_token.set_embedding("static_mixed", static_tensor.clone())

# Add one with no embeddings
empty_token = Token("empty")

data_points = [token, sentence, pair, image, static_token, empty_token]
result = identify_dynamic_embeddings(data_points)

expected = {
"dynamic_mixed_tok",
"dynamic_mixed_sent",
"dynamic_mixed_sent_tok",
"dynamic_mixed_pair_tok",
"dynamic_mixed_pair",
"dynamic_mixed_img",
}

assert isinstance(result, list)
assert set(result) == expected
Loading