@@ -158,19 +158,26 @@ def __init__(self, config):
158
158
self .LayerNorm = BertLayerNorm (config .hidden_size , eps = config .layer_norm_eps )
159
159
self .dropout = nn .Dropout (config .hidden_dropout_prob )
160
160
161
- def forward (self , input_ids , token_type_ids = None , position_ids = None ):
162
- seq_length = input_ids .size (1 )
161
+ def forward (self , input_ids = None , token_type_ids = None , position_ids = None , inputs_embeds = None ):
162
+ if input_ids is not None :
163
+ input_shape = input_ids .size ()
164
+ else :
165
+ input_shape = inputs_embeds .size ()[:- 1 ]
166
+
167
+ seq_length = input_shape [1 ]
168
+ device = input_ids .device if input_ids is not None else inputs_embeds .device
163
169
if position_ids is None :
164
- position_ids = torch .arange (seq_length , dtype = torch .long , device = input_ids . device )
165
- position_ids = position_ids .unsqueeze (0 ).expand_as ( input_ids )
170
+ position_ids = torch .arange (seq_length , dtype = torch .long , device = device )
171
+ position_ids = position_ids .unsqueeze (0 ).expand ( input_shape )
166
172
if token_type_ids is None :
167
- token_type_ids = torch .zeros_like ( input_ids )
173
+ token_type_ids = torch .zeros ( input_shape , dtype = torch . long )
168
174
169
- words_embeddings = self .word_embeddings (input_ids )
175
+ if inputs_embeds is None :
176
+ inputs_embeds = self .word_embeddings (input_ids )
170
177
position_embeddings = self .position_embeddings (position_ids )
171
178
token_type_embeddings = self .token_type_embeddings (token_type_ids )
172
179
173
- embeddings = words_embeddings + position_embeddings + token_type_embeddings
180
+ embeddings = inputs_embeds + position_embeddings + token_type_embeddings
174
181
embeddings = self .LayerNorm (embeddings )
175
182
embeddings = self .dropout (embeddings )
176
183
return embeddings
@@ -550,6 +557,10 @@ def _init_weights(self, module):
550
557
Mask to nullify selected heads of the self-attention modules.
551
558
Mask values selected in ``[0, 1]``:
552
559
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
560
+ **inputs_embeds**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, embedding_dim)``:
561
+ Optionally, instead of passing ``input_ids`` you can choose to directly pass an embedded representation.
562
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
563
+ than the model's internal embedding lookup matrix.
553
564
**encoder_hidden_states**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)``:
554
565
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if the model
555
566
is configured as a decoder.
@@ -615,8 +626,8 @@ def _prune_heads(self, heads_to_prune):
615
626
for layer , heads in heads_to_prune .items ():
616
627
self .encoder .layer [layer ].attention .prune_heads (heads )
617
628
618
- def forward (self , input_ids , attention_mask = None , token_type_ids = None , position_ids = None ,
619
- head_mask = None , encoder_hidden_states = None , encoder_attention_mask = None ):
629
+ def forward (self , input_ids = None , attention_mask = None , token_type_ids = None , position_ids = None ,
630
+ head_mask = None , inputs_embeds = None , encoder_hidden_states = None , encoder_attention_mask = None ):
620
631
""" Forward pass on the Model.
621
632
622
633
The model can behave as an encoder (with only self-attention) as well
@@ -632,12 +643,23 @@ def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_
632
643
https://arxiv.org/abs/1706.03762
633
644
634
645
"""
646
+ if input_ids is not None and inputs_embeds is not None :
647
+ raise ValueError ("You cannot specify both input_ids and inputs_embeds at the same time" )
648
+ elif input_ids is not None :
649
+ input_shape = input_ids .size ()
650
+ elif inputs_embeds is not None :
651
+ input_shape = inputs_embeds .size ()[:- 1 ]
652
+ else :
653
+ raise ValueError ("You have to specify either input_ids or inputs_embeds" )
654
+
655
+ device = input_ids .device if input_ids is not None else inputs_embeds .device
656
+
635
657
if attention_mask is None :
636
- attention_mask = torch .ones_like ( input_ids )
658
+ attention_mask = torch .ones ( input_shape )
637
659
if encoder_attention_mask is None :
638
- encoder_attention_mask = torch .ones_like ( input_ids )
660
+ encoder_attention_mask = torch .ones ( input_shape )
639
661
if token_type_ids is None :
640
- token_type_ids = torch .zeros_like ( input_ids )
662
+ token_type_ids = torch .zeros ( input_shape , dtype = torch . long )
641
663
642
664
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
643
665
# ourselves in which case we just need to make it broadcastable to all heads.
@@ -649,8 +671,8 @@ def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_
649
671
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
650
672
if attention_mask .dim () == 2 :
651
673
if self .config .is_decoder :
652
- batch_size , seq_length = input_ids . size ()
653
- seq_ids = torch .arange (seq_length , device = input_ids . device )
674
+ batch_size , seq_length = input_shape
675
+ seq_ids = torch .arange (seq_length , device = device )
654
676
causal_mask = seq_ids [None , None , :].repeat (batch_size , seq_length , 1 ) <= seq_ids [None , :, None ]
655
677
extended_attention_mask = causal_mask [:, None , :, :] * attention_mask [:, None , None , :]
656
678
else :
@@ -689,7 +711,7 @@ def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_
689
711
else :
690
712
head_mask = [None ] * self .config .num_hidden_layers
691
713
692
- embedding_output = self .embeddings (input_ids , position_ids = position_ids , token_type_ids = token_type_ids )
714
+ embedding_output = self .embeddings (input_ids = input_ids , position_ids = position_ids , token_type_ids = token_type_ids , inputs_embeds = inputs_embeds )
693
715
encoder_outputs = self .encoder (embedding_output ,
694
716
attention_mask = extended_attention_mask ,
695
717
head_mask = head_mask ,
@@ -754,14 +776,15 @@ def __init__(self, config):
754
776
def get_output_embeddings (self ):
755
777
return self .cls .predictions .decoder
756
778
757
- def forward (self , input_ids , attention_mask = None , token_type_ids = None , position_ids = None , head_mask = None ,
779
+ def forward (self , input_ids = None , attention_mask = None , token_type_ids = None , position_ids = None , head_mask = None , inputs_embeds = None ,
758
780
masked_lm_labels = None , next_sentence_label = None ):
759
781
760
782
outputs = self .bert (input_ids ,
761
783
attention_mask = attention_mask ,
762
784
token_type_ids = token_type_ids ,
763
785
position_ids = position_ids ,
764
- head_mask = head_mask )
786
+ head_mask = head_mask ,
787
+ inputs_embeds = inputs_embeds )
765
788
766
789
sequence_output , pooled_output = outputs [:2 ]
767
790
prediction_scores , seq_relationship_score = self .cls (sequence_output , pooled_output )
@@ -829,14 +852,15 @@ def __init__(self, config):
829
852
def get_output_embeddings (self ):
830
853
return self .cls .predictions .decoder
831
854
832
- def forward (self , input_ids , attention_mask = None , token_type_ids = None , position_ids = None , head_mask = None ,
855
+ def forward (self , input_ids = None , attention_mask = None , token_type_ids = None , position_ids = None , head_mask = None , inputs_embeds = None ,
833
856
masked_lm_labels = None , encoder_hidden_states = None , encoder_attention_mask = None , lm_labels = None , ):
834
857
835
858
outputs = self .bert (input_ids ,
836
859
attention_mask = attention_mask ,
837
860
token_type_ids = token_type_ids ,
838
861
position_ids = position_ids ,
839
862
head_mask = head_mask ,
863
+ inputs_embeds = inputs_embeds ,
840
864
encoder_hidden_states = encoder_hidden_states ,
841
865
encoder_attention_mask = encoder_attention_mask )
842
866
@@ -908,14 +932,15 @@ def __init__(self, config):
908
932
909
933
self .init_weights ()
910
934
911
- def forward (self , input_ids , attention_mask = None , token_type_ids = None , position_ids = None , head_mask = None ,
935
+ def forward (self , input_ids = None , attention_mask = None , token_type_ids = None , position_ids = None , head_mask = None , inputs_embeds = None ,
912
936
next_sentence_label = None ):
913
937
914
938
outputs = self .bert (input_ids ,
915
939
attention_mask = attention_mask ,
916
940
token_type_ids = token_type_ids ,
917
941
position_ids = position_ids ,
918
- head_mask = head_mask )
942
+ head_mask = head_mask ,
943
+ inputs_embeds = inputs_embeds )
919
944
920
945
pooled_output = outputs [1 ]
921
946
@@ -975,14 +1000,15 @@ def __init__(self, config):
975
1000
976
1001
self .init_weights ()
977
1002
978
- def forward (self , input_ids , attention_mask = None , token_type_ids = None ,
979
- position_ids = None , head_mask = None , labels = None ):
1003
+ def forward (self , input_ids = None , attention_mask = None , token_type_ids = None ,
1004
+ position_ids = None , head_mask = None , inputs_embeds = None , labels = None ):
980
1005
981
1006
outputs = self .bert (input_ids ,
982
1007
attention_mask = attention_mask ,
983
1008
token_type_ids = token_type_ids ,
984
1009
position_ids = position_ids ,
985
- head_mask = head_mask )
1010
+ head_mask = head_mask ,
1011
+ inputs_embeds = inputs_embeds )
986
1012
987
1013
pooled_output = outputs [1 ]
988
1014
@@ -1049,8 +1075,8 @@ def __init__(self, config):
1049
1075
1050
1076
self .init_weights ()
1051
1077
1052
- def forward (self , input_ids , attention_mask = None , token_type_ids = None ,
1053
- position_ids = None , head_mask = None , labels = None ):
1078
+ def forward (self , input_ids = None , attention_mask = None , token_type_ids = None ,
1079
+ position_ids = None , head_mask = None , inputs_embeds = None , labels = None ):
1054
1080
num_choices = input_ids .shape [1 ]
1055
1081
1056
1082
input_ids = input_ids .view (- 1 , input_ids .size (- 1 ))
@@ -1062,7 +1088,8 @@ def forward(self, input_ids, attention_mask=None, token_type_ids=None,
1062
1088
attention_mask = attention_mask ,
1063
1089
token_type_ids = token_type_ids ,
1064
1090
position_ids = position_ids ,
1065
- head_mask = head_mask )
1091
+ head_mask = head_mask ,
1092
+ inputs_embeds = inputs_embeds )
1066
1093
1067
1094
pooled_output = outputs [1 ]
1068
1095
@@ -1123,14 +1150,15 @@ def __init__(self, config):
1123
1150
1124
1151
self .init_weights ()
1125
1152
1126
- def forward (self , input_ids , attention_mask = None , token_type_ids = None ,
1127
- position_ids = None , head_mask = None , labels = None ):
1153
+ def forward (self , input_ids = None , attention_mask = None , token_type_ids = None ,
1154
+ position_ids = None , head_mask = None , inputs_embeds = None , labels = None ):
1128
1155
1129
1156
outputs = self .bert (input_ids ,
1130
1157
attention_mask = attention_mask ,
1131
1158
token_type_ids = token_type_ids ,
1132
1159
position_ids = position_ids ,
1133
- head_mask = head_mask )
1160
+ head_mask = head_mask ,
1161
+ inputs_embeds = inputs_embeds )
1134
1162
1135
1163
sequence_output = outputs [0 ]
1136
1164
@@ -1207,14 +1235,15 @@ def __init__(self, config):
1207
1235
1208
1236
self .init_weights ()
1209
1237
1210
- def forward (self , input_ids , attention_mask = None , token_type_ids = None , position_ids = None , head_mask = None ,
1238
+ def forward (self , input_ids = None , attention_mask = None , token_type_ids = None , position_ids = None , head_mask = None , inputs_embeds = None ,
1211
1239
start_positions = None , end_positions = None ):
1212
1240
1213
1241
outputs = self .bert (input_ids ,
1214
1242
attention_mask = attention_mask ,
1215
1243
token_type_ids = token_type_ids ,
1216
1244
position_ids = position_ids ,
1217
- head_mask = head_mask )
1245
+ head_mask = head_mask ,
1246
+ inputs_embeds = inputs_embeds )
1218
1247
1219
1248
sequence_output = outputs [0 ]
1220
1249
0 commit comments