forked from explodinggradients/ragas
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy path_simple_criteria.py
211 lines (173 loc) · 6.86 KB
/
_simple_criteria.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
from __future__ import annotations
import logging
import typing as t
from collections import Counter
from pydantic import BaseModel, Field
from ragas.dataset_schema import MultiTurnSample, SingleTurnSample
from ragas.metrics.base import (
MetricType,
MetricWithLLM,
MultiTurnMetric,
SingleTurnMetric,
)
from ragas.prompt import PydanticPrompt
if t.TYPE_CHECKING:
from langchain_core.callbacks.base import Callbacks
from ragas.llms import BaseRagasLLM
logger = logging.getLogger(__name__)
class SimpleCriteriaOutput(BaseModel):
reason: str = Field(description="Reason for the scoring")
score: int = Field(description="The score for the submission")
class SingleTurnSimpleCriteriaInput(BaseModel):
user_input: t.Optional[str] = Field(
description="The input to the llm system", default=None
)
response: t.Optional[str] = Field(
description="The response from the llm system", default=None
)
retrieved_contexts: t.Optional[t.List[str]] = Field(
description="The retrieved contexts from the llm system", default=None
)
reference_contexts: t.Optional[t.List[str]] = Field(
description="The reference contexts for the evaluation", default=None
)
reference: t.Optional[str] = Field(
description="The reference answer for evaluation", default=None
)
class MultiTurnSimpleCriteriaInput(BaseModel):
user_input: t.Optional[str] = Field(
description="The input to the model", default=None
)
reference: t.Optional[str] = Field(
description="The reference response", default=None
)
class SingleTurnSimpleCriteriaPrompt(
PydanticPrompt[SingleTurnSimpleCriteriaInput, SimpleCriteriaOutput]
):
instruction = "" # this will be set in the constructor
input_model = SingleTurnSimpleCriteriaInput
output_model = SimpleCriteriaOutput
class MultiTurnSimpleCriteriaPrompt(
PydanticPrompt[MultiTurnSimpleCriteriaInput, SimpleCriteriaOutput]
):
instruction = "" # this will be set in the constructor
input_model = MultiTurnSimpleCriteriaInput
output_model = SimpleCriteriaOutput
class SimpleCriteriaScore(MetricWithLLM, SingleTurnMetric, MultiTurnMetric):
"""
Judges the submission to give binary results using the criteria specified
in the metric definition.
Attributes
----------
name: str
name of the metrics
definition: str
criteria to score the submission
strictness: int
The number of times self consistency checks is made. Final judgement is
made using majority vote.
"""
def __init__(
self,
name: str,
definition: str,
llm: t.Optional[BaseRagasLLM] = None,
required_columns: t.Optional[t.Dict[MetricType, t.Set[str]]] = None,
single_turn_prompt: t.Optional[PydanticPrompt] = None,
multi_turn_prompt: t.Optional[PydanticPrompt] = None,
strictness: int = 1,
):
if required_columns is None:
required_columns = {
MetricType.SINGLE_TURN: {
"user_input:optional",
"response:optional",
"retrieved_contexts:optional",
"reference:optional",
"reference_contexts:optional",
},
MetricType.MULTI_TURN: {
"user_input:optional",
"reference:optional",
},
}
super().__init__(
name=name,
llm=llm,
_required_columns=required_columns,
)
self._definition = definition
self.single_turn_prompt = single_turn_prompt or SingleTurnSimpleCriteriaPrompt()
self.multi_turn_prompt = multi_turn_prompt or MultiTurnSimpleCriteriaPrompt()
# update the instruction for the prompts with the definition
instruction = f"Evaluate the Input based on the criterial defined. Give a score between 0 and 5.\nCriteria Definition: {self._definition}"
self.single_turn_prompt.instruction = instruction
self.multi_turn_prompt.instruction = instruction
# ensure odd number of checks to avoid tie in majority vote.
self.strictness = strictness
self.strictness = (
self.strictness if self.strictness % 2 != 0 else self.strictness + 1
)
def __repr__(self) -> str:
return f"{self.name}(required_columns={self.required_columns}, llm={self.llm}, definition={self._definition})"
@property
def definition(self) -> str:
return self._definition
@definition.setter
def definition(self, value: str) -> None:
self._definition = value
# Update the instruction for both prompts with the new definition
instruction = f"Evaluate the Input based on the criterial defined. Give a score between 0 and 5.\nCriteria Definition: {self._definition}"
self.single_turn_prompt.instruction = instruction
self.multi_turn_prompt.instruction = instruction
def _compute_score(
self, safe_loaded_responses: t.List[SimpleCriteriaOutput]
) -> float:
if self.strictness > 1:
score = Counter([item.score for item in safe_loaded_responses]).most_common(
1
)[0][0]
else:
score = safe_loaded_responses[0].score
return score
async def _single_turn_ascore(
self, sample: SingleTurnSample, callbacks: Callbacks
) -> float:
row = sample.to_dict()
return await self._ascore(row, callbacks)
async def _ascore(self, row: t.Dict, callbacks: Callbacks) -> float:
assert self.llm is not None, "set LLM before use"
user_input, context, response = (
row["user_input"],
row.get("retrieved_contexts"),
row["response"],
)
if context is not None:
if isinstance(context, list):
context = "\n".join(context)
user_input = f"Question: {user_input} Answer using context: {context}"
prompt_input = SingleTurnSimpleCriteriaInput(
user_input=user_input,
response=response,
)
response = await self.single_turn_prompt.generate(
data=prompt_input,
llm=self.llm,
callbacks=callbacks,
)
return self._compute_score([response])
async def _multi_turn_ascore(
self, sample: MultiTurnSample, callbacks: Callbacks
) -> float:
assert self.llm is not None, "LLM is not set"
assert sample.reference is not None, "Reference is not set"
interaction = sample.pretty_repr()
prompt_input = MultiTurnSimpleCriteriaInput(
user_input=interaction,
)
response = await self.multi_turn_prompt.generate(
data=prompt_input,
llm=self.llm,
callbacks=callbacks,
)
return self._compute_score([response])