Skip to content

Commit

Permalink
Support list inputs
Browse files Browse the repository at this point in the history
  • Loading branch information
yzhangcs committed Aug 28, 2023
1 parent 9bdf967 commit 8a9fd99
Showing 1 changed file with 15 additions and 7 deletions.
22 changes: 15 additions & 7 deletions supar/utils/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@
import os
import tempfile
from collections import Counter
from typing import Dict, List, Optional, Tuple
from typing import Dict, List, Optional, Tuple, Union

import torch

from supar.utils.fn import pad


class Metric(object):

Expand Down Expand Up @@ -73,8 +75,8 @@ class AttachmentMetric(Metric):
def __init__(
self,
loss: Optional[float] = None,
preds: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
golds: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
preds: Optional[Tuple[Union[List, torch.Tensor], Union[List, torch.Tensor]]] = None,
golds: Optional[Tuple[Union[List, torch.Tensor], Union[List, torch.Tensor]]] = None,
mask: Optional[torch.BoolTensor] = None,
reverse: bool = False,
eps: float = 1e-12
Expand All @@ -93,14 +95,20 @@ def __init__(
def __call__(
self,
loss: float,
preds: Tuple[torch.Tensor, torch.Tensor],
golds: Tuple[torch.Tensor, torch.Tensor],
preds: Tuple[Union[List, torch.Tensor], Union[List, torch.Tensor]],
golds: Tuple[Union[List, torch.Tensor], Union[List, torch.Tensor]],
mask: torch.BoolTensor
) -> AttachmentMetric:
lens = mask.sum(1)
arc_preds, rel_preds, arc_golds, rel_golds = *preds, *golds
arc_mask = arc_preds.eq(arc_golds) & mask
rel_mask = rel_preds.eq(rel_golds) & arc_mask
if isinstance(arc_preds, torch.Tensor):
arc_mask = arc_preds.eq(arc_golds)
rel_mask = rel_preds.eq(rel_golds)
else:
arc_mask = pad([mask.new_tensor([i == j for i, j in zip(pred, gold)]) for pred, gold in zip(arc_preds, arc_golds)])
rel_mask = pad([mask.new_tensor([i == j for i, j in zip(pred, gold)]) for pred, gold in zip(rel_preds, rel_golds)])
arc_mask = arc_mask & mask
rel_mask = rel_mask & arc_mask
arc_mask_seq, rel_mask_seq = arc_mask[mask], rel_mask[mask]

self.n += len(mask)
Expand Down

0 comments on commit 8a9fd99

Please sign in to comment.