Skip to content

Commit

Permalink
Add projective order
Browse files Browse the repository at this point in the history
  • Loading branch information
yzhangcs committed Jul 5, 2023
1 parent a719a06 commit 266b04b
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 9 deletions.
11 changes: 11 additions & 0 deletions docs/source/refs.bib
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,17 @@ @inproceedings{smith-eisner-2008-dependency
pages = {145--156}
}

@inproceedings{nivre-2009-non,
title = {Non-Projective Dependency Parsing in Expected Linear Time},
author = {Nivre, Joakim},
booktitle = {Proceedings of ACL},
year = {2009},
url = {https://aclanthology.org/P09-1040},
address = {Suntec, Singapore},
publisher = {Association for Computational Linguistics},
pages = {351--359}
}

@inproceedings{yarin-etal-2016-dropout,
title = {Dropout as a Bayesian Approximation: Representing Model Uncertainty in Deep Learning},
author = {Gal, Yarin and Ghahramani, Zoubin},
Expand Down
54 changes: 45 additions & 9 deletions supar/models/dep/biaffine/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import os
from io import StringIO
from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Iterable, Optional, Sequence, Tuple, Union

from supar.utils.logging import get_logger
from supar.utils.tokenizer import Tokenizer
Expand Down Expand Up @@ -132,12 +132,12 @@ def build_relations(cls, chart):
return sequence

@classmethod
def toconll(cls, tokens: List[Union[str, Tuple]]) -> str:
def toconll(cls, tokens: Sequence[Union[str, Tuple]]) -> str:
r"""
Converts a list of tokens to a string in CoNLL-X format with missing fields filled with underscores.
Args:
tokens (List[Union[str, Tuple]]):
tokens (Sequence[Union[str, Tuple]]):
This can be either a list of words, word/pos pairs or word/lemma/pos triples.
Returns:
Expand Down Expand Up @@ -178,7 +178,7 @@ def toconll(cls, tokens: List[Union[str, Tuple]]) -> str:
return s + '\n'

@classmethod
def isprojective(cls, sequence: List[int]) -> bool:
def isprojective(cls, sequence: Sequence[int]) -> bool:
r"""
Checks if a dependency tree is projective.
This also works for partial annotation.
Expand All @@ -187,7 +187,7 @@ def isprojective(cls, sequence: List[int]) -> bool:
which are hard to detect in the scenario of partial annotation.
Args:
sequence (List[int]):
sequence (Sequence[int]):
A list of head indices.
Returns:
Expand All @@ -213,12 +213,12 @@ def isprojective(cls, sequence: List[int]) -> bool:
return True

@classmethod
def istree(cls, sequence: List[int], proj: bool = False, multiroot: bool = False) -> bool:
def istree(cls, sequence: Sequence[int], proj: bool = False, multiroot: bool = False) -> bool:
r"""
Checks if the arcs form an valid dependency tree.
Args:
sequence (List[int]):
sequence (Sequence[int]):
A list of head indices.
proj (bool):
If ``True``, requires the tree to be projective. Default: ``False``.
Expand Down Expand Up @@ -247,6 +247,42 @@ def istree(cls, sequence: List[int], proj: bool = False, multiroot: bool = False
return False
return next(tarjan(sequence), None) is None

@classmethod
def projective_order(cls, sequence: Sequence[int]) -> Sequence:
r"""
Returns the projective order corresponding to the tree :cite:`nivre-2009-non`.
Args:
sequence (Sequence[int]):
A list of head indices.
Returns:
The projective order of the tree.
Examples:
>>> CoNLL.projective_order([2, 0, 2, 3])
[1, 2, 3, 4]
>>> CoNLL.projective_order([3, 0, 0, 3])
[2, 1, 3, 4]
>>> CoNLL.projective_order([2, 3, 0, 3, 2, 7, 5, 4, 3])
[1, 2, 5, 6, 7, 3, 4, 8, 9]
"""

adjs = [[] for _ in range(len(sequence) + 1)]
for dep, head in enumerate(sequence, 1):
adjs[head].append(dep)

def order(adjs, head):
i = 0
for dep in adjs[head]:
if head < dep:
break
i += 1
left = [j for dep in adjs[head][:i] for j in order(adjs, dep)]
right = [j for dep in adjs[head][i:] for j in order(adjs, dep)]
return left + [head] + right
return [i for head in adjs[0] for i in order(adjs, head)]

def load(
self,
data: Union[str, Iterable],
Expand Down Expand Up @@ -313,7 +349,7 @@ class CoNLLSentence(Sentence):
Args:
transform (CoNLL):
A :class:`~supar.utils.transform.CoNLL` object.
lines (List[str]):
lines (Sequence[str]):
A list of strings composing a sentence in CoNLL-X format.
Comments and non-integer IDs are permitted.
index (Optional[int]):
Expand Down Expand Up @@ -355,7 +391,7 @@ class CoNLLSentence(Sentence):
12 . _ _ _ _ 3 punct _ _
"""

def __init__(self, transform: CoNLL, lines: List[str], index: Optional[int] = None) -> CoNLLSentence:
def __init__(self, transform: CoNLL, lines: Sequence[str], index: Optional[int] = None) -> CoNLLSentence:
super().__init__(transform, index)

self.values = []
Expand Down

0 comments on commit 266b04b

Please sign in to comment.