-
Notifications
You must be signed in to change notification settings - Fork 11
/
Copy path_context_reject_rate.py
165 lines (145 loc) · 6.72 KB
/
_context_reject_rate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
from dataclasses import dataclass
from typing import Callable, List
import datasets
from datasets import Dataset
from langchain.schema import LLMResult
from rageval.metrics import MetricWithLLM, add_attribute
from rageval.utils.prompt import REJECT_RATE_PROMPT
_DESCRIPTION = """\
ContextRejectRate is the metric to measure the unknown robustness of LLM based on the given context.
For details, see the paper: https://arxiv.org/abs/2311.09210.
"""
_KWARGS_DESCRIPTION = """\
Args:
name : str
batch_size : int, Batch size for openai completion.
model : Callable, The LLM model to use.
Optional Args:
None
Functions:
parse_llm_result: parse the results of LLM
_compute_batch: compute the score by measure how many rejected answers in all answers.
Examples:
>>> from datasets import Dataset
>>> from langchain.llms.fake import FakeListLLM
>>> import rageval as rl
>>> sample = {
... "questions": [
... "Why did Bushnell set himself on fire?",
... "Did Bushnell have a wife?"
... ],
... "contexts": [
... [
... ["An active-duty member of the U.S. Air Force has died after he set himself ablaze outside the "
... "Israeli Embassy in Washington, D.C., while declaring that he “will no longer be complicit in "
... "genocide.”"],
... ["The 25-year-old airman, Aaron Bushnell, of San Antonio, Texas, died from his injuries, the "
... "Metropolitan Police Department said Monday."],
... ["Bushnell had walked up to the embassy shortly before 1 p.m. Sunday and began livestreaming on "
... "the video streaming platform Twitch, a person familiar with the matter told The Associated "
... "Press. Law enforcement officials believe he set his phone down and then doused himself in "
... "accelerant and ignited the flames. At one point, he said he “will no longer be complicit in "
... "genocide,” the person said. The video was later removed from the platform, but law enforcement "
... "officials have obtained and reviewed a copy."]
... ],
... [
... ["An active-duty member of the U.S. Air Force has died after he set himself ablaze outside the "
... "Israeli Embassy in Washington, D.C., while declaring that he “will no longer be complicit in "
... "genocide.”"],
... ["The 25-year-old airman, Aaron Bushnell, of San Antonio, Texas, died from his injuries, the "
... "Metropolitan Police Department said Monday."],
... ["Bushnell had walked up to the embassy shortly before 1 p.m. Sunday and began livestreaming on "
... "the video streaming platform Twitch, a person familiar with the matter told The Associated "
... "Press. Law enforcement officials believe he set his phone down and then doused himself in "
... "accelerant and ignited the flames. At one point, he said he “will no longer be complicit in "
... "genocide,” the person said. The video was later removed from the platform, but law enforcement "
... "officials have obtained and reviewed a copy."]
... ],
... ]
... }
>>> dataset = Dataset.from_dict(sample)
>>> model = FakeListLLM(
... responses=[
... "Answer: An active-duty member of the U.S. Air Force has died after he set himself ablaze outside the "
... "Israeli Embassy in Washington, D.C., while declaring that he “will no longer be complicit in "
... "genocide.”",
... "Answer: sorry, cannot answer the question"
... ]
... )
>>> metric = rl.metrics.ContextRejectRate(model)
>>> metric.mtype
'AnswerGroundedness'
>>> s, ds = metric.compute(dataset, batch_size=1)
>>> assert 0 <= s <= 1
>>> type(ds)
<class 'datasets.arrow_dataset.Dataset'>
"""
_CITATION = """\
@misc{yu2023chainofnote,
title={Chain-of-Note: Enhancing Robustness in Retrieval-Augmented Language Models},
author={Wenhao Yu and Hongming Zhang and Xiaoman Pan and Kaixin Ma and Hongwei Wang and Dong Yu},
year={2023},
eprint={2311.09210},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
"""
@dataclass
@add_attribute('mtype', 'AnswerGroundedness')
@datasets.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
class ContextRejectRate(MetricWithLLM):
"""Estimates context reject rate by measuring how many rejected answers in all answers."""
name = "context_reject_rate"
ALIAS = ['context_reject_rate']
def __init__(self, model: Callable):
"""Explicitly initialize the ContextRejectRate to ensure all parent class initialized."""
super().__init__(model)
self._required_columns = ['questions', 'contexts']
def __repr__(self) -> str:
""":return: Formatted string representation of the metric."""
return f"{self.ALIAS[0]}"
def _info(self):
return datasets.MetricInfo(
description=_DESCRIPTION,
inputs_description=_KWARGS_DESCRIPTION,
citation=_CITATION,
homepage="",
features=datasets.Features(
{
"questions": datasets.Value("string"),
"contexts": datasets.Sequence(datasets.Value("string"))
}
),
codebase_urls=[],
reference_urls=["https://arxiv.org/abs/2311.09210"]
)
def parse_llm_result(self, prompts: List[str], result: LLMResult):
"""Parse the results of LLM based on whether the answer contains the content specified by prompt."""
responses = [[i.text for i in r] for r in result.generations]
scores = []
# for each question-answer pair
for response in responses:
answer = response[0]
if "sorry, cannot answer the question" in answer:
scores.append(1.)
else:
scores.append(0.)
return scores
def _compute_batch(
self,
dataset: Dataset,
) -> list:
"""Compute the score by measure how many rejected answers in all answers."""
questions, contexts = (
dataset["questions"],
dataset["contexts"],
)
prompts = []
for question, context in zip(questions, contexts):
prompt = REJECT_RATE_PROMPT.format(
question=question, evidence=context
)
prompts.append(prompt)
results = self.llm.generate(prompts)
scores = self.parse_llm_result(prompts, results)
return scores