@@ -70,11 +70,11 @@ def truncate_hidden_states(hidden_states: torch.Tensor, input_ids: torch.Tensor)
70
70
71
71
@torch .jit .script_if_tracing
72
72
def combine_strided_tensors (
73
- hidden_states : torch .Tensor ,
74
- overflow_to_sample_mapping : torch .Tensor ,
75
- half_stride : int ,
76
- max_length : int ,
77
- default_value : int ,
73
+ hidden_states : torch .Tensor ,
74
+ overflow_to_sample_mapping : torch .Tensor ,
75
+ half_stride : int ,
76
+ max_length : int ,
77
+ default_value : int ,
78
78
) -> torch .Tensor :
79
79
_ , counts = torch .unique (overflow_to_sample_mapping , sorted = True , return_counts = True )
80
80
sentence_count = int (overflow_to_sample_mapping .max ().item () + 1 )
@@ -94,9 +94,9 @@ def combine_strided_tensors(
94
94
selected_sentences = hidden_states [overflow_to_sample_mapping == sentence_id ]
95
95
if selected_sentences .size (0 ) > 1 :
96
96
start_part = selected_sentences [0 , : half_stride + 1 ]
97
- mid_part = selected_sentences [:, half_stride + 1 : max_length - 1 - half_stride ]
97
+ mid_part = selected_sentences [:, half_stride + 1 : max_length - 1 - half_stride ]
98
98
mid_part = torch .reshape (mid_part , (mid_part .shape [0 ] * mid_part .shape [1 ],) + mid_part .shape [2 :])
99
- end_part = selected_sentences [selected_sentences .shape [0 ] - 1 , max_length - half_stride - 1 :]
99
+ end_part = selected_sentences [selected_sentences .shape [0 ] - 1 , max_length - half_stride - 1 :]
100
100
sentence_hidden_state = torch .cat ((start_part , mid_part , end_part ), dim = 0 )
101
101
sentence_hidden_states [sentence_id , : sentence_hidden_state .shape [0 ]] = torch .cat (
102
102
(start_part , mid_part , end_part ), dim = 0
@@ -109,11 +109,11 @@ def combine_strided_tensors(
109
109
110
110
@torch .jit .script_if_tracing
111
111
def fill_masked_elements (
112
- all_token_embeddings : torch .Tensor ,
113
- sentence_hidden_states : torch .Tensor ,
114
- mask : torch .Tensor ,
115
- word_ids : torch .Tensor ,
116
- lengths : torch .LongTensor ,
112
+ all_token_embeddings : torch .Tensor ,
113
+ sentence_hidden_states : torch .Tensor ,
114
+ mask : torch .Tensor ,
115
+ word_ids : torch .Tensor ,
116
+ lengths : torch .LongTensor ,
117
117
):
118
118
for i in torch .arange (int (all_token_embeddings .shape [0 ])):
119
119
r = insert_missing_embeddings (sentence_hidden_states [i ][mask [i ] & (word_ids [i ] >= 0 )], word_ids [i ], lengths [i ])
@@ -123,7 +123,7 @@ def fill_masked_elements(
123
123
124
124
@torch .jit .script_if_tracing
125
125
def insert_missing_embeddings (
126
- token_embeddings : torch .Tensor , word_id : torch .Tensor , length : torch .LongTensor
126
+ token_embeddings : torch .Tensor , word_id : torch .Tensor , length : torch .LongTensor
127
127
) -> torch .Tensor :
128
128
# in some cases we need to insert zero vectors for tokens without embedding.
129
129
if token_embeddings .shape [0 ] == 0 :
@@ -166,10 +166,10 @@ def insert_missing_embeddings(
166
166
167
167
@torch .jit .script_if_tracing
168
168
def fill_mean_token_embeddings (
169
- all_token_embeddings : torch .Tensor ,
170
- sentence_hidden_states : torch .Tensor ,
171
- word_ids : torch .Tensor ,
172
- token_lengths : torch .Tensor ,
169
+ all_token_embeddings : torch .Tensor ,
170
+ sentence_hidden_states : torch .Tensor ,
171
+ word_ids : torch .Tensor ,
172
+ token_lengths : torch .Tensor ,
173
173
):
174
174
for i in torch .arange (all_token_embeddings .shape [0 ]):
175
175
for _id in torch .arange (token_lengths [i ]): # type: ignore[call-overload]
@@ -196,7 +196,7 @@ def document_max_pooling(sentence_hidden_states: torch.Tensor, sentence_lengths:
196
196
197
197
198
198
def _legacy_reconstruct_word_ids (
199
- embedding : "TransformerBaseEmbeddings" , flair_tokens : List [List [str ]]
199
+ embedding : "TransformerBaseEmbeddings" , flair_tokens : List [List [str ]]
200
200
) -> List [List [Optional [int ]]]:
201
201
word_ids_list = []
202
202
max_len = 0
@@ -307,25 +307,25 @@ class TransformerBaseEmbeddings(Embeddings[Sentence]):
307
307
"""
308
308
309
309
def __init__ (
310
- self ,
311
- name : str ,
312
- tokenizer : PreTrainedTokenizer ,
313
- embedding_length : int ,
314
- context_length : int ,
315
- context_dropout : float ,
316
- respect_document_boundaries : bool ,
317
- stride : int ,
318
- allow_long_sentences : bool ,
319
- fine_tune : bool ,
320
- truncate : bool ,
321
- use_lang_emb : bool ,
322
- is_document_embedding : bool = False ,
323
- is_token_embedding : bool = False ,
324
- force_device : Optional [torch .device ] = None ,
325
- force_max_length : bool = False ,
326
- feature_extractor : Optional [FeatureExtractionMixin ] = None ,
327
- needs_manual_ocr : Optional [bool ] = None ,
328
- use_context_separator : bool = True ,
310
+ self ,
311
+ name : str ,
312
+ tokenizer : PreTrainedTokenizer ,
313
+ embedding_length : int ,
314
+ context_length : int ,
315
+ context_dropout : float ,
316
+ respect_document_boundaries : bool ,
317
+ stride : int ,
318
+ allow_long_sentences : bool ,
319
+ fine_tune : bool ,
320
+ truncate : bool ,
321
+ use_lang_emb : bool ,
322
+ is_document_embedding : bool = False ,
323
+ is_token_embedding : bool = False ,
324
+ force_device : Optional [torch .device ] = None ,
325
+ force_max_length : bool = False ,
326
+ feature_extractor : Optional [FeatureExtractionMixin ] = None ,
327
+ needs_manual_ocr : Optional [bool ] = None ,
328
+ use_context_separator : bool = True ,
329
329
) -> None :
330
330
self .name = name
331
331
super ().__init__ ()
@@ -473,32 +473,32 @@ def prepare_tensors(self, sentences: List[Sentence], device: Optional[torch.devi
473
473
474
474
# random check some tokens to save performance.
475
475
if (self .needs_manual_ocr or self .tokenizer_needs_ocr_boxes ) and not all (
476
- [
477
- flair_tokens [0 ][0 ].has_metadata ("bbox" ),
478
- flair_tokens [0 ][- 1 ].has_metadata ("bbox" ),
479
- flair_tokens [- 1 ][0 ].has_metadata ("bbox" ),
480
- flair_tokens [- 1 ][- 1 ].has_metadata ("bbox" ),
481
- ]
476
+ [
477
+ flair_tokens [0 ][0 ].has_metadata ("bbox" ),
478
+ flair_tokens [0 ][- 1 ].has_metadata ("bbox" ),
479
+ flair_tokens [- 1 ][0 ].has_metadata ("bbox" ),
480
+ flair_tokens [- 1 ][- 1 ].has_metadata ("bbox" ),
481
+ ]
482
482
):
483
483
raise ValueError (f"The embedding '{ self .name } ' requires the ocr 'bbox' set as metadata on all tokens." )
484
484
485
485
if self .feature_extractor is not None and not all (
486
- [
487
- sentences [0 ].has_metadata ("image" ),
488
- sentences [- 1 ].has_metadata ("image" ),
489
- ]
486
+ [
487
+ sentences [0 ].has_metadata ("image" ),
488
+ sentences [- 1 ].has_metadata ("image" ),
489
+ ]
490
490
):
491
491
raise ValueError (f"The embedding '{ self .name } ' requires the 'image' set as metadata for all sentences." )
492
492
493
493
return self .__build_transformer_model_inputs (sentences , offsets , lengths , flair_tokens , device )
494
494
495
495
def __build_transformer_model_inputs (
496
- self ,
497
- sentences : List [Sentence ],
498
- offsets : List [int ],
499
- sentence_lengths : List [int ],
500
- flair_tokens : List [List [Token ]],
501
- device : torch .device ,
496
+ self ,
497
+ sentences : List [Sentence ],
498
+ offsets : List [int ],
499
+ sentence_lengths : List [int ],
500
+ flair_tokens : List [List [Token ]],
501
+ device : torch .device ,
502
502
):
503
503
tokenizer_kwargs : Dict [str , Any ] = {}
504
504
if self .tokenizer_needs_ocr_boxes :
@@ -559,7 +559,7 @@ def __build_transformer_model_inputs(
559
559
sentence_idx = 0
560
560
for sentence , part_length in zip (sentences , sentence_part_lengths ):
561
561
lang_id = lang2id .get (sentence .get_language_code (), 0 )
562
- model_kwargs ["langs" ][sentence_idx : sentence_idx + part_length ] = lang_id
562
+ model_kwargs ["langs" ][sentence_idx : sentence_idx + part_length ] = lang_id
563
563
sentence_idx += part_length
564
564
565
565
if "bbox" in batch_encoding :
@@ -801,12 +801,12 @@ def collect_dynamic_axes(cls, embedding: "TransformerEmbeddings", tensors):
801
801
802
802
@classmethod
803
803
def export_from_embedding (
804
- cls ,
805
- path : Union [str , Path ],
806
- embedding : "TransformerEmbeddings" ,
807
- example_sentences : List [Sentence ],
808
- opset_version : int = 14 ,
809
- providers : Optional [List ] = None ,
804
+ cls ,
805
+ path : Union [str , Path ],
806
+ embedding : "TransformerEmbeddings" ,
807
+ example_sentences : List [Sentence ],
808
+ opset_version : int = 14 ,
809
+ providers : Optional [List ] = None ,
810
810
):
811
811
path = str (path )
812
812
example_tensors = embedding .prepare_tensors (example_sentences )
@@ -899,7 +899,7 @@ def create_from_embedding(cls, module: ScriptModule, embedding: "TransformerEmbe
899
899
900
900
@classmethod
901
901
def parameter_to_list (
902
- cls , embedding : "TransformerEmbeddings" , wrapper : torch .nn .Module , sentences : List [Sentence ]
902
+ cls , embedding : "TransformerEmbeddings" , wrapper : torch .nn .Module , sentences : List [Sentence ]
903
903
) -> Tuple [List [str ], List [torch .Tensor ]]:
904
904
tensors = embedding .prepare_tensors (sentences )
905
905
param_names = list (inspect .signature (wrapper .forward ).parameters .keys ())
@@ -912,35 +912,35 @@ def parameter_to_list(
912
912
@register_embeddings
913
913
class TransformerJitWordEmbeddings (TokenEmbeddings , TransformerJitEmbeddings ):
914
914
def __init__ (
915
- self ,
916
- ** kwargs ,
915
+ self ,
916
+ ** kwargs ,
917
917
) -> None :
918
918
TransformerJitEmbeddings .__init__ (self , ** kwargs )
919
919
920
920
921
921
@register_embeddings
922
922
class TransformerJitDocumentEmbeddings (DocumentEmbeddings , TransformerJitEmbeddings ):
923
923
def __init__ (
924
- self ,
925
- ** kwargs ,
924
+ self ,
925
+ ** kwargs ,
926
926
) -> None :
927
927
TransformerJitEmbeddings .__init__ (self , ** kwargs )
928
928
929
929
930
930
@register_embeddings
931
931
class TransformerOnnxWordEmbeddings (TokenEmbeddings , TransformerOnnxEmbeddings ):
932
932
def __init__ (
933
- self ,
934
- ** kwargs ,
933
+ self ,
934
+ ** kwargs ,
935
935
) -> None :
936
936
TransformerOnnxEmbeddings .__init__ (self , ** kwargs )
937
937
938
938
939
939
@register_embeddings
940
940
class TransformerOnnxDocumentEmbeddings (DocumentEmbeddings , TransformerOnnxEmbeddings ):
941
941
def __init__ (
942
- self ,
943
- ** kwargs ,
942
+ self ,
943
+ ** kwargs ,
944
944
) -> None :
945
945
TransformerOnnxEmbeddings .__init__ (self , ** kwargs )
946
946
@@ -950,27 +950,27 @@ class TransformerEmbeddings(TransformerBaseEmbeddings):
950
950
onnx_cls : Type [TransformerOnnxEmbeddings ] = TransformerOnnxEmbeddings
951
951
952
952
def __init__ (
953
- self ,
954
- model : str = "bert-base-uncased" ,
955
- fine_tune : bool = True ,
956
- layers : str = "-1" ,
957
- layer_mean : bool = True ,
958
- subtoken_pooling : str = "first" ,
959
- cls_pooling : str = "cls" ,
960
- is_token_embedding : bool = True ,
961
- is_document_embedding : bool = True ,
962
- allow_long_sentences : bool = False ,
963
- use_context : Union [bool , int ] = False ,
964
- respect_document_boundaries : bool = True ,
965
- context_dropout : float = 0.5 ,
966
- saved_config : Optional [PretrainedConfig ] = None ,
967
- tokenizer_data : Optional [BytesIO ] = None ,
968
- feature_extractor_data : Optional [BytesIO ] = None ,
969
- name : Optional [str ] = None ,
970
- force_max_length : bool = False ,
971
- needs_manual_ocr : Optional [bool ] = None ,
972
- use_context_separator : bool = True ,
973
- ** kwargs ,
953
+ self ,
954
+ model : str = "bert-base-uncased" ,
955
+ fine_tune : bool = True ,
956
+ layers : str = "-1" ,
957
+ layer_mean : bool = True ,
958
+ subtoken_pooling : str = "first" ,
959
+ cls_pooling : str = "cls" ,
960
+ is_token_embedding : bool = True ,
961
+ is_document_embedding : bool = True ,
962
+ allow_long_sentences : bool = False ,
963
+ use_context : Union [bool , int ] = False ,
964
+ respect_document_boundaries : bool = True ,
965
+ context_dropout : float = 0.5 ,
966
+ saved_config : Optional [PretrainedConfig ] = None ,
967
+ tokenizer_data : Optional [BytesIO ] = None ,
968
+ feature_extractor_data : Optional [BytesIO ] = None ,
969
+ name : Optional [str ] = None ,
970
+ force_max_length : bool = False ,
971
+ needs_manual_ocr : Optional [bool ] = None ,
972
+ use_context_separator : bool = True ,
973
+ ** kwargs ,
974
974
) -> None :
975
975
self .instance_parameters = self .get_instance_parameters (locals = locals ())
976
976
del self .instance_parameters ["saved_config" ]
@@ -1107,14 +1107,15 @@ def embedding_length(self) -> int:
1107
1107
1108
1108
return self .embedding_length_internal
1109
1109
1110
-
1111
- def _load_from_state_dict ( self , state_dict , prefix , local_metadata , strict ,
1112
- missing_keys , unexpected_keys , error_msgs ):
1110
+ def _load_from_state_dict (
1111
+ self , state_dict , prefix , local_metadata , strict , missing_keys , unexpected_keys , error_msgs
1112
+ ):
1113
1113
if transformers .__version__ >= Version (4 , 31 , 0 ):
1114
1114
assert isinstance (state_dict , dict )
1115
1115
state_dict .pop (f"{ prefix } model.embeddings.position_ids" , None )
1116
- super ()._load_from_state_dict (state_dict , prefix , local_metadata , strict ,
1117
- missing_keys , unexpected_keys , error_msgs )
1116
+ super ()._load_from_state_dict (
1117
+ state_dict , prefix , local_metadata , strict , missing_keys , unexpected_keys , error_msgs
1118
+ )
1118
1119
1119
1120
def _has_initial_cls_token (self ) -> bool :
1120
1121
# most models have CLS token as last token (GPT-1, GPT-2, TransfoXL, XLNet, XLM), but BERT is initial
@@ -1248,23 +1249,23 @@ def to_params(self):
1248
1249
def _can_document_embedding_shortcut (self ):
1249
1250
# cls first pooling can be done without recreating sentence hidden states
1250
1251
return (
1251
- self .document_embedding
1252
- and not self .token_embedding
1253
- and self .cls_pooling == "cls"
1254
- and self .initial_cls_token
1252
+ self .document_embedding
1253
+ and not self .token_embedding
1254
+ and self .cls_pooling == "cls"
1255
+ and self .initial_cls_token
1255
1256
)
1256
1257
1257
1258
def forward (
1258
- self ,
1259
- input_ids : torch .Tensor ,
1260
- sub_token_lengths : Optional [torch .LongTensor ] = None ,
1261
- token_lengths : Optional [torch .LongTensor ] = None ,
1262
- attention_mask : Optional [torch .Tensor ] = None ,
1263
- overflow_to_sample_mapping : Optional [torch .Tensor ] = None ,
1264
- word_ids : Optional [torch .Tensor ] = None ,
1265
- langs : Optional [torch .Tensor ] = None ,
1266
- bbox : Optional [torch .Tensor ] = None ,
1267
- pixel_values : Optional [torch .Tensor ] = None ,
1259
+ self ,
1260
+ input_ids : torch .Tensor ,
1261
+ sub_token_lengths : Optional [torch .LongTensor ] = None ,
1262
+ token_lengths : Optional [torch .LongTensor ] = None ,
1263
+ attention_mask : Optional [torch .Tensor ] = None ,
1264
+ overflow_to_sample_mapping : Optional [torch .Tensor ] = None ,
1265
+ word_ids : Optional [torch .Tensor ] = None ,
1266
+ langs : Optional [torch .Tensor ] = None ,
1267
+ bbox : Optional [torch .Tensor ] = None ,
1268
+ pixel_values : Optional [torch .Tensor ] = None ,
1268
1269
):
1269
1270
model_kwargs = {}
1270
1271
if langs is not None :
@@ -1353,8 +1354,8 @@ def forward(
1353
1354
word_ids ,
1354
1355
token_lengths ,
1355
1356
)
1356
- all_token_embeddings [:, :, sentence_hidden_states .shape [2 ]:] = fill_masked_elements (
1357
- all_token_embeddings [:, :, sentence_hidden_states .shape [2 ]:],
1357
+ all_token_embeddings [:, :, sentence_hidden_states .shape [2 ] :] = fill_masked_elements (
1358
+ all_token_embeddings [:, :, sentence_hidden_states .shape [2 ] :],
1358
1359
sentence_hidden_states ,
1359
1360
last_mask ,
1360
1361
word_ids ,
@@ -1374,7 +1375,7 @@ def _forward_tensors(self, tensors) -> Dict[str, torch.Tensor]:
1374
1375
return self .forward (** tensors )
1375
1376
1376
1377
def export_onnx (
1377
- self , path : Union [str , Path ], example_sentences : List [Sentence ], ** kwargs
1378
+ self , path : Union [str , Path ], example_sentences : List [Sentence ], ** kwargs
1378
1379
) -> TransformerOnnxEmbeddings :
1379
1380
"""Export TransformerEmbeddings to OnnxFormat.
1380
1381
0 commit comments