Skip to content

Commit dfcd35f

Browse files
authored
Merge pull request #2718 from flairNLP/GH-2717-ignore-labels
GH-2717: Add option to ignore labels in dataset
2 parents 1fe18be + 78b17f2 commit dfcd35f

File tree

2 files changed

+17
-6
lines changed

2 files changed

+17
-6
lines changed

flair/datasets/relation_extraction.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ def __init__(
3939
base_path: Union[str, Path] = None,
4040
in_memory: bool = True,
4141
augment_train: bool = False,
42+
**corpusargs,
4243
):
4344
"""
4445
SemEval-2010 Task 8 on Multi-Way Classification of Semantic Relations Between Pairs of
@@ -83,6 +84,7 @@ def __init__(
8384
column_format={1: "text", 2: "ner"},
8485
comment_symbol="# ",
8586
in_memory=in_memory,
87+
**corpusargs,
8688
)
8789

8890
def extract_and_convert_to_conllu(self, data_file, data_folder, augment_train):
@@ -227,7 +229,7 @@ def _semeval_lines_to_token_list(self, raw_lines, augment_relations):
227229

228230

229231
class RE_ENGLISH_TACRED(ColumnCorpus):
230-
def __init__(self, base_path: Union[str, Path] = None, in_memory: bool = True):
232+
def __init__(self, base_path: Union[str, Path] = None, in_memory: bool = True, **corpusargs):
231233
"""
232234
TAC Relation Extraction Dataset with 41 relations from https://nlp.stanford.edu/projects/tacred/.
233235
Manual download is required for this dataset.
@@ -260,6 +262,7 @@ def __init__(self, base_path: Union[str, Path] = None, in_memory: bool = True):
260262
column_format={1: "text", 2: "ner"},
261263
comment_symbol="# ",
262264
in_memory=in_memory,
265+
**corpusargs,
263266
)
264267

265268
def extract_and_convert_to_conllu(self, data_file, data_folder):
@@ -351,7 +354,7 @@ def _tacred_example_to_token_list(self, example: Dict[str, Any]) -> conllu.Token
351354

352355

353356
class RE_ENGLISH_CONLL04(ColumnCorpus):
354-
def __init__(self, base_path: Union[str, Path] = None, in_memory: bool = True):
357+
def __init__(self, base_path: Union[str, Path] = None, in_memory: bool = True, **corpusargs):
355358
if not base_path:
356359
base_path = flair.cache_root / "datasets"
357360
else:
@@ -385,6 +388,7 @@ def __init__(self, base_path: Union[str, Path] = None, in_memory: bool = True):
385388
in_memory=in_memory,
386389
column_format={1: "text", 2: "ner"},
387390
comment_symbol="# ",
391+
**corpusargs,
388392
)
389393

390394
def _parse_incr(self, source_file) -> Iterable[conllu.TokenList]:
@@ -536,6 +540,7 @@ def __init__(
536540
base_path: Union[str, Path] = None,
537541
in_memory: bool = True,
538542
sentence_splitter: SentenceSplitter = SegtokSentenceSplitter(),
543+
**corpusargs,
539544
):
540545
"""
541546
DrugProt corpus: Biocreative VII Track 1 from https://zenodo.org/record/5119892#.YSdSaVuxU5k/ on
@@ -570,6 +575,7 @@ def __init__(
570575
sample_missing_splits=False,
571576
column_format={1: "text", 2: "ner", 3: "ner"},
572577
comment_symbol="# ",
578+
**corpusargs,
573579
)
574580

575581
def extract_and_convert_to_conllu(self, data_file, data_folder):

flair/datasets/sequence_labeling.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -662,7 +662,8 @@ def _convert_lines_to_sentence(
662662
for span_indices, score, label in predicted_spans:
663663
span = sentence[span_indices[0] : span_indices[-1] + 1]
664664
value = self._remap_label(label)
665-
span.add_label(span_level_tag_columns[span_column], value=value, score=score)
665+
if value != "O":
666+
span.add_label(span_level_tag_columns[span_column], value=value, score=score)
666667
except Exception:
667668
pass
668669

@@ -681,7 +682,9 @@ def _convert_lines_to_sentence(
681682
relation = Relation(
682683
first=sentence[head_start - 1 : head_end], second=sentence[tail_start - 1 : tail_end]
683684
)
684-
relation.add_label(typename="relation", value=self._remap_label(label))
685+
remapped = self._remap_label(label)
686+
if remapped != "O":
687+
relation.add_label(typename="relation", value=remapped)
685688

686689
if len(sentence) > 0:
687690
return sentence
@@ -719,15 +722,17 @@ def _parse_token(self, line: str, column_name_map: Dict[int, str], last_token: O
719722
# add each other feature as label-value pair
720723
label_name = feature.split("=")[0]
721724
label_value = self._remap_label(feature.split("=")[1])
722-
token.add_label(label_name, label_value)
725+
if label_value != "O":
726+
token.add_label(label_name, label_value)
723727

724728
else:
725729
# get the task name (e.g. 'ner')
726730
label_name = column_name_map[column]
727731
# get the label value
728732
label_value = self._remap_label(fields[column])
729733
# add label
730-
token.add_label(label_name, label_value)
734+
if label_value != "O":
735+
token.add_label(label_name, label_value)
731736

732737
if column_name_map[column] == self.SPACE_AFTER_KEY and fields[column] == "-":
733738
token.whitespace_after = False

0 commit comments

Comments
 (0)