Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add text length #116

Merged
merged 16 commits into from
Sep 27, 2024
Next Next commit
add text length metric
  • Loading branch information
Wenshansilvia committed Sep 24, 2024
commit b399a70b6553da851adbb8da2e719edbae900e5a
12 changes: 9 additions & 3 deletions rageval/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,22 @@
from .answer_correctness._answer_disambig_f1 import AnswerDisambigF1Correctness
from .answer_correctness._answer_lcs_ratio import AnswerLCSRatio
from .answer_correctness._answer_ter import AnswerTERCorrectness
##from .answer_correctness._answer_relevancy import AnswerRelevancy

# Metrics about the answer groundedness
from .answer_groundedness._answer_citation_precision import AnswerCitationPrecision
from .answer_groundedness._answer_citation_recall import AnswerCitationRecall
from .answer_groundedness._context_reject_rate import ContextRejectRate
##from .answer_groundedness._claim_faithfulness import ClaimFaithfulness

# Metrics about the answer informativeness

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

建议文件夹的名字写”answer_informativeness“

# Metrics about the context relevancy
from .context_relevancy._context_recall import ContextRecall
##from .answer_informative._claim_num import ClaimNum
from .answer_informative._text_length import TextLength
##from .answer_informative._repetitiveness import Repetitiveness
##from .answer_informative._pairwise_accuracy import PairwiseAccuracy

# Metrics about the context aduquacy
from .context_adequacy._context_recall import ContextRecall

# Metrics about the context relevance

Empty file.
Empty file.
Empty file.
Empty file.
Empty file.
95 changes: 95 additions & 0 deletions rageval/metrics/answer_informative/_text_length.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
from dataclasses import dataclass
from typing import List
import numpy as np

import datasets

from rageval.metrics import Metric, add_attribute


_DESCRIPTION = """\
Textlength is a metric used to evaluate the length of a model-generated response.

It measures the number of tokens in the generated text by first converting the text into tokens and then counting the total number. This metric provides insight into the verbosity or conciseness of the model's output, offering a standardized way to compare text length across different responses.
"""

_KWARGS_DESCRIPTION = """\
Args:
name : str

Optional Args:
None

Functions:
_compute_one: Evaluating the length of answer.

Examples:
>>> from datasets import Dataset
>>> import rageval as rl
>>> sample = {
... "answers": [
... "A",
... "C",
... ]
... }
>>> dataset = Dataset.from_dict(sample)
>>> tokenize_model = rl.models.Tokenizer("Qwen/Qwen2-0.5B-Instruct")
>>> metric = rl.metrics.TextLength(tokenize_model=tokenize_model)
>>> metric.mtype
'answer_informative'
"""


@dataclass
@add_attribute('mtype', 'answer_informative')
@datasets.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
class TextLength(Metric):
"""Estimates the text length of answers."""

name = "text_length"

ALIAS = ['text_length']

def __init__(self, tokenize_model: str = "Qwen/Qwen2-0.5B-Instruct"):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

感觉你这里是不是Tokenizer不用写一个类,直接在 class TextLength里面一个成员就好了

"""
Explicitly initialize TextLength.

Ensure all parent classes are initialized.
"""
self.tokenize_model = tokenize_model
super().__init__()

def __repr__(self) -> str:
""":return: Formatted string representation of the metric."""
return f"{self.ALIAS[0]}"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这一行后面添加一下注释:”# pragma: no cover“可以跳过codecov的覆盖率检测,避免项目的覆盖率降低 PR不过~


def _info(self):
return datasets.MetricInfo(
description=_DESCRIPTION,
inputs_description=_KWARGS_DESCRIPTION,
citation="",
homepage="",
features=datasets.Features(
{
"answers": datasets.Value("string"),
}
),
codebase_urls=[],
reference_urls=[]
)

def _compute_one(
self,
answer: str,
) -> float:
"""Evaluating the text length of answer."""
length = len(self.tokenize_model.tokenizer(answer, return_tensors="pt")['input_ids'][0])
return length

def _compute_batch(
bugtig6351 marked this conversation as resolved.
Show resolved Hide resolved
self,
pred_answers,
) -> List[float]:
"""Evaluate the text length of a batch of answers."""
results = [self._compute_one(answer) for answer in pred_answers]
return results
43 changes: 29 additions & 14 deletions rageval/metrics/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,21 +85,36 @@ def compute(

Return average scores of all inputs and a score list for each example.
"""
self._validate_data(pred_answers, ref_answers, *args)
scores = []
length = len(pred_answers)
if batch_size:
for start in tqdm(range(0, length, batch_size)):
end = start + batch_size
end = end if end < length else length
score = self._compute_batch(
pred_answers[start:end],
ref_answers[start:end],
*[arg[start:end] for arg in args],
)
scores.extend(score)
if ref_answers:
self._validate_data(pred_answers, ref_answers, *args)
scores = []
length = len(pred_answers)
if batch_size:
for start in tqdm(range(0, length, batch_size)):
end = start + batch_size
end = end if end < length else length
score = self._compute_batch(
pred_answers[start:end],
ref_answers[start:end],
*[arg[start:end] for arg in args],
)
scores.extend(score)
else:
scores = self._compute_batch(pred_answers, ref_answers, *args)
else:
scores = self._compute_batch(pred_answers, ref_answers, *args)
scores = []
length = len(pred_answers)
if batch_size:
for start in tqdm(range(0, length, batch_size)):
end = start + batch_size
end = end if end < length else length
score = self._compute_batch(
pred_answers[start:end],
*[arg[start:end] for arg in args],
)
scores.extend(score)
else:
scores = self._compute_batch(pred_answers, *args)

return np.average(scores), scores

Expand Down
Empty file.
Empty file.
Empty file.
Empty file.
1 change: 1 addition & 0 deletions rageval/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .openai import OpenAILLM
from .nli import NLIModel
from .tokenizer import Tokenizer
15 changes: 15 additions & 0 deletions rageval/models/tokenizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import logging
from abc import ABC

import pytest
from transformers import AutoModelForCausalLM, AutoTokenizer

logger = logging.getLogger(__name__)

class Tokenizer(ABC):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个tokenizer只有一个__init__函数,就不用写成一个类了,感觉在用的时候直接实例调用就可以了

"""This is the hugging face tokenizer model."""

def __init__(self, model: str = "Qwen/Qwen2-0.5B-Instruct") -> None:
"""Init the Model."""
self._model_name = model
self.tokenizer = AutoTokenizer.from_pretrained(model)
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,4 @@ sentencepiece == 0.2.0
protobuf == 4.25.3
sacrebleu == 2.3.3
bert_score == 0.3.13
transformers
38 changes: 38 additions & 0 deletions test_text_length.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import pytest
from datasets import Dataset

from rageval.metrics import TextLength
import rageval as rl


#@pytest.fixture(scope='module')
def sample():
test_case = {
#"questions": [
# "习近平主席在何时何地会见了埃塞俄比亚总理海尔马里亚姆?",
# "埃塞俄比亚希望与中国在哪些领域加强合作?"
#],
"answers": [
"习近平主席在2017年5月12日于北京人民大会堂会见了埃塞俄比亚总理海尔马里亚姆。",
"埃塞俄比亚希望与中国在以下领域加强合作:\n1. **共建“一带一路”框架下合作**:埃塞俄比亚表示希望能够积极参与“一带一路”倡议,深化与中国在基础设施建设、产能合作、互联互通等领域的合作。\n2. **提高工业化水平和出口创汇能力**:埃塞俄比亚期待中国在推动其工业化进程以及提升出口创汇能力方面提供帮助和合作。\n3. **安全、有序、有效推进经贸合作**:希望与中方在贸易和投资合作领域取得进展,实现稳定、有序和高效的合作。"
]
}
return test_case


#@pytest.fixture(scope='module')
def testset(sample):
ds = Dataset.from_dict(sample)
return ds


#@pytest.mark.slow
def test_case_on_text_length(testset):
tokenize_model = rl.models.Tokenizer("Qwen/Qwen2-0.5B-Instruct")
metric = TextLength(tokenize_model=tokenize_model)
assert metric.name == "text_length"
score, results = metric.compute(testset["answers"], batch_size = 1)
print(score, results)
assert score == 75.0

test_case_on_text_length(testset(sample()))