Skip to content

Commit 2d87fb0

Browse files
author
Benedikt Fuchs
committed
fix extern dataset update
1 parent 270ef05 commit 2d87fb0

File tree

2 files changed

+114
-113
lines changed

2 files changed

+114
-113
lines changed

flair/embeddings/transformer.py

+113-112
Original file line numberDiff line numberDiff line change
@@ -70,11 +70,11 @@ def truncate_hidden_states(hidden_states: torch.Tensor, input_ids: torch.Tensor)
7070

7171
@torch.jit.script_if_tracing
7272
def combine_strided_tensors(
73-
hidden_states: torch.Tensor,
74-
overflow_to_sample_mapping: torch.Tensor,
75-
half_stride: int,
76-
max_length: int,
77-
default_value: int,
73+
hidden_states: torch.Tensor,
74+
overflow_to_sample_mapping: torch.Tensor,
75+
half_stride: int,
76+
max_length: int,
77+
default_value: int,
7878
) -> torch.Tensor:
7979
_, counts = torch.unique(overflow_to_sample_mapping, sorted=True, return_counts=True)
8080
sentence_count = int(overflow_to_sample_mapping.max().item() + 1)
@@ -94,9 +94,9 @@ def combine_strided_tensors(
9494
selected_sentences = hidden_states[overflow_to_sample_mapping == sentence_id]
9595
if selected_sentences.size(0) > 1:
9696
start_part = selected_sentences[0, : half_stride + 1]
97-
mid_part = selected_sentences[:, half_stride + 1: max_length - 1 - half_stride]
97+
mid_part = selected_sentences[:, half_stride + 1 : max_length - 1 - half_stride]
9898
mid_part = torch.reshape(mid_part, (mid_part.shape[0] * mid_part.shape[1],) + mid_part.shape[2:])
99-
end_part = selected_sentences[selected_sentences.shape[0] - 1, max_length - half_stride - 1:]
99+
end_part = selected_sentences[selected_sentences.shape[0] - 1, max_length - half_stride - 1 :]
100100
sentence_hidden_state = torch.cat((start_part, mid_part, end_part), dim=0)
101101
sentence_hidden_states[sentence_id, : sentence_hidden_state.shape[0]] = torch.cat(
102102
(start_part, mid_part, end_part), dim=0
@@ -109,11 +109,11 @@ def combine_strided_tensors(
109109

110110
@torch.jit.script_if_tracing
111111
def fill_masked_elements(
112-
all_token_embeddings: torch.Tensor,
113-
sentence_hidden_states: torch.Tensor,
114-
mask: torch.Tensor,
115-
word_ids: torch.Tensor,
116-
lengths: torch.LongTensor,
112+
all_token_embeddings: torch.Tensor,
113+
sentence_hidden_states: torch.Tensor,
114+
mask: torch.Tensor,
115+
word_ids: torch.Tensor,
116+
lengths: torch.LongTensor,
117117
):
118118
for i in torch.arange(int(all_token_embeddings.shape[0])):
119119
r = insert_missing_embeddings(sentence_hidden_states[i][mask[i] & (word_ids[i] >= 0)], word_ids[i], lengths[i])
@@ -123,7 +123,7 @@ def fill_masked_elements(
123123

124124
@torch.jit.script_if_tracing
125125
def insert_missing_embeddings(
126-
token_embeddings: torch.Tensor, word_id: torch.Tensor, length: torch.LongTensor
126+
token_embeddings: torch.Tensor, word_id: torch.Tensor, length: torch.LongTensor
127127
) -> torch.Tensor:
128128
# in some cases we need to insert zero vectors for tokens without embedding.
129129
if token_embeddings.shape[0] == 0:
@@ -166,10 +166,10 @@ def insert_missing_embeddings(
166166

167167
@torch.jit.script_if_tracing
168168
def fill_mean_token_embeddings(
169-
all_token_embeddings: torch.Tensor,
170-
sentence_hidden_states: torch.Tensor,
171-
word_ids: torch.Tensor,
172-
token_lengths: torch.Tensor,
169+
all_token_embeddings: torch.Tensor,
170+
sentence_hidden_states: torch.Tensor,
171+
word_ids: torch.Tensor,
172+
token_lengths: torch.Tensor,
173173
):
174174
for i in torch.arange(all_token_embeddings.shape[0]):
175175
for _id in torch.arange(token_lengths[i]): # type: ignore[call-overload]
@@ -196,7 +196,7 @@ def document_max_pooling(sentence_hidden_states: torch.Tensor, sentence_lengths:
196196

197197

198198
def _legacy_reconstruct_word_ids(
199-
embedding: "TransformerBaseEmbeddings", flair_tokens: List[List[str]]
199+
embedding: "TransformerBaseEmbeddings", flair_tokens: List[List[str]]
200200
) -> List[List[Optional[int]]]:
201201
word_ids_list = []
202202
max_len = 0
@@ -307,25 +307,25 @@ class TransformerBaseEmbeddings(Embeddings[Sentence]):
307307
"""
308308

309309
def __init__(
310-
self,
311-
name: str,
312-
tokenizer: PreTrainedTokenizer,
313-
embedding_length: int,
314-
context_length: int,
315-
context_dropout: float,
316-
respect_document_boundaries: bool,
317-
stride: int,
318-
allow_long_sentences: bool,
319-
fine_tune: bool,
320-
truncate: bool,
321-
use_lang_emb: bool,
322-
is_document_embedding: bool = False,
323-
is_token_embedding: bool = False,
324-
force_device: Optional[torch.device] = None,
325-
force_max_length: bool = False,
326-
feature_extractor: Optional[FeatureExtractionMixin] = None,
327-
needs_manual_ocr: Optional[bool] = None,
328-
use_context_separator: bool = True,
310+
self,
311+
name: str,
312+
tokenizer: PreTrainedTokenizer,
313+
embedding_length: int,
314+
context_length: int,
315+
context_dropout: float,
316+
respect_document_boundaries: bool,
317+
stride: int,
318+
allow_long_sentences: bool,
319+
fine_tune: bool,
320+
truncate: bool,
321+
use_lang_emb: bool,
322+
is_document_embedding: bool = False,
323+
is_token_embedding: bool = False,
324+
force_device: Optional[torch.device] = None,
325+
force_max_length: bool = False,
326+
feature_extractor: Optional[FeatureExtractionMixin] = None,
327+
needs_manual_ocr: Optional[bool] = None,
328+
use_context_separator: bool = True,
329329
) -> None:
330330
self.name = name
331331
super().__init__()
@@ -473,32 +473,32 @@ def prepare_tensors(self, sentences: List[Sentence], device: Optional[torch.devi
473473

474474
# random check some tokens to save performance.
475475
if (self.needs_manual_ocr or self.tokenizer_needs_ocr_boxes) and not all(
476-
[
477-
flair_tokens[0][0].has_metadata("bbox"),
478-
flair_tokens[0][-1].has_metadata("bbox"),
479-
flair_tokens[-1][0].has_metadata("bbox"),
480-
flair_tokens[-1][-1].has_metadata("bbox"),
481-
]
476+
[
477+
flair_tokens[0][0].has_metadata("bbox"),
478+
flair_tokens[0][-1].has_metadata("bbox"),
479+
flair_tokens[-1][0].has_metadata("bbox"),
480+
flair_tokens[-1][-1].has_metadata("bbox"),
481+
]
482482
):
483483
raise ValueError(f"The embedding '{self.name}' requires the ocr 'bbox' set as metadata on all tokens.")
484484

485485
if self.feature_extractor is not None and not all(
486-
[
487-
sentences[0].has_metadata("image"),
488-
sentences[-1].has_metadata("image"),
489-
]
486+
[
487+
sentences[0].has_metadata("image"),
488+
sentences[-1].has_metadata("image"),
489+
]
490490
):
491491
raise ValueError(f"The embedding '{self.name}' requires the 'image' set as metadata for all sentences.")
492492

493493
return self.__build_transformer_model_inputs(sentences, offsets, lengths, flair_tokens, device)
494494

495495
def __build_transformer_model_inputs(
496-
self,
497-
sentences: List[Sentence],
498-
offsets: List[int],
499-
sentence_lengths: List[int],
500-
flair_tokens: List[List[Token]],
501-
device: torch.device,
496+
self,
497+
sentences: List[Sentence],
498+
offsets: List[int],
499+
sentence_lengths: List[int],
500+
flair_tokens: List[List[Token]],
501+
device: torch.device,
502502
):
503503
tokenizer_kwargs: Dict[str, Any] = {}
504504
if self.tokenizer_needs_ocr_boxes:
@@ -559,7 +559,7 @@ def __build_transformer_model_inputs(
559559
sentence_idx = 0
560560
for sentence, part_length in zip(sentences, sentence_part_lengths):
561561
lang_id = lang2id.get(sentence.get_language_code(), 0)
562-
model_kwargs["langs"][sentence_idx: sentence_idx + part_length] = lang_id
562+
model_kwargs["langs"][sentence_idx : sentence_idx + part_length] = lang_id
563563
sentence_idx += part_length
564564

565565
if "bbox" in batch_encoding:
@@ -801,12 +801,12 @@ def collect_dynamic_axes(cls, embedding: "TransformerEmbeddings", tensors):
801801

802802
@classmethod
803803
def export_from_embedding(
804-
cls,
805-
path: Union[str, Path],
806-
embedding: "TransformerEmbeddings",
807-
example_sentences: List[Sentence],
808-
opset_version: int = 14,
809-
providers: Optional[List] = None,
804+
cls,
805+
path: Union[str, Path],
806+
embedding: "TransformerEmbeddings",
807+
example_sentences: List[Sentence],
808+
opset_version: int = 14,
809+
providers: Optional[List] = None,
810810
):
811811
path = str(path)
812812
example_tensors = embedding.prepare_tensors(example_sentences)
@@ -899,7 +899,7 @@ def create_from_embedding(cls, module: ScriptModule, embedding: "TransformerEmbe
899899

900900
@classmethod
901901
def parameter_to_list(
902-
cls, embedding: "TransformerEmbeddings", wrapper: torch.nn.Module, sentences: List[Sentence]
902+
cls, embedding: "TransformerEmbeddings", wrapper: torch.nn.Module, sentences: List[Sentence]
903903
) -> Tuple[List[str], List[torch.Tensor]]:
904904
tensors = embedding.prepare_tensors(sentences)
905905
param_names = list(inspect.signature(wrapper.forward).parameters.keys())
@@ -912,35 +912,35 @@ def parameter_to_list(
912912
@register_embeddings
913913
class TransformerJitWordEmbeddings(TokenEmbeddings, TransformerJitEmbeddings):
914914
def __init__(
915-
self,
916-
**kwargs,
915+
self,
916+
**kwargs,
917917
) -> None:
918918
TransformerJitEmbeddings.__init__(self, **kwargs)
919919

920920

921921
@register_embeddings
922922
class TransformerJitDocumentEmbeddings(DocumentEmbeddings, TransformerJitEmbeddings):
923923
def __init__(
924-
self,
925-
**kwargs,
924+
self,
925+
**kwargs,
926926
) -> None:
927927
TransformerJitEmbeddings.__init__(self, **kwargs)
928928

929929

930930
@register_embeddings
931931
class TransformerOnnxWordEmbeddings(TokenEmbeddings, TransformerOnnxEmbeddings):
932932
def __init__(
933-
self,
934-
**kwargs,
933+
self,
934+
**kwargs,
935935
) -> None:
936936
TransformerOnnxEmbeddings.__init__(self, **kwargs)
937937

938938

939939
@register_embeddings
940940
class TransformerOnnxDocumentEmbeddings(DocumentEmbeddings, TransformerOnnxEmbeddings):
941941
def __init__(
942-
self,
943-
**kwargs,
942+
self,
943+
**kwargs,
944944
) -> None:
945945
TransformerOnnxEmbeddings.__init__(self, **kwargs)
946946

@@ -950,27 +950,27 @@ class TransformerEmbeddings(TransformerBaseEmbeddings):
950950
onnx_cls: Type[TransformerOnnxEmbeddings] = TransformerOnnxEmbeddings
951951

952952
def __init__(
953-
self,
954-
model: str = "bert-base-uncased",
955-
fine_tune: bool = True,
956-
layers: str = "-1",
957-
layer_mean: bool = True,
958-
subtoken_pooling: str = "first",
959-
cls_pooling: str = "cls",
960-
is_token_embedding: bool = True,
961-
is_document_embedding: bool = True,
962-
allow_long_sentences: bool = False,
963-
use_context: Union[bool, int] = False,
964-
respect_document_boundaries: bool = True,
965-
context_dropout: float = 0.5,
966-
saved_config: Optional[PretrainedConfig] = None,
967-
tokenizer_data: Optional[BytesIO] = None,
968-
feature_extractor_data: Optional[BytesIO] = None,
969-
name: Optional[str] = None,
970-
force_max_length: bool = False,
971-
needs_manual_ocr: Optional[bool] = None,
972-
use_context_separator: bool = True,
973-
**kwargs,
953+
self,
954+
model: str = "bert-base-uncased",
955+
fine_tune: bool = True,
956+
layers: str = "-1",
957+
layer_mean: bool = True,
958+
subtoken_pooling: str = "first",
959+
cls_pooling: str = "cls",
960+
is_token_embedding: bool = True,
961+
is_document_embedding: bool = True,
962+
allow_long_sentences: bool = False,
963+
use_context: Union[bool, int] = False,
964+
respect_document_boundaries: bool = True,
965+
context_dropout: float = 0.5,
966+
saved_config: Optional[PretrainedConfig] = None,
967+
tokenizer_data: Optional[BytesIO] = None,
968+
feature_extractor_data: Optional[BytesIO] = None,
969+
name: Optional[str] = None,
970+
force_max_length: bool = False,
971+
needs_manual_ocr: Optional[bool] = None,
972+
use_context_separator: bool = True,
973+
**kwargs,
974974
) -> None:
975975
self.instance_parameters = self.get_instance_parameters(locals=locals())
976976
del self.instance_parameters["saved_config"]
@@ -1107,14 +1107,15 @@ def embedding_length(self) -> int:
11071107

11081108
return self.embedding_length_internal
11091109

1110-
1111-
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
1112-
missing_keys, unexpected_keys, error_msgs):
1110+
def _load_from_state_dict(
1111+
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
1112+
):
11131113
if transformers.__version__ >= Version(4, 31, 0):
11141114
assert isinstance(state_dict, dict)
11151115
state_dict.pop(f"{prefix}model.embeddings.position_ids", None)
1116-
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict,
1117-
missing_keys, unexpected_keys, error_msgs)
1116+
super()._load_from_state_dict(
1117+
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
1118+
)
11181119

11191120
def _has_initial_cls_token(self) -> bool:
11201121
# most models have CLS token as last token (GPT-1, GPT-2, TransfoXL, XLNet, XLM), but BERT is initial
@@ -1248,23 +1249,23 @@ def to_params(self):
12481249
def _can_document_embedding_shortcut(self):
12491250
# cls first pooling can be done without recreating sentence hidden states
12501251
return (
1251-
self.document_embedding
1252-
and not self.token_embedding
1253-
and self.cls_pooling == "cls"
1254-
and self.initial_cls_token
1252+
self.document_embedding
1253+
and not self.token_embedding
1254+
and self.cls_pooling == "cls"
1255+
and self.initial_cls_token
12551256
)
12561257

12571258
def forward(
1258-
self,
1259-
input_ids: torch.Tensor,
1260-
sub_token_lengths: Optional[torch.LongTensor] = None,
1261-
token_lengths: Optional[torch.LongTensor] = None,
1262-
attention_mask: Optional[torch.Tensor] = None,
1263-
overflow_to_sample_mapping: Optional[torch.Tensor] = None,
1264-
word_ids: Optional[torch.Tensor] = None,
1265-
langs: Optional[torch.Tensor] = None,
1266-
bbox: Optional[torch.Tensor] = None,
1267-
pixel_values: Optional[torch.Tensor] = None,
1259+
self,
1260+
input_ids: torch.Tensor,
1261+
sub_token_lengths: Optional[torch.LongTensor] = None,
1262+
token_lengths: Optional[torch.LongTensor] = None,
1263+
attention_mask: Optional[torch.Tensor] = None,
1264+
overflow_to_sample_mapping: Optional[torch.Tensor] = None,
1265+
word_ids: Optional[torch.Tensor] = None,
1266+
langs: Optional[torch.Tensor] = None,
1267+
bbox: Optional[torch.Tensor] = None,
1268+
pixel_values: Optional[torch.Tensor] = None,
12681269
):
12691270
model_kwargs = {}
12701271
if langs is not None:
@@ -1353,8 +1354,8 @@ def forward(
13531354
word_ids,
13541355
token_lengths,
13551356
)
1356-
all_token_embeddings[:, :, sentence_hidden_states.shape[2]:] = fill_masked_elements(
1357-
all_token_embeddings[:, :, sentence_hidden_states.shape[2]:],
1357+
all_token_embeddings[:, :, sentence_hidden_states.shape[2] :] = fill_masked_elements(
1358+
all_token_embeddings[:, :, sentence_hidden_states.shape[2] :],
13581359
sentence_hidden_states,
13591360
last_mask,
13601361
word_ids,
@@ -1374,7 +1375,7 @@ def _forward_tensors(self, tensors) -> Dict[str, torch.Tensor]:
13741375
return self.forward(**tensors)
13751376

13761377
def export_onnx(
1377-
self, path: Union[str, Path], example_sentences: List[Sentence], **kwargs
1378+
self, path: Union[str, Path], example_sentences: List[Sentence], **kwargs
13781379
) -> TransformerOnnxEmbeddings:
13791380
"""Export TransformerEmbeddings to OnnxFormat.
13801381

tests/test_datasets.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -748,7 +748,7 @@ def test_masakhane_corpus(tasks_base_path):
748748
"bam": {"train": 4462, "dev": 638, "test": 1274},
749749
"bbj": {"train": 3384, "dev": 483, "test": 966},
750750
"ewe": {"train": 3505, "dev": 501, "test": 1001},
751-
"fon": {"train": 4343, "dev": 621, "test": 1240},
751+
"fon": {"train": 4343, "dev": 623, "test": 1240},
752752
"hau": {"train": 5716, "dev": 816, "test": 1633},
753753
"ibo": {"train": 7634, "dev": 1090, "test": 2181},
754754
"kin": {"train": 7825, "dev": 1118, "test": 2235},

0 commit comments

Comments
 (0)