Skip to content

Commit 7daacf0

Browse files
authored
Merge pull request huggingface#1695 from huggingface/models_inputs_embeds
model forwards can take an inputs_embeds param
2 parents a44f112 + 00337e9 commit 7daacf0

23 files changed

+385
-156
lines changed

templates/adding_a_new_model/modeling_tf_xxx.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,10 @@ class TFXxxPreTrainedModel(TFPreTrainedModel):
255255
Mask to nullify selected heads of the self-attention modules.
256256
Mask values selected in ``[0, 1]``:
257257
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
258+
**inputs_embeds**: (`optional`) ``Numpy array`` or ``tf.Tensor`` of shape ``(batch_size, sequence_length, embedding_dim)``:
259+
Optionally, instead of passing ``input_ids`` you can choose to directly pass an embedded representation.
260+
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
261+
than the model's internal embedding lookup matrix.
258262
"""
259263

260264
@add_start_docstrings("The bare Xxx Model transformer outputing raw hidden-states without any specific head on top.",

templates/adding_a_new_model/modeling_xxx.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,10 @@ def _init_weights(self, module):
238238
Mask to nullify selected heads of the self-attention modules.
239239
Mask values selected in ``[0, 1]``:
240240
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
241+
**inputs_embeds**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, embedding_dim)``:
242+
Optionally, instead of passing ``input_ids`` you can choose to directly pass an embedded representation.
243+
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
244+
than the model's internal embedding lookup matrix.
241245
"""
242246

243247
@add_start_docstrings("The bare Xxx Model transformer outputting raw hidden-states without any specific head on top.",
@@ -295,7 +299,7 @@ def _prune_heads(self, heads_to_prune):
295299
for layer, heads in heads_to_prune.items():
296300
self.encoder.layer[layer].attention.prune_heads(heads)
297301

298-
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None):
302+
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None):
299303
if attention_mask is None:
300304
attention_mask = torch.ones_like(input_ids)
301305
if token_type_ids is None:
@@ -449,14 +453,15 @@ def __init__(self, config):
449453

450454
self.init_weights()
451455

452-
def forward(self, input_ids, attention_mask=None, token_type_ids=None,
453-
position_ids=None, head_mask=None, labels=None):
456+
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None,
457+
position_ids=None, head_mask=None, inputs_embeds=None, labels=None):
454458

455459
outputs = self.transformer(input_ids,
456460
attention_mask=attention_mask,
457461
token_type_ids=token_type_ids,
458462
position_ids=position_ids,
459-
head_mask=head_mask)
463+
head_mask=head_mask,
464+
inputs_embeds=inputs_embeds)
460465

461466
pooled_output = outputs[1]
462467

@@ -520,14 +525,15 @@ def __init__(self, config):
520525

521526
self.init_weights()
522527

523-
def forward(self, input_ids, attention_mask=None, token_type_ids=None,
524-
position_ids=None, head_mask=None, labels=None):
528+
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None,
529+
position_ids=None, head_mask=None, inputs_embeds=None, labels=None):
525530

526531
outputs = self.transformer(input_ids,
527532
attention_mask=attention_mask,
528533
token_type_ids=token_type_ids,
529534
position_ids=position_ids,
530-
head_mask=head_mask)
535+
head_mask=head_mask,
536+
inputs_embeds=inputs_embeds)
531537

532538
sequence_output = outputs[0]
533539

@@ -603,14 +609,15 @@ def __init__(self, config):
603609

604610
self.init_weights()
605611

606-
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
612+
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None,
607613
start_positions=None, end_positions=None):
608614

609615
outputs = self.transformer(input_ids,
610616
attention_mask=attention_mask,
611617
token_type_ids=token_type_ids,
612618
position_ids=position_ids,
613-
head_mask=head_mask)
619+
head_mask=head_mask,
620+
inputs_embeds=inputs_embeds)
614621

615622
sequence_output = outputs[0]
616623

transformers/modeling_bert.py

Lines changed: 60 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -158,19 +158,26 @@ def __init__(self, config):
158158
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
159159
self.dropout = nn.Dropout(config.hidden_dropout_prob)
160160

161-
def forward(self, input_ids, token_type_ids=None, position_ids=None):
162-
seq_length = input_ids.size(1)
161+
def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None):
162+
if input_ids is not None:
163+
input_shape = input_ids.size()
164+
else:
165+
input_shape = inputs_embeds.size()[:-1]
166+
167+
seq_length = input_shape[1]
168+
device = input_ids.device if input_ids is not None else inputs_embeds.device
163169
if position_ids is None:
164-
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
165-
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
170+
position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
171+
position_ids = position_ids.unsqueeze(0).expand(input_shape)
166172
if token_type_ids is None:
167-
token_type_ids = torch.zeros_like(input_ids)
173+
token_type_ids = torch.zeros(input_shape, dtype=torch.long)
168174

169-
words_embeddings = self.word_embeddings(input_ids)
175+
if inputs_embeds is None:
176+
inputs_embeds = self.word_embeddings(input_ids)
170177
position_embeddings = self.position_embeddings(position_ids)
171178
token_type_embeddings = self.token_type_embeddings(token_type_ids)
172179

173-
embeddings = words_embeddings + position_embeddings + token_type_embeddings
180+
embeddings = inputs_embeds + position_embeddings + token_type_embeddings
174181
embeddings = self.LayerNorm(embeddings)
175182
embeddings = self.dropout(embeddings)
176183
return embeddings
@@ -550,6 +557,10 @@ def _init_weights(self, module):
550557
Mask to nullify selected heads of the self-attention modules.
551558
Mask values selected in ``[0, 1]``:
552559
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
560+
**inputs_embeds**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, embedding_dim)``:
561+
Optionally, instead of passing ``input_ids`` you can choose to directly pass an embedded representation.
562+
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
563+
than the model's internal embedding lookup matrix.
553564
**encoder_hidden_states**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)``:
554565
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if the model
555566
is configured as a decoder.
@@ -615,8 +626,8 @@ def _prune_heads(self, heads_to_prune):
615626
for layer, heads in heads_to_prune.items():
616627
self.encoder.layer[layer].attention.prune_heads(heads)
617628

618-
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None,
619-
head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None):
629+
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None,
630+
head_mask=None, inputs_embeds=None, encoder_hidden_states=None, encoder_attention_mask=None):
620631
""" Forward pass on the Model.
621632
622633
The model can behave as an encoder (with only self-attention) as well
@@ -632,12 +643,23 @@ def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_
632643
https://arxiv.org/abs/1706.03762
633644
634645
"""
646+
if input_ids is not None and inputs_embeds is not None:
647+
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
648+
elif input_ids is not None:
649+
input_shape = input_ids.size()
650+
elif inputs_embeds is not None:
651+
input_shape = inputs_embeds.size()[:-1]
652+
else:
653+
raise ValueError("You have to specify either input_ids or inputs_embeds")
654+
655+
device = input_ids.device if input_ids is not None else inputs_embeds.device
656+
635657
if attention_mask is None:
636-
attention_mask = torch.ones_like(input_ids)
658+
attention_mask = torch.ones(input_shape)
637659
if encoder_attention_mask is None:
638-
encoder_attention_mask = torch.ones_like(input_ids)
660+
encoder_attention_mask = torch.ones(input_shape)
639661
if token_type_ids is None:
640-
token_type_ids = torch.zeros_like(input_ids)
662+
token_type_ids = torch.zeros(input_shape, dtype=torch.long)
641663

642664
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
643665
# ourselves in which case we just need to make it broadcastable to all heads.
@@ -649,8 +671,8 @@ def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_
649671
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
650672
if attention_mask.dim() == 2:
651673
if self.config.is_decoder:
652-
batch_size, seq_length = input_ids.size()
653-
seq_ids = torch.arange(seq_length, device=input_ids.device)
674+
batch_size, seq_length = input_shape
675+
seq_ids = torch.arange(seq_length, device=device)
654676
causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
655677
extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
656678
else:
@@ -689,7 +711,7 @@ def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_
689711
else:
690712
head_mask = [None] * self.config.num_hidden_layers
691713

692-
embedding_output = self.embeddings(input_ids, position_ids=position_ids, token_type_ids=token_type_ids)
714+
embedding_output = self.embeddings(input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds)
693715
encoder_outputs = self.encoder(embedding_output,
694716
attention_mask=extended_attention_mask,
695717
head_mask=head_mask,
@@ -754,14 +776,15 @@ def __init__(self, config):
754776
def get_output_embeddings(self):
755777
return self.cls.predictions.decoder
756778

757-
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
779+
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None,
758780
masked_lm_labels=None, next_sentence_label=None):
759781

760782
outputs = self.bert(input_ids,
761783
attention_mask=attention_mask,
762784
token_type_ids=token_type_ids,
763785
position_ids=position_ids,
764-
head_mask=head_mask)
786+
head_mask=head_mask,
787+
inputs_embeds=inputs_embeds)
765788

766789
sequence_output, pooled_output = outputs[:2]
767790
prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
@@ -829,14 +852,15 @@ def __init__(self, config):
829852
def get_output_embeddings(self):
830853
return self.cls.predictions.decoder
831854

832-
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
855+
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None,
833856
masked_lm_labels=None, encoder_hidden_states=None, encoder_attention_mask=None, lm_labels=None, ):
834857

835858
outputs = self.bert(input_ids,
836859
attention_mask=attention_mask,
837860
token_type_ids=token_type_ids,
838861
position_ids=position_ids,
839862
head_mask=head_mask,
863+
inputs_embeds=inputs_embeds,
840864
encoder_hidden_states=encoder_hidden_states,
841865
encoder_attention_mask=encoder_attention_mask)
842866

@@ -908,14 +932,15 @@ def __init__(self, config):
908932

909933
self.init_weights()
910934

911-
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
935+
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None,
912936
next_sentence_label=None):
913937

914938
outputs = self.bert(input_ids,
915939
attention_mask=attention_mask,
916940
token_type_ids=token_type_ids,
917941
position_ids=position_ids,
918-
head_mask=head_mask)
942+
head_mask=head_mask,
943+
inputs_embeds=inputs_embeds)
919944

920945
pooled_output = outputs[1]
921946

@@ -975,14 +1000,15 @@ def __init__(self, config):
9751000

9761001
self.init_weights()
9771002

978-
def forward(self, input_ids, attention_mask=None, token_type_ids=None,
979-
position_ids=None, head_mask=None, labels=None):
1003+
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None,
1004+
position_ids=None, head_mask=None, inputs_embeds=None, labels=None):
9801005

9811006
outputs = self.bert(input_ids,
9821007
attention_mask=attention_mask,
9831008
token_type_ids=token_type_ids,
9841009
position_ids=position_ids,
985-
head_mask=head_mask)
1010+
head_mask=head_mask,
1011+
inputs_embeds=inputs_embeds)
9861012

9871013
pooled_output = outputs[1]
9881014

@@ -1049,8 +1075,8 @@ def __init__(self, config):
10491075

10501076
self.init_weights()
10511077

1052-
def forward(self, input_ids, attention_mask=None, token_type_ids=None,
1053-
position_ids=None, head_mask=None, labels=None):
1078+
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None,
1079+
position_ids=None, head_mask=None, inputs_embeds=None, labels=None):
10541080
num_choices = input_ids.shape[1]
10551081

10561082
input_ids = input_ids.view(-1, input_ids.size(-1))
@@ -1062,7 +1088,8 @@ def forward(self, input_ids, attention_mask=None, token_type_ids=None,
10621088
attention_mask=attention_mask,
10631089
token_type_ids=token_type_ids,
10641090
position_ids=position_ids,
1065-
head_mask=head_mask)
1091+
head_mask=head_mask,
1092+
inputs_embeds=inputs_embeds)
10661093

10671094
pooled_output = outputs[1]
10681095

@@ -1123,14 +1150,15 @@ def __init__(self, config):
11231150

11241151
self.init_weights()
11251152

1126-
def forward(self, input_ids, attention_mask=None, token_type_ids=None,
1127-
position_ids=None, head_mask=None, labels=None):
1153+
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None,
1154+
position_ids=None, head_mask=None, inputs_embeds=None, labels=None):
11281155

11291156
outputs = self.bert(input_ids,
11301157
attention_mask=attention_mask,
11311158
token_type_ids=token_type_ids,
11321159
position_ids=position_ids,
1133-
head_mask=head_mask)
1160+
head_mask=head_mask,
1161+
inputs_embeds=inputs_embeds)
11341162

11351163
sequence_output = outputs[0]
11361164

@@ -1207,14 +1235,15 @@ def __init__(self, config):
12071235

12081236
self.init_weights()
12091237

1210-
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
1238+
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None,
12111239
start_positions=None, end_positions=None):
12121240

12131241
outputs = self.bert(input_ids,
12141242
attention_mask=attention_mask,
12151243
token_type_ids=token_type_ids,
12161244
position_ids=position_ids,
1217-
head_mask=head_mask)
1245+
head_mask=head_mask,
1246+
inputs_embeds=inputs_embeds)
12181247

12191248
sequence_output = outputs[0]
12201249

0 commit comments

Comments
 (0)