Skip to content

Commit

Permalink
fix bug with propbank, fix dep-based
Browse files Browse the repository at this point in the history
  • Loading branch information
Riccorl committed Nov 18, 2020
1 parent ccce441 commit 9725154
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 7 deletions.
1 change: 1 addition & 0 deletions training_config/bert_ml_conll2009.jsonnet
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
"type": "transformer_srl_span",
"embedding_dropout": 0.1,
"bert_model": "bert-base-multilingual-cased",
"inventory": "propbank",
},

"trainer": {
Expand Down
4 changes: 2 additions & 2 deletions training_config/bert_tiny_dep.jsonnet
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"dataset_reader": {
"type": "transformer_srl_span",
"type": "transformer_srl_dependency",
"model_name": "mrm8488/bert-tiny-finetuned-squadv2",
},

Expand All @@ -17,7 +17,7 @@
"model": {
"type": "transformer_srl_dependency",
"embedding_dropout": 0.1,
"bert_model": "mrm8488/bert-tiny-finetuned-squadv2",
"model_name": "mrm8488/bert-tiny-finetuned-squadv2",
},

"trainer": {
Expand Down
4 changes: 2 additions & 2 deletions transformer_srl/dataset_readers.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,7 @@ def _read(self, file_path: str):

for annotation in parse_incr(
conll_file,
fields=conllu_fields if self.format == "conllu" else conll2009_fields,
fields=conll2009_fields if self.format == "conll2009" else conllu_fields,
field_parsers={"roles": lambda line, i: line[i:]},
):
# CoNLLU annotations sometimes add back in words that have been elided
Expand All @@ -460,7 +460,7 @@ def _read(self, file_path: str):
continue
frames = [x["frame"] for x in annotation]
# clean C-V tags (not needed yet)
frames = [f if f != "C-V" else "_" for f in frames]
frames = [f if f.isupper() else "_" for f in frames]
roles = [x["roles"] for x in annotation]
# transpose rolses, to have a list of roles per frame
roles = list(map(list, zip(*roles)))
Expand Down
8 changes: 5 additions & 3 deletions transformer_srl/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def __init__(
srl_eval_path: str = DEFAULT_SRL_EVAL_PATH,
restrict_frames: bool = False,
restrict_roles: bool = False,
inventory: str = "verbatlas",
**kwargs,
) -> None:
# bypass SrlBert constructor
Expand All @@ -69,9 +70,10 @@ def __init__(
self.restrict_roles = restrict_roles
self.transformer = AutoModel.from_pretrained(bert_model)
self.frame_criterion = nn.CrossEntropyLoss()
# add missing labels
frame_list = load_label_list(FRAME_LIST_PATH)
self.vocab.add_tokens_to_namespace(frame_list, "frames_labels")
if inventory == "verbatlas":
# add missing labels
frame_list = load_label_list(FRAME_LIST_PATH)
self.vocab.add_tokens_to_namespace(frame_list, "frames_labels")
self.num_classes = self.vocab.get_vocab_size("labels")
self.frame_num_classes = self.vocab.get_vocab_size("frames_labels")
if srl_eval_path is not None:
Expand Down

0 comments on commit 9725154

Please sign in to comment.