Skip to content

Commit

Permalink
Merge branch 'main' into add_text_length
Browse files Browse the repository at this point in the history
  • Loading branch information
Wenshansilvia authored Sep 26, 2024
2 parents cbdaca3 + e7726c7 commit b49add6
Show file tree
Hide file tree
Showing 7 changed files with 272 additions and 20 deletions.
8 changes: 6 additions & 2 deletions rageval/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,12 @@
# Metrics about the answer informativeness
##from .answer_informative._claim_num import ClaimNum
from .answer_informativeness._text_length import TextLength
##from .answer_informative._repetitiveness import Repetitiveness
##from .answer_informative._pairwise_accuracy import PairwiseAccuracy
##from .answer_informativeness._repetitiveness import Repetitiveness
##from .answer_informativeness._pairwise_accuracy import PairwiseAccuracy
from .answer_informativeness._answer_distinct12 import AnswerDistinct

# Metrics about the context relevancy
from .context_relevancy._context_recall import ContextRecall

# Metrics about the context aduquacy
from .context_adequacy._context_recall import ContextRecall
Expand Down
91 changes: 79 additions & 12 deletions rageval/metrics/answer_correctness/_answer_f1.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,31 +2,36 @@
import string
from collections import Counter
from dataclasses import dataclass
from typing import List, Optional, Iterable, Union

import datasets
import numpy as np
from typing import Union, Iterable, List
import jieba

from rageval.metrics import Metric, add_attribute


_DESCRIPTION = """\
F1 score combines precision and recall into a single score using their harmonic mean.
"""

_KWARGS_DESCRIPTION = """\
Args:
name : str
normalize : bool, default is True, whether to normalize the text. If False, the text will be treated as a list of tokens.
language : str, default is 'en', the language of the text. Supported languages are 'en' and 'zh'.
Optional Args:
None
Functions:
_normalize_text: normalize the text by removing articles, white spaces, punctuations and lowercasing.
_validate_data: validate the dataset format.
_f1_score: compute the f1 score between `pred` string and `ref` string.
_f1_score: compute the f1 score between `pred` tokens and `ref` tokens.
_compute_one: evaluate the f1 score of between `answer` and `gt_answers`, return the highest score in all pairs.
Examples:
English:
>>> from datasets import Dataset
>>> import rageval as rl
>>> sample = {
Expand All @@ -44,8 +49,50 @@
>>> metric = rl.metrics.AnswerF1Correctness()
>>> metric.mtype
'AnswerCorrectness'
>>> score, results = metric.compute(dataset['answers'], dataset['gt_answers'], 1)
>>> assert 0 <= score <= 1
>>> score, results = metric.compute(dataset['answers'], dataset['gt_answers'])
>>> round(score, 2)
0.18
Chinese:
>>> from datasets import Dataset
>>> import rageval as rl
>>> sample = {
... "answers": [
... "督邮,中国古代职官名,自汉代开始设置。",
... "魏晋",
... "北齐只设于清都郡。",
... "隋代",
... ],
... "gt_answers": [
... ["督邮,中国古代职官名,自汉代开始设置。"],
... ["魏晋", "魏晋时期"],
... ["北齐只设于清都郡。", "清都郡"],
... ["隋代", "隋朝"]
... ]
... }
>>> dataset = Dataset.from_dict(sample)
>>> metric = rl.metrics.AnswerF1Correctness(language='zh')
>>> metric.mtype
'AnswerCorrectness'
>>> score, results = metric.compute(dataset['answers'], dataset['gt_answers'])
>>> round(score, 2)
1.0
Other Iterables:
>>> from datasets import Dataset
>>> import rageval as rl
>>> sample = {
... "answers": [[1,2,3], [4,5,6]],
... "gt_answers": [[2,3,4,5,6], [1,2,3,4,5]]
... }
>>> dataset = Dataset.from_dict(sample)
>>> metric = rl.metrics.AnswerF1Correctness(normalize=False)
>>> metric.mtype
'AnswerCorrectness'
>>> score, results = metric.compute(dataset['answers'], dataset['gt_answers'])
>>> round(score, 2)
0.5
"""

_CITATION = """\
Expand All @@ -63,14 +110,15 @@ class AnswerF1Correctness(Metric):

ALIAS = ['answer_f1']

def __init__(self, normalize: bool = True):
def __init__(self, normalize: bool = True, language: Optional[str] = "en"):
"""
Explicitly initialize AnswerF1Correctness.
Ensure all parent classes are initialized.
"""
super().__init__()
self.normalize = normalize
self.language = language

def __repr__(self) -> str:
""":return: Formatted string representation of the metric."""
Expand Down Expand Up @@ -103,10 +151,21 @@ def lower(text):
return text.lower()
return remove_articles(remove_punc(lower(s))).split()

def _f1_score(self, preds: Iterable, refs: Iterable) -> float:
def _normalize_text_zh(self, s: str) -> str:
"""Normalize Chinese text."""
def white_space_fix(text):
return ' '.join(text.split())

def remove_punc(text):
exclude = set(string.punctuation) | {',', '。', '?', '!', ':', ';', '“', '”', '‘', '’', '(', ')', '《', '》', '——', '……', '、'}
return ''.join(ch for ch in text if ch not in exclude)

return white_space_fix(remove_punc(s))

def _f1_score(self, pred: Iterable, ref: Iterable) -> float:
"""Compute the f1 score between pred and ref."""
pred_counter = Counter(preds)
ref_counter = Counter(refs)
pred_counter = Counter(pred)
ref_counter = Counter(ref)

tp = sum((pred_counter & ref_counter).values())
fp = sum((pred_counter - ref_counter).values())
Expand All @@ -126,9 +185,17 @@ def _compute_one(
) -> float:
"""Evaluate the f1 score of an answer."""
if self.normalize:
pred_answer = self._normalize_text(pred_answer)
ref_answers = [self._normalize_text(ref_answer) for ref_answer in ref_answers]

scores = [self._f1_score(pred_answer, ref_answer) for ref_answer in ref_answers]
# str, List[str] -> List[str], List[List[str]]
if self.language == "en":
preds = self._normalize_text(pred_answer)
refs = [self._normalize_text(ref_answer) for ref_answer in ref_answers]
elif self.language == "zh":
preds = list(jieba.cut(self._normalize_text_zh(pred_answer)))
refs = [list(jieba.cut(self._normalize_text_zh(ref_answer))) for ref_answer in ref_answers]
else:
raise Exception('Unsupported language: {}'.format(self.language)) # pragma: no cover
scores = [self._f1_score(preds, ref) for ref in refs]
else:
scores = self._f1_score(pred_answer, ref_answers)

return np.max(scores)
114 changes: 114 additions & 0 deletions rageval/metrics/answer_informativeness/_answer_distinct12.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
from collections import Counter
from dataclasses import dataclass
from typing import List, Optional, Iterable, Tuple
import datasets
from nltk import ngrams
from rageval.metrics import Metric, add_attribute

_DESCRIPTION = """\
Distinct 1/2 measures the diversity of generated text by calculating the ratio of unique n-grams to the total number of n-grams.
"""

_KWARGS_DESCRIPTION = """\
Args:
pred_answers (list of str): List of generated texts for which distinct metrics are computed.
n_grams (int): The n-gram order for which distinct metrics are computed.
Returns:
dict: Dictionary containing Distinct-1 and Distinct-2 scores.
Examples:
>>> from datasets import Dataset
>>> import rageval as rl
>>> sample = {
... "answers": [
... "This is the first sentence.",
... "This is the second sentence."
... ]
... }
>>> dataset = Dataset.from_dict(sample)
>>> metric = rl.metrics.AnswerDistinct(1)
>>> metric.mtype
'AnswerInformativeness'
>>> score, results = metric.compute(dataset['answers'])
>>> score
0.6
"""

_CITATION = """\
@misc{selfmemory2023,
title={Lift Yourself Up: Retrieval-augmented Text Generation with Self Memory},
author={Xin Cheng and Di Luo and Xiuying Chen and Lemao Liu and Dongyan Zhao and Rui Yan},
year={2023},
eprint={2305.02437},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
"""


def get_distinct_score(pred_answers: List[str], n_grams: int) -> dict:
"""Compute Distinct-1 and Distinct-2 metrics."""
c = Counter()
for answer in pred_answers:
tokens = answer.split()
c.update(ngrams(tokens, n_grams))

return len(c) / sum(c.values())


@dataclass
@add_attribute('mtype', 'AnswerInformativeness')
@datasets.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
class AnswerDistinct(Metric):
"""Distinct 1/2 metric for text generation."""

name = "answer_distinct"

ALIAS = ['answer_distinct']

def __init__(self, n_grams: int = 1):
"""
Explicitly initialize Distinct.
Ensure all parent classes are initialized.
"""
super().__init__()
self.n_grams = n_grams

def __repr__(self) -> str:
""":return: Formatted string representation of the metric."""
return f"{self.ALIAS[0]}"

def _info(self):
return datasets.MetricInfo(
description=_DESCRIPTION,
inputs_description=_KWARGS_DESCRIPTION,
citation=_CITATION,
features=datasets.Features(
{
"pred_answers": datasets.Value("string"),
}
),
codebase_urls=["https://github.com/Hannibal046/SelfMemory/blob/main/src/utils/metrics_utils.py"],
reference_urls=["https://arxiv.org/abs/2305.02437"]
)

def _validate_data(
self,
pred_answers: Optional[Iterable] = None,
ref_answers: Optional[Iterable] = None,
) -> bool:
"""Validate the input data."""
assert isinstance(pred_answers, str) or isinstance(pred_answers, list) # pragma: no cover

def compute(
self,
pred_answers: Optional[Iterable] = None,
) -> Tuple[float, List[float]]:
"""
Evaluate the dataset.
Return average scores of all inputs and a score list for each example.
"""
return get_distinct_score(pred_answers, self.n_grams), [get_distinct_score([pred_answer], self.n_grams) for pred_answer in pred_answers]
5 changes: 5 additions & 0 deletions rageval/metrics/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@
from langchain.schema import LLMResult
from tqdm import tqdm

import sys
import io

sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8') # for Chinese language output


def add_attribute(attribute_name, attribute_value):
"""
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,4 @@ protobuf == 4.25.3
sacrebleu == 2.3.3
bert_score == 0.3.13
transformers
jieba >= 0.42.1
43 changes: 43 additions & 0 deletions tests/units/test_answer_distinct.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import pytest
from datasets import Dataset

from rageval.metrics import AnswerDistinct


@pytest.fixture(scope='module')
def sample():
test_case = {
"answers": [
"Ali Dael has the highest goals in men's world international football with 109 goals. Josef Bican has the \
highest goals all-time in men's football and Christine Sinclair has the highest goals in women's world \
international football.",
"A supercentenarian is someone who has reached the age of 110. Sarah Knauss, whose age is undisputed, was \
the oldest person ever from the United States and the second-oldest fully documented person ever. Jeanne \
Calment was a French supercentenarian and the oldest human whose age is well-documented, with a lifespan \
of 122 years and 164 days, and was the oldest person in the world as of 1997."
]
}
return test_case


@pytest.fixture(scope='module')
def testset(sample):
ds = Dataset.from_dict(sample)
return ds


@pytest.mark.slow
def test_case_on_answer_distinct(testset):
metric = AnswerDistinct(n_grams=1)
assert metric.name == "answer_distinct"
repr(metric) == 'answer_distinct'
assert metric.mtype == 'AnswerInformativeness'
score, results = metric.compute(pred_answers=testset['answers'])
assert 0 <= score <= 1

metric = AnswerDistinct(n_grams=2)
assert metric.name == "answer_distinct"
repr(metric) == 'answer_distinct'
assert metric.mtype == 'AnswerInformativeness'
score, results = metric.compute(pred_answers=testset['answers'])
assert 0 <= score <= 1
30 changes: 24 additions & 6 deletions tests/units/test_answer_f1.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,23 +19,41 @@ def sample():
"gt_answers": [
["Daei", "Ali Daei"],
["Jeanne Calment"]
]
],
"answers_zh": [
"魏晋",
"北齐只设于清都郡。",
],
"gt_answers_zh": [
["魏晋", "魏晋时期"],
["北齐只设于清都郡。", "清都郡"]
],
"answers_num":[[1,2,3], [4,5,6]],
"gt_answers_num":[[2,3,4,5,6], [1,2,3,4,5]]
}
return test_case


@pytest.fixture(scope='module')
def testset(sample):
ds = Dataset.from_dict(sample)
return ds


@pytest.mark.slow
def test_case_on_answer_f1(testset):
metric = AnswerF1Correctness(normalize=True)
metric = AnswerF1Correctness(normalize=True, language='en')
assert metric.name == "answer_f1"
assert metric.mtype == 'AnswerCorrectness'
score, results = metric.compute(testset['answers'], testset['gt_answers'], 1)
score, results = metric.compute(testset['answers'], testset['gt_answers'])
assert 0 <= score <= 1
score = metric._compute_one(testset['answers'][0], testset['gt_answers'][0])

metric = AnswerF1Correctness(normalize=True, language='zh')
assert metric.name == "answer_f1"
assert metric.mtype == 'AnswerCorrectness'
score_zh, results_zh = metric.compute(testset['answers_zh'], testset['gt_answers_zh'])
assert 0 <= score_zh <= 1

metric = AnswerF1Correctness(normalize=False)
assert metric.name == "answer_f1"
assert metric.mtype == 'AnswerCorrectness'
score, results = metric.compute(testset['answers_num'], testset['gt_answers_num'])
assert 0 <= score <= 1

0 comments on commit b49add6

Please sign in to comment.