Skip to content

Commit

Permalink
added role list to labels
Browse files Browse the repository at this point in the history
  • Loading branch information
Riccorl committed Dec 2, 2020
1 parent c52eaed commit 95ee1cd
Show file tree
Hide file tree
Showing 5 changed files with 21,092 additions and 6 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setuptools.setup(
name="transformer_srl", # Replace with your own username
version="2.4.7",
version="2.4.9",
author="Riccardo Orlando",
author_email="orlandoricc@gmail.com",
description="SRL Transformer model",
Expand Down
12 changes: 8 additions & 4 deletions transformer_srl/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,10 @@

from transformer_srl.utils import load_label_list, load_lemma_frame, load_role_frame

LEMMA_FRAME_PATH = pathlib.Path(__file__).resolve().parent / "resources" / "lemma2frame.csv"
LEMMA_FRAME_PATH = pathlib.Path(__file__).resolve().parent / "resources" / "lemma2va_ml.tsv"
FRAME_ROLE_PATH = pathlib.Path(__file__).resolve().parent / "resources" / "frame2role.csv"
FRAME_LIST_PATH = pathlib.Path(__file__).resolve().parent / "resources" / "framelist.txt"
ROLE_LIST_PATH = pathlib.Path(__file__).resolve().parent / "resources" / "rolelist.txt"


@Model.register("transformer_srl_span")
Expand Down Expand Up @@ -65,14 +66,17 @@ def __init__(
Model.__init__(self, vocab, **kwargs)
self.lemma_frame_dict = load_lemma_frame(LEMMA_FRAME_PATH)
self.frame_role_dict = load_role_frame(FRAME_ROLE_PATH)
self.restrict_frames = restrict_frames
self.restrict_roles = restrict_roles
self.restrict_frames = True
self.restrict_roles = True
self.transformer = AutoModel.from_pretrained(bert_model)
self.frame_criterion = nn.CrossEntropyLoss()
if inventory == "verbatlas":
# add missing labels
# add missing frame labels
frame_list = load_label_list(FRAME_LIST_PATH)
self.vocab.add_tokens_to_namespace(frame_list, "frames_labels")
# add missing role labels
role_list = load_label_list(ROLE_LIST_PATH)
self.vocab.add_tokens_to_namespace(role_list, "labels")
self.num_classes = self.vocab.get_vocab_size("labels")
self.frame_num_classes = self.vocab.get_vocab_size("frames_labels")
if srl_eval_path is not None:
Expand Down
2 changes: 1 addition & 1 deletion transformer_srl/predictors.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def make_srl_string(words: List[str], tags: List[str], frame: str) -> str:
@overrides
def _sentence_to_srl_instances(self, json_dict: JsonDict) -> List[Instance]:
sentence = json_dict["sentence"]
if json_dict.get("verbs"):
if "verbs" in json_dict.keys():
text = sentence.split()
pos = ["VERB" if i == json_dict["verbs"] else "NOUN" for i, _ in enumerate(text)]
tokens = [Token(t, i, i + len(text), pos_=p) for i, (t, p) in enumerate(zip(text, pos))]
Expand Down
Loading

0 comments on commit 95ee1cd

Please sign in to comment.