diff --git a/setup.py b/setup.py index 2111cfe..e955eeb 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ setuptools.setup( name="transformer_srl", # Replace with your own username - version="3.0rc3", + version="2.4", author="Riccardo Orlando", author_email="orlandoricc@gmail.com", description="SRL Transformer model", diff --git a/transformer_srl/dataset_readers.py b/transformer_srl/dataset_readers.py index 361e853..0e80dcf 100644 --- a/transformer_srl/dataset_readers.py +++ b/transformer_srl/dataset_readers.py @@ -258,7 +258,6 @@ def text_to_instance( # type: ignore verb_indicator = SequenceLabelField(new_verbs, text_field) frame_indicator = SequenceLabelField(frame_indicator, text_field) - sep_index = wordpieces.index(self.tokenizer.sep_token) metadata_dict["offsets"] = start_offsets @@ -266,7 +265,6 @@ def text_to_instance( # type: ignore "tokens": text_field, "verb_indicator": verb_indicator, "frame_indicator": frame_indicator, - "sentence_end": ArrayField(np.array(sep_index + 1, dtype=np.int64), dtype=np.int64), } if all(x == 0 for x in verb_label): diff --git a/transformer_srl/models.py b/transformer_srl/models.py index 62ba84e..6a08798 100644 --- a/transformer_srl/models.py +++ b/transformer_srl/models.py @@ -53,7 +53,6 @@ def __init__( vocab: Vocabulary, bert_model: Union[str, AutoModel], embedding_dropout: float = 0.0, - num_lstms: int = 2, initializer: InitializerApplicator = InitializerApplicator(), label_smoothing: float = None, ignore_span_metric: bool = False, @@ -68,8 +67,7 @@ def __init__( self.frame_role_dict = load_role_frame(FRAME_ROLE_PATH) self.restrict_frames = restrict_frames self.restrict_roles = restrict_roles - config = AutoConfig.from_pretrained(bert_model, output_hidden_states=True) - self.transformer = AutoModel.from_pretrained(bert_model, config=config) + self.transformer = AutoModel.from_pretrained(bert_model) self.frame_criterion = nn.CrossEntropyLoss() # add missing labels frame_list = load_label_list(FRAME_LIST_PATH) @@ -83,20 +81,10 @@ def __init__( else: self.span_metric = None self.f1_frame_metric = FBetaMeasure(average="micro") - self.predicate_embedding = nn.Embedding(num_embeddings=2, embedding_dim=10) - self.lstms = nn.LSTM( - config.hidden_size + 10, - config.hidden_size, - num_layers=num_lstms, - dropout=0.2 if num_lstms > 1 else 0, - bidirectional=True, - ) - # self.dropout = nn.Dropout(0.4) - # self.tag_projection_layer = nn.Linear(config.hidden_size, self.num_classes) - self.tag_projection_layer = torch.nn.Sequential( - nn.Linear(config.hidden_size * 2, 300), nn.ReLU(), nn.Linear(300, self.num_classes), + self.tag_projection_layer = nn.Linear(self.transformer.config.hidden_size, self.num_classes) + self.frame_projection_layer = nn.Linear( + self.transformer.config.hidden_size, self.frame_num_classes ) - self.frame_projection_layer = nn.Linear(config.hidden_size * 2, self.frame_num_classes) self.embedding_dropout = nn.Dropout(p=embedding_dropout) self._label_smoothing = label_smoothing self.ignore_span_metric = ignore_span_metric @@ -106,7 +94,6 @@ def forward( # type: ignore self, tokens: TextFieldTensors, verb_indicator: torch.Tensor, - sentence_end: torch.LongTensor, frame_indicator: torch.Tensor, metadata: List[Any], tags: torch.LongTensor = None, @@ -153,36 +140,18 @@ def forward( # type: ignore """ mask = get_text_field_mask(tokens) input_ids = util.get_token_ids_from_text_field_tensors(tokens) - embeddings = self.transformer(input_ids=input_ids, attention_mask=mask) - embeddings = embeddings[2][-4:] - embeddings = torch.stack(embeddings, dim=0).sum(dim=0) - # get sizes - batch_size, _, _ = embeddings.size() + bert_embeddings, _ = self.transformer( + input_ids=input_ids, token_type_ids=verb_indicator, attention_mask=mask, + ) # extract embeddings - embedded_text_input = self.embedding_dropout(embeddings) - # sentence_mask = ( - # torch.arange(mask.shape[1]).unsqueeze(0).repeat(batch_size, 1).to(mask.device) - # < sentence_end.unsqueeze(1).repeat(1, mask.shape[1]) - # ).long() - # cutoff = sentence_end.max().item() - - # encoded_text = embedded_text_input - # mask = sentence_mask[:, :cutoff].contiguous() - # encoded_text = encoded_text[:, :cutoff, :] - # tags = tags[:, :cutoff].contiguous() - # frame_tags = frame_tags[:, :cutoff].contiguous() - # frame_indicator = frame_indicator[:, :cutoff].contiguous() - - predicate_embeddings = self.predicate_embedding(verb_indicator) - # encoded_text = torch.stack((embedded_text_input, predicate_embeddings), dim=0).sum(dim=0) - embedded_text_input = torch.cat((embedded_text_input, predicate_embeddings), dim=-1) - encoded_text, _ = self.lstms(embedded_text_input) - frame_embeddings = encoded_text[frame_indicator == 1] + embedded_text_input = self.embedding_dropout(bert_embeddings) + frame_embeddings = embedded_text_input[frame_indicator == 1] + # get sizes + batch_size, sequence_length, _ = embedded_text_input.size() # outputs - logits = self.tag_projection_layer(encoded_text) + logits = self.tag_projection_layer(embedded_text_input) frame_logits = self.frame_projection_layer(frame_embeddings) - sequence_length = encoded_text.shape[1] reshaped_log_probs = logits.view(-1, self.num_classes) class_probabilities = F.softmax(reshaped_log_probs, dim=-1).view( [batch_size, sequence_length, self.num_classes]