Skip to content

model forwards can take an inputs_embeds param #1695

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Nov 5, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions templates/adding_a_new_model/modeling_tf_xxx.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,10 @@ class TFXxxPreTrainedModel(TFPreTrainedModel):
Mask to nullify selected heads of the self-attention modules.
Mask values selected in ``[0, 1]``:
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
**inputs_embeds**: (`optional`) ``Numpy array`` or ``tf.Tensor`` of shape ``(batch_size, sequence_length, embedding_dim)``:
Optionally, instead of passing ``input_ids`` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
"""

@add_start_docstrings("The bare Xxx Model transformer outputing raw hidden-states without any specific head on top.",
Expand Down
25 changes: 16 additions & 9 deletions templates/adding_a_new_model/modeling_xxx.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,10 @@ def _init_weights(self, module):
Mask to nullify selected heads of the self-attention modules.
Mask values selected in ``[0, 1]``:
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
**inputs_embeds**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, embedding_dim)``:
Optionally, instead of passing ``input_ids`` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
"""

@add_start_docstrings("The bare Xxx Model transformer outputting raw hidden-states without any specific head on top.",
Expand Down Expand Up @@ -295,7 +299,7 @@ def _prune_heads(self, heads_to_prune):
for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads)

def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None):
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None):
if attention_mask is None:
attention_mask = torch.ones_like(input_ids)
if token_type_ids is None:
Expand Down Expand Up @@ -449,14 +453,15 @@ def __init__(self, config):

self.init_weights()

def forward(self, input_ids, attention_mask=None, token_type_ids=None,
position_ids=None, head_mask=None, labels=None):
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None,
position_ids=None, head_mask=None, inputs_embeds=None, labels=None):

outputs = self.transformer(input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask)
head_mask=head_mask,
inputs_embeds=inputs_embeds)

pooled_output = outputs[1]

Expand Down Expand Up @@ -520,14 +525,15 @@ def __init__(self, config):

self.init_weights()

def forward(self, input_ids, attention_mask=None, token_type_ids=None,
position_ids=None, head_mask=None, labels=None):
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None,
position_ids=None, head_mask=None, inputs_embeds=None, labels=None):

outputs = self.transformer(input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask)
head_mask=head_mask,
inputs_embeds=inputs_embeds)

sequence_output = outputs[0]

Expand Down Expand Up @@ -603,14 +609,15 @@ def __init__(self, config):

self.init_weights()

def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None,
start_positions=None, end_positions=None):

outputs = self.transformer(input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask)
head_mask=head_mask,
inputs_embeds=inputs_embeds)

sequence_output = outputs[0]

Expand Down
91 changes: 60 additions & 31 deletions transformers/modeling_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,19 +158,26 @@ def __init__(self, config):
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)

def forward(self, input_ids, token_type_ids=None, position_ids=None):
seq_length = input_ids.size(1)
def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None):
if input_ids is not None:
input_shape = input_ids.size()
else:
input_shape = inputs_embeds.size()[:-1]

seq_length = input_shape[1]
device = input_ids.device if input_ids is not None else inputs_embeds.device
if position_ids is None:
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).expand(input_shape)
if token_type_ids is None:
token_type_ids = torch.zeros_like(input_ids)
token_type_ids = torch.zeros(input_shape, dtype=torch.long)

words_embeddings = self.word_embeddings(input_ids)
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
position_embeddings = self.position_embeddings(position_ids)
token_type_embeddings = self.token_type_embeddings(token_type_ids)

embeddings = words_embeddings + position_embeddings + token_type_embeddings
embeddings = inputs_embeds + position_embeddings + token_type_embeddings
embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings)
return embeddings
Expand Down Expand Up @@ -550,6 +557,10 @@ def _init_weights(self, module):
Mask to nullify selected heads of the self-attention modules.
Mask values selected in ``[0, 1]``:
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
**inputs_embeds**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, embedding_dim)``:
Optionally, instead of passing ``input_ids`` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
**encoder_hidden_states**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)``:
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if the model
is configured as a decoder.
Expand Down Expand Up @@ -615,8 +626,8 @@ def _prune_heads(self, heads_to_prune):
for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads)

def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None,
head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None):
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None,
head_mask=None, inputs_embeds=None, encoder_hidden_states=None, encoder_attention_mask=None):
""" Forward pass on the Model.

The model can behave as an encoder (with only self-attention) as well
Expand All @@ -632,12 +643,23 @@ def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_
https://arxiv.org/abs/1706.03762

"""
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size()
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")

device = input_ids.device if input_ids is not None else inputs_embeds.device

if attention_mask is None:
attention_mask = torch.ones_like(input_ids)
attention_mask = torch.ones(input_shape)
if encoder_attention_mask is None:
encoder_attention_mask = torch.ones_like(input_ids)
encoder_attention_mask = torch.ones(input_shape)
if token_type_ids is None:
token_type_ids = torch.zeros_like(input_ids)
token_type_ids = torch.zeros(input_shape, dtype=torch.long)

# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
Expand All @@ -649,8 +671,8 @@ def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
if attention_mask.dim() == 2:
if self.config.is_decoder:
batch_size, seq_length = input_ids.size()
seq_ids = torch.arange(seq_length, device=input_ids.device)
batch_size, seq_length = input_shape
seq_ids = torch.arange(seq_length, device=device)
causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
else:
Expand Down Expand Up @@ -689,7 +711,7 @@ def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_
else:
head_mask = [None] * self.config.num_hidden_layers

embedding_output = self.embeddings(input_ids, position_ids=position_ids, token_type_ids=token_type_ids)
embedding_output = self.embeddings(input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds)
encoder_outputs = self.encoder(embedding_output,
attention_mask=extended_attention_mask,
head_mask=head_mask,
Expand Down Expand Up @@ -754,14 +776,15 @@ def __init__(self, config):
def get_output_embeddings(self):
return self.cls.predictions.decoder

def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None,
masked_lm_labels=None, next_sentence_label=None):

outputs = self.bert(input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask)
head_mask=head_mask,
inputs_embeds=inputs_embeds)

sequence_output, pooled_output = outputs[:2]
prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
Expand Down Expand Up @@ -829,14 +852,15 @@ def __init__(self, config):
def get_output_embeddings(self):
return self.cls.predictions.decoder

def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None,
masked_lm_labels=None, encoder_hidden_states=None, encoder_attention_mask=None, lm_labels=None, ):

outputs = self.bert(input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask)

Expand Down Expand Up @@ -908,14 +932,15 @@ def __init__(self, config):

self.init_weights()

def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None,
next_sentence_label=None):

outputs = self.bert(input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask)
head_mask=head_mask,
inputs_embeds=inputs_embeds)

pooled_output = outputs[1]

Expand Down Expand Up @@ -975,14 +1000,15 @@ def __init__(self, config):

self.init_weights()

def forward(self, input_ids, attention_mask=None, token_type_ids=None,
position_ids=None, head_mask=None, labels=None):
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None,
position_ids=None, head_mask=None, inputs_embeds=None, labels=None):

outputs = self.bert(input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask)
head_mask=head_mask,
inputs_embeds=inputs_embeds)

pooled_output = outputs[1]

Expand Down Expand Up @@ -1049,8 +1075,8 @@ def __init__(self, config):

self.init_weights()

def forward(self, input_ids, attention_mask=None, token_type_ids=None,
position_ids=None, head_mask=None, labels=None):
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None,
position_ids=None, head_mask=None, inputs_embeds=None, labels=None):
num_choices = input_ids.shape[1]

input_ids = input_ids.view(-1, input_ids.size(-1))
Expand All @@ -1062,7 +1088,8 @@ def forward(self, input_ids, attention_mask=None, token_type_ids=None,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask)
head_mask=head_mask,
inputs_embeds=inputs_embeds)

pooled_output = outputs[1]

Expand Down Expand Up @@ -1123,14 +1150,15 @@ def __init__(self, config):

self.init_weights()

def forward(self, input_ids, attention_mask=None, token_type_ids=None,
position_ids=None, head_mask=None, labels=None):
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None,
position_ids=None, head_mask=None, inputs_embeds=None, labels=None):

outputs = self.bert(input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask)
head_mask=head_mask,
inputs_embeds=inputs_embeds)

sequence_output = outputs[0]

Expand Down Expand Up @@ -1207,14 +1235,15 @@ def __init__(self, config):

self.init_weights()

def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None,
start_positions=None, end_positions=None):

outputs = self.bert(input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask)
head_mask=head_mask,
inputs_embeds=inputs_embeds)

sequence_output = outputs[0]

Expand Down
Loading