Skip to content

Commit

Permalink
Implement metric for discontinuous trees
Browse files Browse the repository at this point in the history
  • Loading branch information
yzhangcs committed Aug 14, 2023
1 parent 266b04b commit f33fb25
Showing 1 changed file with 117 additions and 0 deletions.
117 changes: 117 additions & 0 deletions supar/utils/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from __future__ import annotations

import os
import tempfile
from collections import Counter
from typing import Dict, List, Optional, Tuple

Expand Down Expand Up @@ -258,6 +260,121 @@ def values(self) -> Dict:
'LF': self.lf}


class DiscontinuousSpanMetric(Metric):

def __init__(
self,
loss: Optional[float] = None,
preds: Optional[List[List[Tuple]]] = None,
golds: Optional[List[List[Tuple]]] = None,
param: Optional[str] = None,
reverse: bool = False,
eps: float = 1e-12
) -> DiscontinuousSpanMetric:
super().__init__(reverse=reverse, eps=eps)

self.tp = 0.0
self.pred = 0.0
self.gold = 0.0
self.dtp = 0.0
self.dpred = 0.0
self.dgold = 0.0

if loss is not None:
self(loss, preds, golds, param)

def __call__(
self,
loss: float,
preds: List[List[Tuple]],
golds: List[List[Tuple]],
param: str = None
) -> DiscontinuousSpanMetric:
self.n += len(preds)
self.count += 1
self.total_loss += float(loss)
with tempfile.TemporaryDirectory() as ftemp:
fpred, fgold = os.path.join(ftemp, 'pred'), os.path.join(ftemp, 'gold')
with open(fpred, 'w') as f:
for pred in preds:
f.write(pred.pformat(1000000) + '\n')
with open(fgold, 'w') as f:
for gold in golds:
f.write(gold.pformat(1000000) + '\n')

from discodop.eval import Evaluator, readparam
from discodop.tree import bitfanout
from discodop.treebank import DiscBracketCorpusReader
preds = DiscBracketCorpusReader(fpred, encoding='utf8', functions='remove')
golds = DiscBracketCorpusReader(fgold, encoding='utf8', functions='remove')
goldtrees, goldsents = golds.trees(), golds.sents()
candtrees, candsents = preds.trees(), preds.sents()

evaluator = Evaluator(readparam(param), max(len(str(key)) for key in candtrees))
for n, ctree in candtrees.items():
evaluator.add(n, goldtrees[n], goldsents[n], ctree, candsents[n])
cpreds, cgolds = evaluator.acc.candb, evaluator.acc.goldb
dpreds, dgolds = (Counter([i for i in c.elements() if bitfanout(i[1][1]) > 1]) for c in (cpreds, cgolds))
self.tp += sum((cpreds & cgolds).values())
self.pred += sum(cpreds.values())
self.gold += sum(cgolds.values())
self.dtp += sum((dpreds & dgolds).values())
self.dpred += sum(dpreds.values())
self.dgold += sum(dgolds.values())
return self

def __add__(self, other: DiscontinuousSpanMetric) -> DiscontinuousSpanMetric:
metric = DiscontinuousSpanMetric(eps=self.eps)
metric.n = self.n + other.n
metric.count = self.count + other.count
metric.total_loss = self.total_loss + other.total_loss
metric.tp = self.tp + other.tp
metric.pred = self.pred + other.pred
metric.gold = self.gold + other.gold
metric.dtp = self.dtp + other.dtp
metric.dpred = self.dpred + other.dpred
metric.dgold = self.dgold + other.dgold
metric.reverse = self.reverse or other.reverse
return metric

@property
def score(self):
return self.f

@property
def p(self):
return self.tp / (self.pred + self.eps)

@property
def r(self):
return self.tp / (self.gold + self.eps)

@property
def f(self):
return 2 * self.tp / (self.pred + self.gold + self.eps)

@property
def dp(self):
return self.dtp / (self.dpred + self.eps)

@property
def dr(self):
return self.dtp / (self.dgold + self.eps)

@property
def df(self):
return 2 * self.dtp / (self.dpred + self.dgold + self.eps)

@property
def values(self) -> Dict:
return {'P': self.p,
'R': self.r,
'F': self.f,
'DP': self.dp,
'DR': self.dr,
'DF': self.df}


class ChartMetric(Metric):

def __init__(
Expand Down

0 comments on commit f33fb25

Please sign in to comment.