diff --git a/training_config/bert_ml_conll2009.jsonnet b/training_config/bert_ml_conll2009.jsonnet index 4ca0a8e..06543da 100644 --- a/training_config/bert_ml_conll2009.jsonnet +++ b/training_config/bert_ml_conll2009.jsonnet @@ -19,6 +19,7 @@ "type": "transformer_srl_span", "embedding_dropout": 0.1, "bert_model": "bert-base-multilingual-cased", + "inventory": "propbank", }, "trainer": { diff --git a/training_config/bert_tiny_dep.jsonnet b/training_config/bert_tiny_dep.jsonnet index 85c2e99..6854b1f 100644 --- a/training_config/bert_tiny_dep.jsonnet +++ b/training_config/bert_tiny_dep.jsonnet @@ -1,6 +1,6 @@ { "dataset_reader": { - "type": "transformer_srl_span", + "type": "transformer_srl_dependency", "model_name": "mrm8488/bert-tiny-finetuned-squadv2", }, @@ -17,7 +17,7 @@ "model": { "type": "transformer_srl_dependency", "embedding_dropout": 0.1, - "bert_model": "mrm8488/bert-tiny-finetuned-squadv2", + "model_name": "mrm8488/bert-tiny-finetuned-squadv2", }, "trainer": { diff --git a/transformer_srl/dataset_readers.py b/transformer_srl/dataset_readers.py index b99ccc5..6ef0bf0 100644 --- a/transformer_srl/dataset_readers.py +++ b/transformer_srl/dataset_readers.py @@ -442,7 +442,7 @@ def _read(self, file_path: str): for annotation in parse_incr( conll_file, - fields=conllu_fields if self.format == "conllu" else conll2009_fields, + fields=conll2009_fields if self.format == "conll2009" else conllu_fields, field_parsers={"roles": lambda line, i: line[i:]}, ): # CoNLLU annotations sometimes add back in words that have been elided @@ -460,7 +460,7 @@ def _read(self, file_path: str): continue frames = [x["frame"] for x in annotation] # clean C-V tags (not needed yet) - frames = [f if f != "C-V" else "_" for f in frames] + frames = [f if f.isupper() else "_" for f in frames] roles = [x["roles"] for x in annotation] # transpose rolses, to have a list of roles per frame roles = list(map(list, zip(*roles))) diff --git a/transformer_srl/models.py b/transformer_srl/models.py index 6a08798..69141ff 100644 --- a/transformer_srl/models.py +++ b/transformer_srl/models.py @@ -59,6 +59,7 @@ def __init__( srl_eval_path: str = DEFAULT_SRL_EVAL_PATH, restrict_frames: bool = False, restrict_roles: bool = False, + inventory: str = "verbatlas", **kwargs, ) -> None: # bypass SrlBert constructor @@ -69,9 +70,10 @@ def __init__( self.restrict_roles = restrict_roles self.transformer = AutoModel.from_pretrained(bert_model) self.frame_criterion = nn.CrossEntropyLoss() - # add missing labels - frame_list = load_label_list(FRAME_LIST_PATH) - self.vocab.add_tokens_to_namespace(frame_list, "frames_labels") + if inventory == "verbatlas": + # add missing labels + frame_list = load_label_list(FRAME_LIST_PATH) + self.vocab.add_tokens_to_namespace(frame_list, "frames_labels") self.num_classes = self.vocab.get_vocab_size("labels") self.frame_num_classes = self.vocab.get_vocab_size("frames_labels") if srl_eval_path is not None: