-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'main' into add_text_length
- Loading branch information
Showing
7 changed files
with
272 additions
and
20 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
114 changes: 114 additions & 0 deletions
114
rageval/metrics/answer_informativeness/_answer_distinct12.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
from collections import Counter | ||
from dataclasses import dataclass | ||
from typing import List, Optional, Iterable, Tuple | ||
import datasets | ||
from nltk import ngrams | ||
from rageval.metrics import Metric, add_attribute | ||
|
||
_DESCRIPTION = """\ | ||
Distinct 1/2 measures the diversity of generated text by calculating the ratio of unique n-grams to the total number of n-grams. | ||
""" | ||
|
||
_KWARGS_DESCRIPTION = """\ | ||
Args: | ||
pred_answers (list of str): List of generated texts for which distinct metrics are computed. | ||
n_grams (int): The n-gram order for which distinct metrics are computed. | ||
Returns: | ||
dict: Dictionary containing Distinct-1 and Distinct-2 scores. | ||
Examples: | ||
>>> from datasets import Dataset | ||
>>> import rageval as rl | ||
>>> sample = { | ||
... "answers": [ | ||
... "This is the first sentence.", | ||
... "This is the second sentence." | ||
... ] | ||
... } | ||
>>> dataset = Dataset.from_dict(sample) | ||
>>> metric = rl.metrics.AnswerDistinct(1) | ||
>>> metric.mtype | ||
'AnswerInformativeness' | ||
>>> score, results = metric.compute(dataset['answers']) | ||
>>> score | ||
0.6 | ||
""" | ||
|
||
_CITATION = """\ | ||
@misc{selfmemory2023, | ||
title={Lift Yourself Up: Retrieval-augmented Text Generation with Self Memory}, | ||
author={Xin Cheng and Di Luo and Xiuying Chen and Lemao Liu and Dongyan Zhao and Rui Yan}, | ||
year={2023}, | ||
eprint={2305.02437}, | ||
archivePrefix={arXiv}, | ||
primaryClass={cs.CL} | ||
} | ||
""" | ||
|
||
|
||
def get_distinct_score(pred_answers: List[str], n_grams: int) -> dict: | ||
"""Compute Distinct-1 and Distinct-2 metrics.""" | ||
c = Counter() | ||
for answer in pred_answers: | ||
tokens = answer.split() | ||
c.update(ngrams(tokens, n_grams)) | ||
|
||
return len(c) / sum(c.values()) | ||
|
||
|
||
@dataclass | ||
@add_attribute('mtype', 'AnswerInformativeness') | ||
@datasets.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) | ||
class AnswerDistinct(Metric): | ||
"""Distinct 1/2 metric for text generation.""" | ||
|
||
name = "answer_distinct" | ||
|
||
ALIAS = ['answer_distinct'] | ||
|
||
def __init__(self, n_grams: int = 1): | ||
""" | ||
Explicitly initialize Distinct. | ||
Ensure all parent classes are initialized. | ||
""" | ||
super().__init__() | ||
self.n_grams = n_grams | ||
|
||
def __repr__(self) -> str: | ||
""":return: Formatted string representation of the metric.""" | ||
return f"{self.ALIAS[0]}" | ||
|
||
def _info(self): | ||
return datasets.MetricInfo( | ||
description=_DESCRIPTION, | ||
inputs_description=_KWARGS_DESCRIPTION, | ||
citation=_CITATION, | ||
features=datasets.Features( | ||
{ | ||
"pred_answers": datasets.Value("string"), | ||
} | ||
), | ||
codebase_urls=["https://github.com/Hannibal046/SelfMemory/blob/main/src/utils/metrics_utils.py"], | ||
reference_urls=["https://arxiv.org/abs/2305.02437"] | ||
) | ||
|
||
def _validate_data( | ||
self, | ||
pred_answers: Optional[Iterable] = None, | ||
ref_answers: Optional[Iterable] = None, | ||
) -> bool: | ||
"""Validate the input data.""" | ||
assert isinstance(pred_answers, str) or isinstance(pred_answers, list) # pragma: no cover | ||
|
||
def compute( | ||
self, | ||
pred_answers: Optional[Iterable] = None, | ||
) -> Tuple[float, List[float]]: | ||
""" | ||
Evaluate the dataset. | ||
Return average scores of all inputs and a score list for each example. | ||
""" | ||
return get_distinct_score(pred_answers, self.n_grams), [get_distinct_score([pred_answer], self.n_grams) for pred_answer in pred_answers] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -25,3 +25,4 @@ protobuf == 4.25.3 | |
sacrebleu == 2.3.3 | ||
bert_score == 0.3.13 | ||
transformers | ||
jieba >= 0.42.1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
import pytest | ||
from datasets import Dataset | ||
|
||
from rageval.metrics import AnswerDistinct | ||
|
||
|
||
@pytest.fixture(scope='module') | ||
def sample(): | ||
test_case = { | ||
"answers": [ | ||
"Ali Dael has the highest goals in men's world international football with 109 goals. Josef Bican has the \ | ||
highest goals all-time in men's football and Christine Sinclair has the highest goals in women's world \ | ||
international football.", | ||
"A supercentenarian is someone who has reached the age of 110. Sarah Knauss, whose age is undisputed, was \ | ||
the oldest person ever from the United States and the second-oldest fully documented person ever. Jeanne \ | ||
Calment was a French supercentenarian and the oldest human whose age is well-documented, with a lifespan \ | ||
of 122 years and 164 days, and was the oldest person in the world as of 1997." | ||
] | ||
} | ||
return test_case | ||
|
||
|
||
@pytest.fixture(scope='module') | ||
def testset(sample): | ||
ds = Dataset.from_dict(sample) | ||
return ds | ||
|
||
|
||
@pytest.mark.slow | ||
def test_case_on_answer_distinct(testset): | ||
metric = AnswerDistinct(n_grams=1) | ||
assert metric.name == "answer_distinct" | ||
repr(metric) == 'answer_distinct' | ||
assert metric.mtype == 'AnswerInformativeness' | ||
score, results = metric.compute(pred_answers=testset['answers']) | ||
assert 0 <= score <= 1 | ||
|
||
metric = AnswerDistinct(n_grams=2) | ||
assert metric.name == "answer_distinct" | ||
repr(metric) == 'answer_distinct' | ||
assert metric.mtype == 'AnswerInformativeness' | ||
score, results = metric.compute(pred_answers=testset['answers']) | ||
assert 0 <= score <= 1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters