Skip to content

Commit

Permalink
verb embeddings
Browse files Browse the repository at this point in the history
  • Loading branch information
Riccorl committed Oct 16, 2020
1 parent 41b66d9 commit b5c53c0
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 23 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setuptools.setup(
name="transformer_srl", # Replace with your own username
version="2.4rc5",
version="2.4rc6",
author="Riccardo Orlando",
author_email="orlandoricc@gmail.com",
description="SRL Transformer model",
Expand Down
6 changes: 4 additions & 2 deletions transformer_srl/dataset_readers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from allennlp.common.file_utils import cached_path
from allennlp.data import Vocabulary
from allennlp.data.dataset_readers.dataset_reader import DatasetReader
from allennlp.data.fields import Field, TextField, SequenceLabelField, MetadataField
from allennlp.data.fields import Field, TextField, SequenceLabelField, MetadataField, ArrayField
from allennlp.data.instance import Instance
from allennlp.data.token_indexers import (
TokenIndexer,
Expand All @@ -18,7 +18,7 @@
from conllu import parse_incr
from overrides import overrides
from transformers import AutoTokenizer

import numpy as np
from transformer_srl.utils import load_label_list

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -294,6 +294,7 @@ def text_to_instance( # type: ignore
)
new_verbs = _convert_verb_indices_to_wordpiece_indices(verb_label, offsets)
frame_indicator = _convert_frames_indices_to_wordpiece_indices(verb_label, offsets, True)
sep_index = wordpieces.index(self.tokenizer.sep_token)
metadata_dict["offsets"] = start_offsets
# In order to override the indexing mechanism, we need to set the `text_id`
# attribute directly. This causes the indexing to use this id.
Expand All @@ -308,6 +309,7 @@ def text_to_instance( # type: ignore
"tokens": text_field,
"verb_indicator": verb_indicator,
"frame_indicator": frame_indicator,
"sentence_end": ArrayField(np.array(sep_index + 1, dtype=np.int64), dtype=np.int64),
}

if all(x == 0 for x in verb_label):
Expand Down
67 changes: 47 additions & 20 deletions transformer_srl/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,10 @@ def __init__(
self.span_metric = None
self.f1_frame_metric = FBetaMeasure(average="micro")
self.tag_projection_layer = Linear(
self.transformer.config.hidden_size * 2, self.num_classes
self.transformer.config.hidden_size, self.num_classes
)
self.frame_projection_layer = Linear(
self.transformer.config.hidden_size * 2, self.frame_num_classes
self.transformer.config.hidden_size, self.frame_num_classes
)
self.embedding_dropout = Dropout(p=embedding_dropout)
self._label_smoothing = label_smoothing
Expand All @@ -99,6 +99,7 @@ def forward( # type: ignore
tokens: TextFieldTensors,
verb_indicator: torch.Tensor,
frame_indicator: torch.Tensor,
sentence_end: torch.LongTensor,
metadata: List[Any],
tags: torch.LongTensor = None,
frame_tags: torch.LongTensor = None,
Expand Down Expand Up @@ -142,31 +143,57 @@ def forward( # type: ignore
loss : `torch.FloatTensor`, optional
A scalar loss to be optimised.
"""
# mask = get_text_field_mask(tokens)
# input_ids = util.get_token_ids_from_text_field_tensors(tokens)
# bert_embeddings, _ = self.transformer(
# input_ids=input_ids,
# # token_type_ids=verb_indicator,
# attention_mask=mask,
# )
# # get sizes
# batch_size, sequence_length, _ = bert_embeddings.size()

# # verb emeddings
# verb_embeddings = bert_embeddings[
# torch.arange(batch_size).to(bert_embeddings.device), verb_indicator.argmax(1), :
# ]
# verb_embeddings = torch.where(
# (verb_indicator.sum(1, keepdim=True) > 0).repeat(1, verb_embeddings.shape[-1]),
# verb_embeddings,
# torch.zeros_like(verb_embeddings),
# )
# bert_embeddings = torch.cat(
# (bert_embeddings, verb_embeddings.unsqueeze(1).repeat(1, bert_embeddings.shape[1], 1)),
# dim=2,
# )
# mask = tokens["tokens"]["mask"]
# index = mask.sum(1).argmax().item()

mask = get_text_field_mask(tokens)
input_ids = util.get_token_ids_from_text_field_tensors(tokens)
bert_embeddings, _ = self.transformer(
input_ids=input_ids,
input_ids=util.get_token_ids_from_text_field_tensors(tokens),
# token_type_ids=verb_indicator,
attention_mask=mask,
)
# get sizes
batch_size, sequence_length, _ = bert_embeddings.size()
# verb emeddings
verb_embeddings = bert_embeddings[
torch.arange(batch_size).to(bert_embeddings.device), verb_indicator.argmax(1), :
]
verb_embeddings = torch.where(
(verb_indicator.sum(1, keepdim=True) > 0).repeat(1, verb_embeddings.shape[-1]),
verb_embeddings,
torch.zeros_like(verb_embeddings),
)
bert_embeddings = torch.cat(
(bert_embeddings, verb_embeddings.unsqueeze(1).repeat(1, bert_embeddings.shape[1], 1)),
dim=2,
)

# extract embeddings
batch_size, _ = mask.size()
embedded_text_input = self.embedding_dropout(bert_embeddings)
# Restrict to sentence part
sentence_mask = (
torch.arange(mask.shape[1]).unsqueeze(0).repeat(batch_size, 1).to(mask.device)
< sentence_end.unsqueeze(1).repeat(1, mask.shape[1])
).long()
cutoff = sentence_end.max().item()

# encoded_text = embedded_text_input
mask = sentence_mask[:, :cutoff].contiguous()
embedded_text_input = embedded_text_input[:, :cutoff, :]
tags = tags[:, :cutoff].contiguous()

sequence_length = embedded_text_input.shape[1]

# extract embeddings
# embedded_text_input = self.embedding_dropout(bert_embeddings)
frame_embeddings = embedded_text_input[frame_indicator == 1]
# outputs
logits = self.tag_projection_layer(embedded_text_input)
Expand Down

0 comments on commit b5c53c0

Please sign in to comment.