From dd4aedab6b7a25c887ca8e715ec17d80ec71c81e Mon Sep 17 00:00:00 2001 From: Riccardo Orlando Date: Fri, 16 Oct 2020 17:50:26 +0200 Subject: [PATCH] verb embeddings --- setup.py | 2 +- transformer_srl/dataset_readers.py | 6 +-- transformer_srl/models.py | 67 +++++++++--------------------- 3 files changed, 23 insertions(+), 52 deletions(-) diff --git a/setup.py b/setup.py index 8419906..9c8133b 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ setuptools.setup( name="transformer_srl", # Replace with your own username - version="2.4rc6", + version="2.4rc7", author="Riccardo Orlando", author_email="orlandoricc@gmail.com", description="SRL Transformer model", diff --git a/transformer_srl/dataset_readers.py b/transformer_srl/dataset_readers.py index 28e5cd5..dc88dc2 100644 --- a/transformer_srl/dataset_readers.py +++ b/transformer_srl/dataset_readers.py @@ -6,7 +6,7 @@ from allennlp.common.file_utils import cached_path from allennlp.data import Vocabulary from allennlp.data.dataset_readers.dataset_reader import DatasetReader -from allennlp.data.fields import Field, TextField, SequenceLabelField, MetadataField, ArrayField +from allennlp.data.fields import Field, TextField, SequenceLabelField, MetadataField from allennlp.data.instance import Instance from allennlp.data.token_indexers import ( TokenIndexer, @@ -18,7 +18,7 @@ from conllu import parse_incr from overrides import overrides from transformers import AutoTokenizer -import numpy as np + from transformer_srl.utils import load_label_list logger = logging.getLogger(__name__) @@ -294,7 +294,6 @@ def text_to_instance( # type: ignore ) new_verbs = _convert_verb_indices_to_wordpiece_indices(verb_label, offsets) frame_indicator = _convert_frames_indices_to_wordpiece_indices(verb_label, offsets, True) - sep_index = wordpieces.index(self.tokenizer.sep_token) metadata_dict["offsets"] = start_offsets # In order to override the indexing mechanism, we need to set the `text_id` # attribute directly. This causes the indexing to use this id. @@ -309,7 +308,6 @@ def text_to_instance( # type: ignore "tokens": text_field, "verb_indicator": verb_indicator, "frame_indicator": frame_indicator, - "sentence_end": ArrayField(np.array(sep_index + 1, dtype=np.int64), dtype=np.int64), } if all(x == 0 for x in verb_label): diff --git a/transformer_srl/models.py b/transformer_srl/models.py index a1fe3f6..496d7cc 100644 --- a/transformer_srl/models.py +++ b/transformer_srl/models.py @@ -84,10 +84,10 @@ def __init__( self.span_metric = None self.f1_frame_metric = FBetaMeasure(average="micro") self.tag_projection_layer = Linear( - self.transformer.config.hidden_size, self.num_classes + self.transformer.config.hidden_size * 2, self.num_classes ) self.frame_projection_layer = Linear( - self.transformer.config.hidden_size, self.frame_num_classes + self.transformer.config.hidden_size * 2, self.frame_num_classes ) self.embedding_dropout = Dropout(p=embedding_dropout) self._label_smoothing = label_smoothing @@ -99,7 +99,6 @@ def forward( # type: ignore tokens: TextFieldTensors, verb_indicator: torch.Tensor, frame_indicator: torch.Tensor, - sentence_end: torch.LongTensor, metadata: List[Any], tags: torch.LongTensor = None, frame_tags: torch.LongTensor = None, @@ -143,57 +142,31 @@ def forward( # type: ignore loss : `torch.FloatTensor`, optional A scalar loss to be optimised. """ - # mask = get_text_field_mask(tokens) - # input_ids = util.get_token_ids_from_text_field_tensors(tokens) - # bert_embeddings, _ = self.transformer( - # input_ids=input_ids, - # # token_type_ids=verb_indicator, - # attention_mask=mask, - # ) - # # get sizes - # batch_size, sequence_length, _ = bert_embeddings.size() - - # # verb emeddings - # verb_embeddings = bert_embeddings[ - # torch.arange(batch_size).to(bert_embeddings.device), verb_indicator.argmax(1), : - # ] - # verb_embeddings = torch.where( - # (verb_indicator.sum(1, keepdim=True) > 0).repeat(1, verb_embeddings.shape[-1]), - # verb_embeddings, - # torch.zeros_like(verb_embeddings), - # ) - # bert_embeddings = torch.cat( - # (bert_embeddings, verb_embeddings.unsqueeze(1).repeat(1, bert_embeddings.shape[1], 1)), - # dim=2, - # ) - # mask = tokens["tokens"]["mask"] - # index = mask.sum(1).argmax().item() - mask = get_text_field_mask(tokens) + input_ids = util.get_token_ids_from_text_field_tensors(tokens) bert_embeddings, _ = self.transformer( - input_ids=util.get_token_ids_from_text_field_tensors(tokens), + input_ids=input_ids, # token_type_ids=verb_indicator, attention_mask=mask, ) - - batch_size, _ = mask.size() - embedded_text_input = self.embedding_dropout(bert_embeddings) - # Restrict to sentence part - sentence_mask = ( - torch.arange(mask.shape[1]).unsqueeze(0).repeat(batch_size, 1).to(mask.device) - < sentence_end.unsqueeze(1).repeat(1, mask.shape[1]) - ).long() - cutoff = sentence_end.max().item() - - # encoded_text = embedded_text_input - mask = sentence_mask[:, :cutoff].contiguous() - embedded_text_input = embedded_text_input[:, :cutoff, :] - tags = tags[:, :cutoff].contiguous() - - sequence_length = embedded_text_input.shape[1] + # get sizes + batch_size, sequence_length, _ = bert_embeddings.size() + # verb emeddings + verb_embeddings = bert_embeddings[ + torch.arange(batch_size).to(bert_embeddings.device), verb_indicator.argmax(1), : + ] + verb_embeddings = torch.where( + (verb_indicator.sum(1, keepdim=True) > 0).repeat(1, verb_embeddings.shape[-1]), + verb_embeddings, + torch.zeros_like(verb_embeddings), + ) + bert_embeddings = torch.cat( + (bert_embeddings, verb_embeddings.unsqueeze(1).repeat(1, bert_embeddings.shape[1], 1)), + dim=2, + ) # extract embeddings - # embedded_text_input = self.embedding_dropout(bert_embeddings) + embedded_text_input = self.embedding_dropout(bert_embeddings) frame_embeddings = embedded_text_input[frame_indicator == 1] # outputs logits = self.tag_projection_layer(embedded_text_input)