Skip to content

Commit

Permalink
Add support for pretraining recurring span selection to Splinter (#17247
Browse files Browse the repository at this point in the history
)

* Add SplinterForSpanSelection for pre-training recurring span selection.

* Formatting.

* Rename SplinterForSpanSelection to SplinterForPreTraining.

* Ensure repo consistency

* Fixup changes

* Address SplinterForPreTraining PR comments

* Incorporate feedback and derive multiple question tokens per example.

* Update src/transformers/models/splinter/modeling_splinter.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Update src/transformers/models/splinter/modeling_splinter.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

Co-authored-by: Jean Vancoppenole <jean.vancoppenolle@retresco.de>
Co-authored-by: Tobias Günther <tobias.guenther@retresco.de>
Co-authored-by: Tobias Günther <github@tobigue.de>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
  • Loading branch information
5 people authored May 17, 2022
1 parent 0511305 commit bad3583
Show file tree
Hide file tree
Showing 6 changed files with 435 additions and 18 deletions.
5 changes: 5 additions & 0 deletions docs/source/en/model_doc/splinter.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,8 @@ This model was contributed by [yuvalkirstain](https://huggingface.co/yuvalkirsta

[[autodoc]] SplinterForQuestionAnswering
- forward

## SplinterForPreTraining

[[autodoc]] SplinterForPreTraining
- forward
2 changes: 2 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1532,6 +1532,7 @@
_import_structure["models.splinter"].extend(
[
"SPLINTER_PRETRAINED_MODEL_ARCHIVE_LIST",
"SplinterForPreTraining",
"SplinterForQuestionAnswering",
"SplinterLayer",
"SplinterModel",
Expand Down Expand Up @@ -3830,6 +3831,7 @@
from .models.speech_to_text_2 import Speech2Text2ForCausalLM, Speech2Text2PreTrainedModel
from .models.splinter import (
SPLINTER_PRETRAINED_MODEL_ARCHIVE_LIST,
SplinterForPreTraining,
SplinterForQuestionAnswering,
SplinterLayer,
SplinterModel,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/auto/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@
("openai-gpt", "OpenAIGPTLMHeadModel"),
("retribert", "RetriBertModel"),
("roberta", "RobertaForMaskedLM"),
("splinter", "SplinterForPreTraining"),
("squeezebert", "SqueezeBertForMaskedLM"),
("t5", "T5ForConditionalGeneration"),
("tapas", "TapasForMaskedLM"),
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/splinter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
_import_structure["modeling_splinter"] = [
"SPLINTER_PRETRAINED_MODEL_ARCHIVE_LIST",
"SplinterForQuestionAnswering",
"SplinterForPreTraining",
"SplinterLayer",
"SplinterModel",
"SplinterPreTrainedModel",
Expand All @@ -68,6 +69,7 @@
else:
from .modeling_splinter import (
SPLINTER_PRETRAINED_MODEL_ARCHIVE_LIST,
SplinterForPreTraining,
SplinterForQuestionAnswering,
SplinterLayer,
SplinterModel,
Expand Down
171 changes: 170 additions & 1 deletion src/transformers/models/splinter/modeling_splinter.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@


import math
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union

import torch
Expand All @@ -24,7 +25,7 @@
from torch.nn import CrossEntropyLoss

from ...activations import ACT2FN
from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, QuestionAnsweringModelOutput
from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, ModelOutput, QuestionAnsweringModelOutput
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
Expand Down Expand Up @@ -940,3 +941,171 @@ def forward(
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)


@dataclass
class SplinterForPreTrainingOutput(ModelOutput):
"""
Class for outputs of Splinter as a span selection model.
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when start and end positions are provided):
Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
start_logits (`torch.FloatTensor` of shape `(batch_size, num_questions, sequence_length)`):
Span-start scores (before SoftMax).
end_logits (`torch.FloatTensor` of shape `(batch_size, num_questions, sequence_length)`):
Span-end scores (before SoftMax).
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""

loss: Optional[torch.FloatTensor] = None
start_logits: torch.FloatTensor = None
end_logits: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None


@add_start_docstrings(
"""
Splinter Model for the recurring span selection task as done during the pretraining. The difference to the QA task
is that we do not have a question, but multiple question tokens that replace the occurrences of recurring spans
instead.
""",
SPLINTER_START_DOCSTRING,
)
class SplinterForPreTraining(SplinterPreTrainedModel):
def __init__(self, config):
super().__init__(config)

self.splinter = SplinterModel(config)
self.splinter_qass = QuestionAwareSpanSelectionHead(config)
self.question_token_id = config.question_token_id

# Initialize weights and apply final processing
self.post_init()

@add_start_docstrings_to_model_forward(
SPLINTER_INPUTS_DOCSTRING.format("batch_size, num_questions, sequence_length")
)
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
start_positions: Optional[torch.LongTensor] = None,
end_positions: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
question_positions: Optional[torch.LongTensor] = None,
) -> Union[Tuple, SplinterForPreTrainingOutput]:
r"""
start_positions (`torch.LongTensor` of shape `(batch_size, num_questions)`, *optional*):
Labels for position (index) of the start of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
are not taken into account for computing the loss.
end_positions (`torch.LongTensor` of shape `(batch_size, num_questions)`, *optional*):
Labels for position (index) of the end of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
are not taken into account for computing the loss.
question_positions (`torch.LongTensor` of shape `(batch_size, num_questions)`, *optional*):
The positions of all question tokens. If given, start_logits and end_logits will be of shape `(batch_size,
num_questions, sequence_length)`. If None, the first question token in each sequence in the batch will be
the only one for which start_logits and end_logits are calculated and they will be of shape `(batch_size,
sequence_length)`.
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

if question_positions is None and start_positions is not None and end_positions is not None:
raise TypeError("question_positions must be specified in order to calculate the loss")

elif question_positions is None and input_ids is None:
raise TypeError("question_positions must be specified when input_embeds is used")

elif question_positions is None:
question_positions = self._prepare_question_positions(input_ids)

outputs = self.splinter(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)

sequence_output = outputs[0]
batch_size, sequence_length, dim = sequence_output.size()
# [batch_size, num_questions, sequence_length]
start_logits, end_logits = self.splinter_qass(sequence_output, question_positions)

num_questions = question_positions.size(1)
if attention_mask is not None:
attention_mask_for_each_question = attention_mask.unsqueeze(1).expand(
batch_size, num_questions, sequence_length
)
start_logits = start_logits + (1 - attention_mask_for_each_question) * -10000.0
end_logits = end_logits + (1 - attention_mask_for_each_question) * -10000.0

total_loss = None
# [batch_size, num_questions, sequence_length]
if start_positions is not None and end_positions is not None:
# sometimes the start/end positions are outside our model inputs, we ignore these terms
start_positions.clamp_(0, max(0, sequence_length - 1))
end_positions.clamp_(0, max(0, sequence_length - 1))

# Ignore zero positions in the loss. Splinter never predicts zero
# during pretraining and zero is used for padding question
# tokens as well as for start and end positions of padded
# question tokens.
loss_fct = CrossEntropyLoss(ignore_index=self.config.pad_token_id)
start_loss = loss_fct(
start_logits.view(batch_size * num_questions, sequence_length),
start_positions.view(batch_size * num_questions),
)
end_loss = loss_fct(
end_logits.view(batch_size * num_questions, sequence_length),
end_positions.view(batch_size * num_questions),
)
total_loss = (start_loss + end_loss) / 2

if not return_dict:
output = (start_logits, end_logits) + outputs[1:]
return ((total_loss,) + output) if total_loss is not None else output

return SplinterForPreTrainingOutput(
loss=total_loss,
start_logits=start_logits,
end_logits=end_logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)

def _prepare_question_positions(self, input_ids: torch.Tensor) -> torch.Tensor:
rows, flat_positions = torch.where(input_ids == self.config.question_token_id)
num_questions = torch.bincount(rows)
positions = torch.full(
(input_ids.size(0), num_questions.max()),
self.config.pad_token_id,
dtype=torch.long,
device=input_ids.device,
)
cols = torch.cat([torch.arange(n) for n in num_questions])
positions[rows, cols] = flat_positions
return positions
Loading

0 comments on commit bad3583

Please sign in to comment.