|
10 | 10 | from ragas.dataset_schema import SingleTurnSample
|
11 | 11 | from ragas.llms.output_parser import RagasoutputParser, get_json_format_instructions
|
12 | 12 | from ragas.llms.prompt import Prompt
|
| 13 | +from ragas.metrics._string import NonLLMStringSimilarity |
13 | 14 | from ragas.metrics.base import MetricType, MetricWithLLM, SingleTurnMetric, ensembler
|
| 15 | +from ragas.run_config import RunConfig |
| 16 | +from ragas.utils import deprecated |
14 | 17 |
|
15 | 18 | if t.TYPE_CHECKING:
|
16 | 19 | from langchain_core.callbacks import Callbacks
|
@@ -109,7 +112,7 @@ def dicts(self) -> t.List[t.Dict]:
|
109 | 112 |
|
110 | 113 |
|
111 | 114 | @dataclass
|
112 |
| -class ContextRecall(MetricWithLLM, SingleTurnMetric): |
| 115 | +class LLMContextRecall(MetricWithLLM, SingleTurnMetric): |
113 | 116 | """
|
114 | 117 | Estimates context recall by estimating TP and FN using annotated answer and
|
115 | 118 | retrieved context.
|
@@ -213,4 +216,78 @@ def save(self, cache_dir: str | None = None) -> None:
|
213 | 216 | self.context_recall_prompt.save(cache_dir)
|
214 | 217 |
|
215 | 218 |
|
| 219 | +@dataclass |
| 220 | +class ContextRecall(LLMContextRecall): |
| 221 | + name: str = "context_recall" |
| 222 | + |
| 223 | + @deprecated(since="0.2", removal="0.3", alternative="LLMContextRecall") |
| 224 | + async def _single_turn_ascore( |
| 225 | + self, sample: SingleTurnSample, callbacks: Callbacks |
| 226 | + ) -> float: |
| 227 | + row = sample.dict() |
| 228 | + return await self._ascore(row, callbacks) |
| 229 | + |
| 230 | + @deprecated(since="0.2", removal="0.3", alternative="LLMContextRecall") |
| 231 | + async def _ascore(self, row: t.Dict, callbacks: Callbacks) -> float: |
| 232 | + return await super()._ascore(row, callbacks) |
| 233 | + |
| 234 | + |
| 235 | +@dataclass |
| 236 | +class NonLLMContextRecall(SingleTurnMetric): |
| 237 | + name: str = "non_llm_context_recall" # type: ignore |
| 238 | + _required_columns: t.Dict[MetricType, t.Set[str]] = field( |
| 239 | + default_factory=lambda: { |
| 240 | + MetricType.SINGLE_TURN: { |
| 241 | + "retrieved_contexts", |
| 242 | + "reference_contexts", |
| 243 | + } |
| 244 | + } |
| 245 | + ) |
| 246 | + distance_measure: SingleTurnMetric = field( |
| 247 | + default_factory=lambda: NonLLMStringSimilarity() |
| 248 | + ) |
| 249 | + threshold: float = 0.5 |
| 250 | + |
| 251 | + def __post_init__(self): |
| 252 | + if isinstance(self.distance_measure, MetricWithLLM): |
| 253 | + raise ValueError( |
| 254 | + "distance_measure must not be an instance of MetricWithLLM for NonLLMContextPrecisionWithReference" |
| 255 | + ) |
| 256 | + |
| 257 | + def init(self, run_config: RunConfig) -> None: |
| 258 | + ... |
| 259 | + |
| 260 | + async def _single_turn_ascore( |
| 261 | + self, sample: SingleTurnSample, callbacks: Callbacks |
| 262 | + ) -> float: |
| 263 | + retrieved_contexts = sample.retrieved_contexts |
| 264 | + reference_contexts = sample.reference_contexts |
| 265 | + assert retrieved_contexts is not None, "retrieved_contexts is empty" |
| 266 | + assert reference_contexts is not None, "reference_contexts is empty" |
| 267 | + |
| 268 | + scores = [] |
| 269 | + for ref in reference_contexts: |
| 270 | + scores.append( |
| 271 | + max( |
| 272 | + [ |
| 273 | + await self.distance_measure.single_turn_ascore( |
| 274 | + SingleTurnSample(reference=rc, response=ref), callbacks |
| 275 | + ) |
| 276 | + for rc in retrieved_contexts |
| 277 | + ] |
| 278 | + ) |
| 279 | + ) |
| 280 | + return self._compute_score(scores) |
| 281 | + |
| 282 | + async def _ascore(self, row: t.Dict, callbacks: Callbacks) -> float: |
| 283 | + return await self._single_turn_ascore(SingleTurnSample(**row), callbacks) |
| 284 | + |
| 285 | + def _compute_score(self, verdict_list: t.List[float]) -> float: |
| 286 | + response = [1 if score > self.threshold else 0 for score in verdict_list] |
| 287 | + denom = len(response) |
| 288 | + numerator = sum(response) |
| 289 | + score = numerator / denom if denom > 0 else np.nan |
| 290 | + return score |
| 291 | + |
| 292 | + |
216 | 293 | context_recall = ContextRecall()
|
0 commit comments