Skip to content

Commit

Permalink
add frame list
Browse files Browse the repository at this point in the history
  • Loading branch information
Riccorl committed Sep 11, 2020
1 parent 9ccdd19 commit bae1b7c
Show file tree
Hide file tree
Showing 5 changed files with 537 additions and 36 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setuptools.setup(
name="transformer_srl", # Replace with your own username
version="2.3",
version="2.4rc1",
author="Riccardo Orlando",
author_email="orlandoricc@gmail.com",
description="SRL Transformer model",
Expand Down
67 changes: 43 additions & 24 deletions transformer_srl/dataset_readers.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import logging
import pathlib
from typing import Any
from typing import Dict, Tuple, List

from allennlp.common.file_utils import cached_path
from allennlp.data import Vocabulary
from allennlp.data.dataset_readers.dataset_reader import DatasetReader
from allennlp.data.fields import Field, TextField, SequenceLabelField, MetadataField
from allennlp.data.instance import Instance
Expand All @@ -17,8 +19,13 @@
from overrides import overrides
from transformers import AutoTokenizer

from transformer_srl.utils import load_label_list

logger = logging.getLogger(__name__)

FRAME_LIST_PATH = pathlib.Path(__file__).resolve().parent / "resources" / "framelist.txt"


"""
ID: Word index, integer starting at 1 for each new sentence; may be a range for tokens with multiple words.
FORM: Word form or punctuation symbol.
Expand Down Expand Up @@ -164,7 +171,6 @@ def __init__(
"tokens": PretrainedTransformerIndexer(model_name)
}
self._domain_identifier = domain_identifier

self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.lowercase_input = "uncased" in model_name

Expand Down Expand Up @@ -232,19 +238,26 @@ def _read(self, file_path: str):
self._domain_identifier,
)

for sentence in self._ontonotes_subset(
ontonotes_reader, file_path, self._domain_identifier
):
tokens = [Token(t) for t in sentence.words]
if sentence.srl_frames:
for (_, tags) in sentence.srl_frames:
verb_indicator = [1 if label[-2:] == "-V" else 0 for label in tags]
frames = self._get_predicate_labels(sentence, verb_indicator)
lemmas = [
f for f, v in zip(sentence.predicate_lemmas, verb_indicator) if v == 1
]
if not all(v == 0 for v in verb_indicator):
yield self.text_to_instance(tokens, verb_indicator, frames, lemmas, tags)
counter = 0
try:
for sentence in self._ontonotes_subset(
ontonotes_reader, file_path, self._domain_identifier
):
tokens = [Token(t) for t in sentence.words]
if sentence.srl_frames:
for (_, tags) in sentence.srl_frames:
verb_indicator = [1 if label[-2:] == "-V" else 0 for label in tags]
frames = self._get_predicate_labels(sentence, verb_indicator)
lemmas = [
f for f, v in zip(sentence.predicate_lemmas, verb_indicator) if v == 1
]
if not all(v == 0 for v in verb_indicator):
yield self.text_to_instance(
tokens, verb_indicator, frames, lemmas, tags
)
counter += 1
except:
print("COUNTER", counter)

def text_to_instance( # type: ignore
self,
Expand Down Expand Up @@ -353,16 +366,22 @@ def _convert_tags_to_wordpiece_tags(self, tags: List[str], offsets: List[int]) -

def _get_predicate_labels(self, sentence, verb_indicator):
labels = []
for i, v in enumerate(verb_indicator):
if v == 1:
label = (
"{}.{}".format(sentence.predicate_lemmas[i], sentence.predicate_framenet_ids[i])
if sentence.predicate_framenet_ids[i].isdigit()
else sentence.predicate_framenet_ids[i]
)
labels.append(label)
else:
labels.append("O")
try:
for i, v in enumerate(verb_indicator):
if v == 1:
label = (
"{}.{}".format(
sentence.predicate_lemmas[i], sentence.predicate_framenet_ids[i]
)
if sentence.predicate_framenet_ids[i].isdigit()
else sentence.predicate_framenet_ids[i]
)
labels.append(label)
else:
labels.append("O")
except:
print(sentence.words)
print(sentence.predicate_framenet_ids)
return labels


Expand Down
18 changes: 7 additions & 11 deletions transformer_srl/models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import os
import pathlib
from collections import defaultdict
from typing import Dict, List, Any, Union

import numpy as np
Expand All @@ -9,27 +7,23 @@
from allennlp.data import TextFieldTensors, Vocabulary
from allennlp.models.model import Model
from allennlp.nn import InitializerApplicator, util
from allennlp.nn.util import (
get_lengths_from_binary_sequence_mask,
viterbi_decode,
get_device_of,
)
from allennlp.nn.util import get_device_of
from allennlp.nn.util import get_text_field_mask, sequence_cross_entropy_with_logits
from allennlp.training.metrics.fbeta_measure import FBetaMeasure
from allennlp_models.structured_prediction import SrlBert
from allennlp_models.structured_prediction.metrics.srl_eval_scorer import (
DEFAULT_SRL_EVAL_PATH,
SrlEvalScorer,
)
from allennlp.modules.token_embedders import PretrainedTransformerMismatchedEmbedder
from overrides import overrides
from torch.nn.modules import Linear, Dropout
from transformers import AutoModel

from transformer_srl.utils import load_lemma_frame, load_role_frame
from transformer_srl.utils import load_lemma_frame, load_role_frame, load_label_list

LEMMA_FRAME_PATH = pathlib.Path(__file__).resolve().parent / "resources" / "lemma2frame.csv"
FRAME_ROLE_PATH = pathlib.Path(__file__).resolve().parent / "resources" / "frame2role.csv"
FRAME_LIST_PATH = pathlib.Path(__file__).resolve().parent / "resources" / "framelist.txt"


@Model.register("transformer_srl_dependency")
Expand Down Expand Up @@ -341,6 +335,9 @@ def __init__(
else:
self.transformer = bert_model
self.frame_criterion = torch.nn.CrossEntropyLoss()
# add missing labels
frame_list = load_label_list(FRAME_LIST_PATH)
self.vocab.add_tokens_to_namespace(frame_list, "frames_labels")
self.num_classes = self.vocab.get_vocab_size("labels")
self.frame_num_classes = self.vocab.get_vocab_size("frames_labels")
if srl_eval_path is not None:
Expand Down Expand Up @@ -569,8 +566,7 @@ def get_metrics(self, reset: bool = False):
metric_dict_filtered = {
x.split("-")[0] + "_role": y
for x, y in metric_dict.items()
if "overall" in x
and "f1" in x
if "overall" in x and "f1" in x
}
frame_metric_dict = {
x + "_frame": y for x, y in frame_metric_dict.items() if "fscore" in x
Expand Down
Loading

0 comments on commit bae1b7c

Please sign in to comment.