Skip to content

Commit

Permalink
release 2.4
Browse files Browse the repository at this point in the history
  • Loading branch information
Riccorl committed Oct 30, 2020
1 parent 00d407d commit fb71b24
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 46 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setuptools.setup(
name="transformer_srl", # Replace with your own username
version="3.0rc3",
version="2.4",
author="Riccardo Orlando",
author_email="orlandoricc@gmail.com",
description="SRL Transformer model",
Expand Down
2 changes: 0 additions & 2 deletions transformer_srl/dataset_readers.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,15 +258,13 @@ def text_to_instance( # type: ignore
verb_indicator = SequenceLabelField(new_verbs, text_field)
frame_indicator = SequenceLabelField(frame_indicator, text_field)

sep_index = wordpieces.index(self.tokenizer.sep_token)

metadata_dict["offsets"] = start_offsets

fields: Dict[str, Field] = {
"tokens": text_field,
"verb_indicator": verb_indicator,
"frame_indicator": frame_indicator,
"sentence_end": ArrayField(np.array(sep_index + 1, dtype=np.int64), dtype=np.int64),
}

if all(x == 0 for x in verb_label):
Expand Down
55 changes: 12 additions & 43 deletions transformer_srl/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ def __init__(
vocab: Vocabulary,
bert_model: Union[str, AutoModel],
embedding_dropout: float = 0.0,
num_lstms: int = 2,
initializer: InitializerApplicator = InitializerApplicator(),
label_smoothing: float = None,
ignore_span_metric: bool = False,
Expand All @@ -68,8 +67,7 @@ def __init__(
self.frame_role_dict = load_role_frame(FRAME_ROLE_PATH)
self.restrict_frames = restrict_frames
self.restrict_roles = restrict_roles
config = AutoConfig.from_pretrained(bert_model, output_hidden_states=True)
self.transformer = AutoModel.from_pretrained(bert_model, config=config)
self.transformer = AutoModel.from_pretrained(bert_model)
self.frame_criterion = nn.CrossEntropyLoss()
# add missing labels
frame_list = load_label_list(FRAME_LIST_PATH)
Expand All @@ -83,20 +81,10 @@ def __init__(
else:
self.span_metric = None
self.f1_frame_metric = FBetaMeasure(average="micro")
self.predicate_embedding = nn.Embedding(num_embeddings=2, embedding_dim=10)
self.lstms = nn.LSTM(
config.hidden_size + 10,
config.hidden_size,
num_layers=num_lstms,
dropout=0.2 if num_lstms > 1 else 0,
bidirectional=True,
)
# self.dropout = nn.Dropout(0.4)
# self.tag_projection_layer = nn.Linear(config.hidden_size, self.num_classes)
self.tag_projection_layer = torch.nn.Sequential(
nn.Linear(config.hidden_size * 2, 300), nn.ReLU(), nn.Linear(300, self.num_classes),
self.tag_projection_layer = nn.Linear(self.transformer.config.hidden_size, self.num_classes)
self.frame_projection_layer = nn.Linear(
self.transformer.config.hidden_size, self.frame_num_classes
)
self.frame_projection_layer = nn.Linear(config.hidden_size * 2, self.frame_num_classes)
self.embedding_dropout = nn.Dropout(p=embedding_dropout)
self._label_smoothing = label_smoothing
self.ignore_span_metric = ignore_span_metric
Expand All @@ -106,7 +94,6 @@ def forward( # type: ignore
self,
tokens: TextFieldTensors,
verb_indicator: torch.Tensor,
sentence_end: torch.LongTensor,
frame_indicator: torch.Tensor,
metadata: List[Any],
tags: torch.LongTensor = None,
Expand Down Expand Up @@ -153,36 +140,18 @@ def forward( # type: ignore
"""
mask = get_text_field_mask(tokens)
input_ids = util.get_token_ids_from_text_field_tensors(tokens)
embeddings = self.transformer(input_ids=input_ids, attention_mask=mask)
embeddings = embeddings[2][-4:]
embeddings = torch.stack(embeddings, dim=0).sum(dim=0)
# get sizes
batch_size, _, _ = embeddings.size()
bert_embeddings, _ = self.transformer(
input_ids=input_ids, token_type_ids=verb_indicator, attention_mask=mask,
)
# extract embeddings
embedded_text_input = self.embedding_dropout(embeddings)
# sentence_mask = (
# torch.arange(mask.shape[1]).unsqueeze(0).repeat(batch_size, 1).to(mask.device)
# < sentence_end.unsqueeze(1).repeat(1, mask.shape[1])
# ).long()
# cutoff = sentence_end.max().item()

# encoded_text = embedded_text_input
# mask = sentence_mask[:, :cutoff].contiguous()
# encoded_text = encoded_text[:, :cutoff, :]
# tags = tags[:, :cutoff].contiguous()
# frame_tags = frame_tags[:, :cutoff].contiguous()
# frame_indicator = frame_indicator[:, :cutoff].contiguous()

predicate_embeddings = self.predicate_embedding(verb_indicator)
# encoded_text = torch.stack((embedded_text_input, predicate_embeddings), dim=0).sum(dim=0)
embedded_text_input = torch.cat((embedded_text_input, predicate_embeddings), dim=-1)
encoded_text, _ = self.lstms(embedded_text_input)
frame_embeddings = encoded_text[frame_indicator == 1]
embedded_text_input = self.embedding_dropout(bert_embeddings)
frame_embeddings = embedded_text_input[frame_indicator == 1]
# get sizes
batch_size, sequence_length, _ = embedded_text_input.size()
# outputs
logits = self.tag_projection_layer(encoded_text)
logits = self.tag_projection_layer(embedded_text_input)
frame_logits = self.frame_projection_layer(frame_embeddings)

sequence_length = encoded_text.shape[1]
reshaped_log_probs = logits.view(-1, self.num_classes)
class_probabilities = F.softmax(reshaped_log_probs, dim=-1).view(
[batch_size, sequence_length, self.num_classes]
Expand Down

0 comments on commit fb71b24

Please sign in to comment.