Skip to content

Commit

Permalink
xmlr model
Browse files Browse the repository at this point in the history
  • Loading branch information
Riccorl committed Oct 27, 2020
1 parent e0183b5 commit 10099db
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 23 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setuptools.setup(
name="transformer_srl", # Replace with your own username
version="3.0rc1",
version="3.0rc2",
author="Riccardo Orlando",
author_email="orlandoricc@gmail.com",
description="SRL Transformer model",
Expand Down
14 changes: 7 additions & 7 deletions transformer_srl/dataset_readers.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,13 +242,13 @@ def text_to_instance( # type: ignore
frame_indicator = _convert_frames_indices_to_wordpiece_indices(verb_label, offsets, True)

# add verb as information to the model
verb_tokens = [token for token, v in zip(wordpieces, new_verbs) if v == 1]
verb_tokens = verb_tokens + [self.tokenizer.sep_token]
if isinstance(self.tokenizer, XLMRobertaTokenizer):
verb_tokens = [self.tokenizer.sep_token] + verb_tokens
wordpieces += verb_tokens
new_verbs += [0 for _ in range(len(verb_tokens))]
frame_indicator += [0 for _ in range(len(verb_tokens))]
# verb_tokens = [token for token, v in zip(wordpieces, new_verbs) if v == 1]
# verb_tokens = verb_tokens + [self.tokenizer.sep_token]
# if isinstance(self.tokenizer, XLMRobertaTokenizer):
# verb_tokens = [self.tokenizer.sep_token] + verb_tokens
# wordpieces += verb_tokens
# new_verbs += [0 for _ in range(len(verb_tokens))]
# frame_indicator += [0 for _ in range(len(verb_tokens))]
# In order to override the indexing mechanism, we need to set the `text_id`
# attribute directly. This causes the indexing to use this id.
text_field = TextField(
Expand Down
34 changes: 19 additions & 15 deletions transformer_srl/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,10 @@ def __init__(
else:
self.span_metric = None
self.f1_frame_metric = FBetaMeasure(average="micro")
self.predicate_embedding = nn.Embedding(num_embeddings=2, embedding_dim=10)
# self.tag_projection_layer = nn.Linear(config.hidden_size, self.num_classes)
self.tag_projection_layer = torch.nn.Sequential(
nn.Linear(config.hidden_size, 300), nn.ReLU(), nn.Linear(300, self.num_classes),
nn.Linear(config.hidden_size + 10, 300), nn.ReLU(), nn.Linear(300, self.num_classes),
)
self.frame_projection_layer = nn.Linear(config.hidden_size, self.frame_num_classes)
self.embedding_dropout = nn.Dropout(p=embedding_dropout)
Expand Down Expand Up @@ -150,20 +151,23 @@ def forward( # type: ignore
batch_size, _, _ = embeddings.size()
# extract embeddings
embedded_text_input = self.embedding_dropout(embeddings)
sentence_mask = (
torch.arange(mask.shape[1]).unsqueeze(0).repeat(batch_size, 1).to(mask.device)
< sentence_end.unsqueeze(1).repeat(1, mask.shape[1])
).long()
cutoff = sentence_end.max().item()

encoded_text = embedded_text_input
mask = sentence_mask[:, :cutoff].contiguous()
encoded_text = encoded_text[:, :cutoff, :]
tags = tags[:, :cutoff].contiguous()
frame_tags = frame_tags[:, :cutoff].contiguous()
frame_indicator = frame_indicator[:, :cutoff].contiguous()

frame_embeddings = encoded_text[frame_indicator == 1]
# sentence_mask = (
# torch.arange(mask.shape[1]).unsqueeze(0).repeat(batch_size, 1).to(mask.device)
# < sentence_end.unsqueeze(1).repeat(1, mask.shape[1])
# ).long()
# cutoff = sentence_end.max().item()

# encoded_text = embedded_text_input
# mask = sentence_mask[:, :cutoff].contiguous()
# encoded_text = encoded_text[:, :cutoff, :]
# tags = tags[:, :cutoff].contiguous()
# frame_tags = frame_tags[:, :cutoff].contiguous()
# frame_indicator = frame_indicator[:, :cutoff].contiguous()

predicate_embeddings = self.predicate_embedding(verb_indicator)
# encoded_text = torch.stack((embedded_text_input, predicate_embeddings), dim=0).sum(dim=0)
encoded_text = torch.cat((embedded_text_input, predicate_embeddings), dim=-1)
frame_embeddings = embedded_text_input[frame_indicator == 1]
# outputs
logits = self.tag_projection_layer(encoded_text)
frame_logits = self.frame_projection_layer(frame_embeddings)
Expand Down

0 comments on commit 10099db

Please sign in to comment.