diff --git a/setup.py b/setup.py index 6df5c91..8419906 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ setuptools.setup( name="transformer_srl", # Replace with your own username - version="2.4rc5", + version="2.4rc6", author="Riccardo Orlando", author_email="orlandoricc@gmail.com", description="SRL Transformer model", diff --git a/transformer_srl/dataset_readers.py b/transformer_srl/dataset_readers.py index dc88dc2..28e5cd5 100644 --- a/transformer_srl/dataset_readers.py +++ b/transformer_srl/dataset_readers.py @@ -6,7 +6,7 @@ from allennlp.common.file_utils import cached_path from allennlp.data import Vocabulary from allennlp.data.dataset_readers.dataset_reader import DatasetReader -from allennlp.data.fields import Field, TextField, SequenceLabelField, MetadataField +from allennlp.data.fields import Field, TextField, SequenceLabelField, MetadataField, ArrayField from allennlp.data.instance import Instance from allennlp.data.token_indexers import ( TokenIndexer, @@ -18,7 +18,7 @@ from conllu import parse_incr from overrides import overrides from transformers import AutoTokenizer - +import numpy as np from transformer_srl.utils import load_label_list logger = logging.getLogger(__name__) @@ -294,6 +294,7 @@ def text_to_instance( # type: ignore ) new_verbs = _convert_verb_indices_to_wordpiece_indices(verb_label, offsets) frame_indicator = _convert_frames_indices_to_wordpiece_indices(verb_label, offsets, True) + sep_index = wordpieces.index(self.tokenizer.sep_token) metadata_dict["offsets"] = start_offsets # In order to override the indexing mechanism, we need to set the `text_id` # attribute directly. This causes the indexing to use this id. @@ -308,6 +309,7 @@ def text_to_instance( # type: ignore "tokens": text_field, "verb_indicator": verb_indicator, "frame_indicator": frame_indicator, + "sentence_end": ArrayField(np.array(sep_index + 1, dtype=np.int64), dtype=np.int64), } if all(x == 0 for x in verb_label): diff --git a/transformer_srl/models.py b/transformer_srl/models.py index 496d7cc..a1fe3f6 100644 --- a/transformer_srl/models.py +++ b/transformer_srl/models.py @@ -84,10 +84,10 @@ def __init__( self.span_metric = None self.f1_frame_metric = FBetaMeasure(average="micro") self.tag_projection_layer = Linear( - self.transformer.config.hidden_size * 2, self.num_classes + self.transformer.config.hidden_size, self.num_classes ) self.frame_projection_layer = Linear( - self.transformer.config.hidden_size * 2, self.frame_num_classes + self.transformer.config.hidden_size, self.frame_num_classes ) self.embedding_dropout = Dropout(p=embedding_dropout) self._label_smoothing = label_smoothing @@ -99,6 +99,7 @@ def forward( # type: ignore tokens: TextFieldTensors, verb_indicator: torch.Tensor, frame_indicator: torch.Tensor, + sentence_end: torch.LongTensor, metadata: List[Any], tags: torch.LongTensor = None, frame_tags: torch.LongTensor = None, @@ -142,31 +143,57 @@ def forward( # type: ignore loss : `torch.FloatTensor`, optional A scalar loss to be optimised. """ + # mask = get_text_field_mask(tokens) + # input_ids = util.get_token_ids_from_text_field_tensors(tokens) + # bert_embeddings, _ = self.transformer( + # input_ids=input_ids, + # # token_type_ids=verb_indicator, + # attention_mask=mask, + # ) + # # get sizes + # batch_size, sequence_length, _ = bert_embeddings.size() + + # # verb emeddings + # verb_embeddings = bert_embeddings[ + # torch.arange(batch_size).to(bert_embeddings.device), verb_indicator.argmax(1), : + # ] + # verb_embeddings = torch.where( + # (verb_indicator.sum(1, keepdim=True) > 0).repeat(1, verb_embeddings.shape[-1]), + # verb_embeddings, + # torch.zeros_like(verb_embeddings), + # ) + # bert_embeddings = torch.cat( + # (bert_embeddings, verb_embeddings.unsqueeze(1).repeat(1, bert_embeddings.shape[1], 1)), + # dim=2, + # ) + # mask = tokens["tokens"]["mask"] + # index = mask.sum(1).argmax().item() + mask = get_text_field_mask(tokens) - input_ids = util.get_token_ids_from_text_field_tensors(tokens) bert_embeddings, _ = self.transformer( - input_ids=input_ids, + input_ids=util.get_token_ids_from_text_field_tensors(tokens), # token_type_ids=verb_indicator, attention_mask=mask, ) - # get sizes - batch_size, sequence_length, _ = bert_embeddings.size() - # verb emeddings - verb_embeddings = bert_embeddings[ - torch.arange(batch_size).to(bert_embeddings.device), verb_indicator.argmax(1), : - ] - verb_embeddings = torch.where( - (verb_indicator.sum(1, keepdim=True) > 0).repeat(1, verb_embeddings.shape[-1]), - verb_embeddings, - torch.zeros_like(verb_embeddings), - ) - bert_embeddings = torch.cat( - (bert_embeddings, verb_embeddings.unsqueeze(1).repeat(1, bert_embeddings.shape[1], 1)), - dim=2, - ) - # extract embeddings + batch_size, _ = mask.size() embedded_text_input = self.embedding_dropout(bert_embeddings) + # Restrict to sentence part + sentence_mask = ( + torch.arange(mask.shape[1]).unsqueeze(0).repeat(batch_size, 1).to(mask.device) + < sentence_end.unsqueeze(1).repeat(1, mask.shape[1]) + ).long() + cutoff = sentence_end.max().item() + + # encoded_text = embedded_text_input + mask = sentence_mask[:, :cutoff].contiguous() + embedded_text_input = embedded_text_input[:, :cutoff, :] + tags = tags[:, :cutoff].contiguous() + + sequence_length = embedded_text_input.shape[1] + + # extract embeddings + # embedded_text_input = self.embedding_dropout(bert_embeddings) frame_embeddings = embedded_text_input[frame_indicator == 1] # outputs logits = self.tag_projection_layer(embedded_text_input)