From 5dc22abf762d17566e02714d9632c3bd7b044477 Mon Sep 17 00:00:00 2001 From: Riccardo Orlando Date: Thu, 17 Feb 2022 16:29:51 +0100 Subject: [PATCH] Update --- .github/workflows/python-publish.yml | 6 ++-- CITATION.cff | 13 +++++++ scripts/bert_base_span.sh | 2 +- setup.py | 2 +- transformer_srl/dataset_readers.py | 51 ++++++++++++++-------------- transformer_srl/models.py | 41 +++++++++++++++++----- 6 files changed, 77 insertions(+), 38 deletions(-) create mode 100644 CITATION.cff diff --git a/.github/workflows/python-publish.yml b/.github/workflows/python-publish.yml index a369c2d..1aad0c9 100644 --- a/.github/workflows/python-publish.yml +++ b/.github/workflows/python-publish.yml @@ -4,9 +4,9 @@ name: Upload Python Package on: - push - # release: - # types: [created] +# push + release: + types: [created] jobs: deploy: diff --git a/CITATION.cff b/CITATION.cff new file mode 100644 index 0000000..7626d16 --- /dev/null +++ b/CITATION.cff @@ -0,0 +1,13 @@ +# This CITATION.cff file was generated with cffinit. +# Visit https://bit.ly/cffinit to generate yours today! + +cff-version: 1.2.0 +title: transformer-srl +message: >- + If you use this software, please cite it using the + metadata from this file. +type: software +authors: + - email: orlandoricc@gmail.com + given-names: Riccardo + family-names: Orlando diff --git a/scripts/bert_base_span.sh b/scripts/bert_base_span.sh index d7f652b..99dd693 100755 --- a/scripts/bert_base_span.sh +++ b/scripts/bert_base_span.sh @@ -2,7 +2,7 @@ source /Users/ric/mambaforge/bin/activate srl-mt #HOME="/home/orlando" -DATASET="/Users/ric/Documents/ComputerScience/Projects/transformer-srl/data/conll2012_pb" +DATASET="/Users/ric/Documents/ComputerScience/Projects/transformer-srl/data/conll2012_pb_subset/" PROJECT="/Users/ric/Documents/ComputerScience/Projects/transformer-srl" # local # DATASET="/mnt/d/Datasets/conll2012/conll-formatted-ontonotes-verbatlas-subset" diff --git a/setup.py b/setup.py index df1f246..9dc016b 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ setuptools.setup( name="transformer_srl", # Replace with your own username - version="2.5", + version="2.5.2", author="Riccardo Orlando", author_email="orlandoricc@gmail.com", description="SRL Transformer model", diff --git a/transformer_srl/dataset_readers.py b/transformer_srl/dataset_readers.py index 2b79043..930a8f1 100644 --- a/transformer_srl/dataset_readers.py +++ b/transformer_srl/dataset_readers.py @@ -2,10 +2,11 @@ from collections import defaultdict from typing import Any, DefaultDict, Dict, List, Set, Tuple +import numpy as np from allennlp.common.file_utils import cached_path from allennlp.data.dataset_readers.dataset_reader import DatasetReader from allennlp.data.dataset_readers.dataset_utils.span_utils import TypedSpan -from allennlp.data.fields import Field, MetadataField, SequenceLabelField, TextField +from allennlp.data.fields import Field, MetadataField, SequenceLabelField, TextField, ArrayField from allennlp.data.instance import Instance from allennlp.data.token_indexers import PretrainedTransformerIndexer, TokenIndexer from allennlp.data.tokenizers import Token @@ -14,7 +15,7 @@ from conllu import parse_incr from nltk import Tree from overrides import overrides -from transformers import AutoTokenizer +from transformers import AutoTokenizer, XLMRobertaTokenizer logger = logging.getLogger(__name__) @@ -204,22 +205,19 @@ def _read(self, file_path: str): for sentence in self._ontonotes_subset( ontonotes_reader, file_path, self._domain_identifier ): - try: - tokens = [Token(t) for t in sentence.words] - sentence_id = sentence.sentence_id - if sentence.srl_frames: - for (_, tags) in sentence.srl_frames: - verb_indicator = [1 if label[-2:] == "-V" else 0 for label in tags] - frames = self._get_predicate_labels(sentence, verb_indicator) - lemmas = [ - f for f, v in zip(sentence.predicate_lemmas, verb_indicator) if v == 1 - ] - if not all(v == 0 for v in verb_indicator): - yield self.text_to_instance( - tokens, verb_indicator, frames, lemmas, tags, sentence_id - ) - except: - print(sentence.sentence_id) + tokens = [Token(t) for t in sentence.words] + sentence_id = sentence.sentence_id + if sentence.srl_frames: + for (_, tags) in sentence.srl_frames: + verb_indicator = [1 if label[-2:] == "-V" else 0 for label in tags] + frames = self._get_predicate_labels(sentence, verb_indicator) + lemmas = [ + f for f, v in zip(sentence.predicate_lemmas, verb_indicator) if v == 1 + ] + if not all(v == 0 for v in verb_indicator): + yield self.text_to_instance( + tokens, verb_indicator, frames, lemmas, tags, sentence_id + ) def text_to_instance( # type: ignore self, @@ -244,13 +242,13 @@ def text_to_instance( # type: ignore frame_indicator = _convert_frames_indices_to_wordpiece_indices(verb_label, offsets, True) # add verb as information to the model - # verb_tokens = [token for token, v in zip(wordpieces, new_verbs) if v == 1] - # verb_tokens = verb_tokens + [self.tokenizer.sep_token] - # if isinstance(self.tokenizer, XLMRobertaTokenizer): - # verb_tokens = [self.tokenizer.sep_token] + verb_tokens - # wordpieces += verb_tokens - # new_verbs += [0 for _ in range(len(verb_tokens))] - # frame_indicator += [0 for _ in range(len(verb_tokens))] + verb_tokens = [token for token, v in zip(wordpieces, new_verbs) if v == 1] + verb_tokens = verb_tokens + [self.tokenizer.sep_token] + if isinstance(self.tokenizer, XLMRobertaTokenizer): + verb_tokens = [self.tokenizer.sep_token] + verb_tokens + wordpieces += verb_tokens + new_verbs += [0 for _ in range(len(verb_tokens))] + frame_indicator += [0 for _ in range(len(verb_tokens))] # In order to override the indexing mechanism, we need to set the `text_id` # attribute directly. This causes the indexing to use this id. text_field = TextField( @@ -260,12 +258,15 @@ def text_to_instance( # type: ignore verb_indicator = SequenceLabelField(new_verbs, text_field) frame_indicator = SequenceLabelField(frame_indicator, text_field) + sep_index = wordpieces.index(self.tokenizer.sep_token) + metadata_dict["offsets"] = start_offsets fields: Dict[str, Field] = { "tokens": text_field, "verb_indicator": verb_indicator, "frame_indicator": frame_indicator, + "sentence_end": ArrayField(np.array(sep_index + 1, dtype=np.int64), dtype=np.int64), } if all(x == 0 for x in verb_label): diff --git a/transformer_srl/models.py b/transformer_srl/models.py index b566f5c..f2508a6 100644 --- a/transformer_srl/models.py +++ b/transformer_srl/models.py @@ -15,7 +15,7 @@ ) from overrides import overrides from torch import nn -from transformers import AutoModel +from transformers import AutoModel, AutoConfig from transformer_srl.utils import load_label_list @@ -49,6 +49,8 @@ def __init__( vocab: Vocabulary, bert_model: Union[str, AutoModel], embedding_dropout: float = 0.0, + num_lstms: int = 2, + dense_units: int = 300, initializer: InitializerApplicator = InitializerApplicator(), label_smoothing: float = None, ignore_span_metric: bool = False, @@ -58,7 +60,8 @@ def __init__( ) -> None: # bypass SrlBert constructor Model.__init__(self, vocab, **kwargs) - self.transformer = AutoModel.from_pretrained(bert_model) + self.tr_config = AutoConfig.from_pretrained(bert_model, output_hidden_states=True) + self.transformer = AutoModel.from_pretrained(bert_model, config=self.tr_config) self.frame_criterion = nn.CrossEntropyLoss() if inventory == "verbatlas": # add missing frame labels @@ -76,9 +79,20 @@ def __init__( else: self.span_metric = None self.f1_frame_metric = FBetaMeasure(average="micro") - self.tag_projection_layer = nn.Linear(self.transformer.config.hidden_size, self.num_classes) + self.lstms = nn.LSTM( + self.tr_config.hidden_size, + self.tr_config.hidden_size, + num_layers=num_lstms, + dropout=0.2 if num_lstms > 1 else 0, + bidirectional=True, + ) + self.tag_projection_layer = torch.nn.Sequential( + nn.Linear(self.tr_config.hidden_size * 2, dense_units), + nn.ReLU(), + nn.Linear(dense_units, self.num_classes), + ) self.frame_projection_layer = nn.Linear( - self.transformer.config.hidden_size, self.frame_num_classes + self.tr_config.hidden_size, self.frame_num_classes ) self.embedding_dropout = nn.Dropout(p=embedding_dropout) self._label_smoothing = label_smoothing @@ -89,6 +103,7 @@ def forward( # type: ignore self, tokens: TextFieldTensors, verb_indicator: torch.Tensor, + sentence_end: torch.LongTensor, frame_indicator: torch.Tensor, metadata: List[Any], tags: torch.LongTensor = None, @@ -135,17 +150,27 @@ def forward( # type: ignore """ mask = get_text_field_mask(tokens) input_ids = util.get_token_ids_from_text_field_tensors(tokens) - bert_embeddings, _ = self.transformer( + if self.tr_config.type_vocab_size != 1: + verb_indicator = torch.zeros_like(verb_indicator) + embeddings = self.transformer( input_ids=input_ids, token_type_ids=verb_indicator, attention_mask=mask, return_dict=False, ) + embeddings = embeddings[2][-4:] + embeddings = torch.stack(embeddings, dim=0).sum(dim=0) + # get sizes + batch_size, sequence_length, _ = embeddings.size() # extract embeddings - embedded_text_input = self.embedding_dropout(bert_embeddings) + embedded_text_input = self.embedding_dropout(embeddings) frame_embeddings = embedded_text_input[frame_indicator == 1] - # get sizes - batch_size, sequence_length, _ = embedded_text_input.size() + predicate_embeddings = torch.index_select(embedded_text_input, 1, sentence_end) + print(predicate_embeddings.shape) + print(embedded_text_input.shape) + # embedded_text_input += predicate_embeddings + # lstm pass + embedded_text_input = self.lstms(embedded_text_input)[0] # outputs logits = self.tag_projection_layer(embedded_text_input) frame_logits = self.frame_projection_layer(frame_embeddings)