Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
Riccorl committed Feb 17, 2022
1 parent 0d2ec62 commit 5dc22ab
Show file tree
Hide file tree
Showing 6 changed files with 77 additions and 38 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/python-publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
name: Upload Python Package

on:
push
# release:
# types: [created]
# push
release:
types: [created]

jobs:
deploy:
Expand Down
13 changes: 13 additions & 0 deletions CITATION.cff
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# This CITATION.cff file was generated with cffinit.
# Visit https://bit.ly/cffinit to generate yours today!

cff-version: 1.2.0
title: transformer-srl
message: >-
If you use this software, please cite it using the
metadata from this file.
type: software
authors:
- email: orlandoricc@gmail.com
given-names: Riccardo
family-names: Orlando
2 changes: 1 addition & 1 deletion scripts/bert_base_span.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
source /Users/ric/mambaforge/bin/activate srl-mt

#HOME="/home/orlando"
DATASET="/Users/ric/Documents/ComputerScience/Projects/transformer-srl/data/conll2012_pb"
DATASET="/Users/ric/Documents/ComputerScience/Projects/transformer-srl/data/conll2012_pb_subset/"
PROJECT="/Users/ric/Documents/ComputerScience/Projects/transformer-srl"
# local
# DATASET="/mnt/d/Datasets/conll2012/conll-formatted-ontonotes-verbatlas-subset"
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setuptools.setup(
name="transformer_srl", # Replace with your own username
version="2.5",
version="2.5.2",
author="Riccardo Orlando",
author_email="orlandoricc@gmail.com",
description="SRL Transformer model",
Expand Down
51 changes: 26 additions & 25 deletions transformer_srl/dataset_readers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@
from collections import defaultdict
from typing import Any, DefaultDict, Dict, List, Set, Tuple

import numpy as np
from allennlp.common.file_utils import cached_path
from allennlp.data.dataset_readers.dataset_reader import DatasetReader
from allennlp.data.dataset_readers.dataset_utils.span_utils import TypedSpan
from allennlp.data.fields import Field, MetadataField, SequenceLabelField, TextField
from allennlp.data.fields import Field, MetadataField, SequenceLabelField, TextField, ArrayField
from allennlp.data.instance import Instance
from allennlp.data.token_indexers import PretrainedTransformerIndexer, TokenIndexer
from allennlp.data.tokenizers import Token
Expand All @@ -14,7 +15,7 @@
from conllu import parse_incr
from nltk import Tree
from overrides import overrides
from transformers import AutoTokenizer
from transformers import AutoTokenizer, XLMRobertaTokenizer

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -204,22 +205,19 @@ def _read(self, file_path: str):
for sentence in self._ontonotes_subset(
ontonotes_reader, file_path, self._domain_identifier
):
try:
tokens = [Token(t) for t in sentence.words]
sentence_id = sentence.sentence_id
if sentence.srl_frames:
for (_, tags) in sentence.srl_frames:
verb_indicator = [1 if label[-2:] == "-V" else 0 for label in tags]
frames = self._get_predicate_labels(sentence, verb_indicator)
lemmas = [
f for f, v in zip(sentence.predicate_lemmas, verb_indicator) if v == 1
]
if not all(v == 0 for v in verb_indicator):
yield self.text_to_instance(
tokens, verb_indicator, frames, lemmas, tags, sentence_id
)
except:
print(sentence.sentence_id)
tokens = [Token(t) for t in sentence.words]
sentence_id = sentence.sentence_id
if sentence.srl_frames:
for (_, tags) in sentence.srl_frames:
verb_indicator = [1 if label[-2:] == "-V" else 0 for label in tags]
frames = self._get_predicate_labels(sentence, verb_indicator)
lemmas = [
f for f, v in zip(sentence.predicate_lemmas, verb_indicator) if v == 1
]
if not all(v == 0 for v in verb_indicator):
yield self.text_to_instance(
tokens, verb_indicator, frames, lemmas, tags, sentence_id
)

def text_to_instance( # type: ignore
self,
Expand All @@ -244,13 +242,13 @@ def text_to_instance( # type: ignore
frame_indicator = _convert_frames_indices_to_wordpiece_indices(verb_label, offsets, True)

# add verb as information to the model
# verb_tokens = [token for token, v in zip(wordpieces, new_verbs) if v == 1]
# verb_tokens = verb_tokens + [self.tokenizer.sep_token]
# if isinstance(self.tokenizer, XLMRobertaTokenizer):
# verb_tokens = [self.tokenizer.sep_token] + verb_tokens
# wordpieces += verb_tokens
# new_verbs += [0 for _ in range(len(verb_tokens))]
# frame_indicator += [0 for _ in range(len(verb_tokens))]
verb_tokens = [token for token, v in zip(wordpieces, new_verbs) if v == 1]
verb_tokens = verb_tokens + [self.tokenizer.sep_token]
if isinstance(self.tokenizer, XLMRobertaTokenizer):
verb_tokens = [self.tokenizer.sep_token] + verb_tokens
wordpieces += verb_tokens
new_verbs += [0 for _ in range(len(verb_tokens))]
frame_indicator += [0 for _ in range(len(verb_tokens))]
# In order to override the indexing mechanism, we need to set the `text_id`
# attribute directly. This causes the indexing to use this id.
text_field = TextField(
Expand All @@ -260,12 +258,15 @@ def text_to_instance( # type: ignore
verb_indicator = SequenceLabelField(new_verbs, text_field)
frame_indicator = SequenceLabelField(frame_indicator, text_field)

sep_index = wordpieces.index(self.tokenizer.sep_token)

metadata_dict["offsets"] = start_offsets

fields: Dict[str, Field] = {
"tokens": text_field,
"verb_indicator": verb_indicator,
"frame_indicator": frame_indicator,
"sentence_end": ArrayField(np.array(sep_index + 1, dtype=np.int64), dtype=np.int64),
}

if all(x == 0 for x in verb_label):
Expand Down
41 changes: 33 additions & 8 deletions transformer_srl/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
)
from overrides import overrides
from torch import nn
from transformers import AutoModel
from transformers import AutoModel, AutoConfig

from transformer_srl.utils import load_label_list

Expand Down Expand Up @@ -49,6 +49,8 @@ def __init__(
vocab: Vocabulary,
bert_model: Union[str, AutoModel],
embedding_dropout: float = 0.0,
num_lstms: int = 2,
dense_units: int = 300,
initializer: InitializerApplicator = InitializerApplicator(),
label_smoothing: float = None,
ignore_span_metric: bool = False,
Expand All @@ -58,7 +60,8 @@ def __init__(
) -> None:
# bypass SrlBert constructor
Model.__init__(self, vocab, **kwargs)
self.transformer = AutoModel.from_pretrained(bert_model)
self.tr_config = AutoConfig.from_pretrained(bert_model, output_hidden_states=True)
self.transformer = AutoModel.from_pretrained(bert_model, config=self.tr_config)
self.frame_criterion = nn.CrossEntropyLoss()
if inventory == "verbatlas":
# add missing frame labels
Expand All @@ -76,9 +79,20 @@ def __init__(
else:
self.span_metric = None
self.f1_frame_metric = FBetaMeasure(average="micro")
self.tag_projection_layer = nn.Linear(self.transformer.config.hidden_size, self.num_classes)
self.lstms = nn.LSTM(
self.tr_config.hidden_size,
self.tr_config.hidden_size,
num_layers=num_lstms,
dropout=0.2 if num_lstms > 1 else 0,
bidirectional=True,
)
self.tag_projection_layer = torch.nn.Sequential(
nn.Linear(self.tr_config.hidden_size * 2, dense_units),
nn.ReLU(),
nn.Linear(dense_units, self.num_classes),
)
self.frame_projection_layer = nn.Linear(
self.transformer.config.hidden_size, self.frame_num_classes
self.tr_config.hidden_size, self.frame_num_classes
)
self.embedding_dropout = nn.Dropout(p=embedding_dropout)
self._label_smoothing = label_smoothing
Expand All @@ -89,6 +103,7 @@ def forward( # type: ignore
self,
tokens: TextFieldTensors,
verb_indicator: torch.Tensor,
sentence_end: torch.LongTensor,
frame_indicator: torch.Tensor,
metadata: List[Any],
tags: torch.LongTensor = None,
Expand Down Expand Up @@ -135,17 +150,27 @@ def forward( # type: ignore
"""
mask = get_text_field_mask(tokens)
input_ids = util.get_token_ids_from_text_field_tensors(tokens)
bert_embeddings, _ = self.transformer(
if self.tr_config.type_vocab_size != 1:
verb_indicator = torch.zeros_like(verb_indicator)
embeddings = self.transformer(
input_ids=input_ids,
token_type_ids=verb_indicator,
attention_mask=mask,
return_dict=False,
)
embeddings = embeddings[2][-4:]
embeddings = torch.stack(embeddings, dim=0).sum(dim=0)
# get sizes
batch_size, sequence_length, _ = embeddings.size()
# extract embeddings
embedded_text_input = self.embedding_dropout(bert_embeddings)
embedded_text_input = self.embedding_dropout(embeddings)
frame_embeddings = embedded_text_input[frame_indicator == 1]
# get sizes
batch_size, sequence_length, _ = embedded_text_input.size()
predicate_embeddings = torch.index_select(embedded_text_input, 1, sentence_end)
print(predicate_embeddings.shape)
print(embedded_text_input.shape)
# embedded_text_input += predicate_embeddings
# lstm pass
embedded_text_input = self.lstms(embedded_text_input)[0]
# outputs
logits = self.tag_projection_layer(embedded_text_input)
frame_logits = self.frame_projection_layer(frame_embeddings)
Expand Down

0 comments on commit 5dc22ab

Please sign in to comment.