Skip to content

Commit

Permalink
Deal with subtypes
Browse files Browse the repository at this point in the history
  • Loading branch information
yzhangcs committed Sep 1, 2023
1 parent 923042e commit e98d704
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions supar/utils/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def __init__(
preds: Optional[Tuple[Union[List, torch.Tensor], Union[List, torch.Tensor]]] = None,
golds: Optional[Tuple[Union[List, torch.Tensor], Union[List, torch.Tensor]]] = None,
mask: Optional[torch.BoolTensor] = None,
subtype: Optional[bool] = True,
reverse: bool = False,
eps: float = 1e-12
) -> AttachmentMetric:
Expand All @@ -90,21 +91,25 @@ def __init__(
self.correct_rels = 0.0

if loss is not None:
self(loss, preds, golds, mask)
self(loss, preds, golds, mask, subtype)

def __call__(
self,
loss: float,
preds: Tuple[Union[List, torch.Tensor], Union[List, torch.Tensor]],
golds: Tuple[Union[List, torch.Tensor], Union[List, torch.Tensor]],
mask: torch.BoolTensor
mask: Optional[torch.BoolTensor] = None,
subtype: Optional[bool] = True
) -> AttachmentMetric:
lens = mask.sum(1)
arc_preds, rel_preds, arc_golds, rel_golds = *preds, *golds
if isinstance(arc_preds, torch.Tensor):
arc_mask = arc_preds.eq(arc_golds)
rel_mask = rel_preds.eq(rel_golds)
else:
if not subtype:
rel_preds = [[i.split(':', 1)[0] for i in rels] for rels in rel_preds]
rel_golds = [[i.split(':', 1)[0] for i in rels] for rels in rel_golds]
arc_mask = pad([mask.new_tensor([i == j for i, j in zip(pred, gold)]) for pred, gold in zip(arc_preds, arc_golds)])
rel_mask = pad([mask.new_tensor([i == j for i, j in zip(pred, gold)]) for pred, gold in zip(rel_preds, rel_golds)])
arc_mask = arc_mask & mask
Expand Down

0 comments on commit e98d704

Please sign in to comment.