Skip to content

Commit f123d3d

Browse files
authored
feat: add non llm based context recall (#1266)
1 parent ae50d45 commit f123d3d

File tree

2 files changed

+142
-1
lines changed

2 files changed

+142
-1
lines changed
+64
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
# Context Recall
2+
3+
Context Recall measures how many of the relevant documents (or pieces of information) were successfully retrieved. It focuses on not missing important results. Higher recall means fewer relevant documents were left out.
4+
In short, recall is about not missing anything important. Since it is about not missing anything, calculating context recall always requires a reference to compare against.
5+
6+
7+
8+
## LLM Based Context Recall
9+
10+
Computed using `user_input`, `reference` and the `retrieved_contexts`, and the values range between 0 and 1, with higher values indicating better performance. This metric uses `reference` as a proxy to `reference_contexts` which also makes it easier to use as annotating reference contexts can be very time consuming. To estimate context recall from the `reference`, the reference is broken down into claims each claim in the `reference` answer is analyzed to determine whether it can be attributed to the retrieved context or not. In an ideal scenario, all claims in the reference answer should be attributable to the retrieved context.
11+
12+
13+
The formula for calculating context recall is as follows:
14+
15+
```{math}
16+
\text{context recall} = {|\text{GT claims that can be attributed to context}| \over |\text{Number of claims in GT}|}
17+
```
18+
19+
## Example
20+
21+
```{code-block} python
22+
from ragas.dataset_schema import SingleTurnSample
23+
from ragas.metrics import LLMContextRecall
24+
25+
sample = SingleTurnSample(
26+
user_input="Where is the Eiffel Tower located?",
27+
response="The Eiffel Tower is located in Paris.",
28+
reference="The Eiffel Tower is located in Paris.",
29+
retrieved_contexts=["Paris is the capital of France."],
30+
)
31+
32+
context_recall = LLMContextRecall()
33+
await context_recall.single_turn_ascore(sample)
34+
35+
```
36+
37+
## Non LLM Based Context Recall
38+
39+
Computed using `retrieved_contexts` and `reference_contexts`, and the values range between 0 and 1, with higher values indicating better performance. This metrics uses non llm string comparison metrics to identify if a retrieved context is relevant or not. You can use any non LLM based metrics as distance measure to identify if a retrieved context is relevant or not.
40+
41+
The formula for calculating context recall is as follows:
42+
43+
```{math}
44+
\text{context recall} = {|\text{Number of relevant contexts retrieved}| \over |\text{Total number of reference contexts}|}
45+
```
46+
47+
## Example
48+
49+
```{code-block} python
50+
51+
52+
from ragas.dataset_schema import SingleTurnSample
53+
from ragas.metrics import NonLLMContextRecall
54+
55+
sample = SingleTurnSample(
56+
retrieved_contexts=["Paris is the capital of France."],
57+
reference_contexts=["Paris is the capital of France.", "The Eiffel Tower is one of the most famous landmarks in Paris."]
58+
)
59+
60+
context_recall = NonLLMContextRecall()
61+
await context_recall.single_turn_ascore(sample)
62+
63+
64+
```

src/ragas/metrics/_context_recall.py

+78-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,10 @@
1010
from ragas.dataset_schema import SingleTurnSample
1111
from ragas.llms.output_parser import RagasoutputParser, get_json_format_instructions
1212
from ragas.llms.prompt import Prompt
13+
from ragas.metrics._string import NonLLMStringSimilarity
1314
from ragas.metrics.base import MetricType, MetricWithLLM, SingleTurnMetric, ensembler
15+
from ragas.run_config import RunConfig
16+
from ragas.utils import deprecated
1417

1518
if t.TYPE_CHECKING:
1619
from langchain_core.callbacks import Callbacks
@@ -109,7 +112,7 @@ def dicts(self) -> t.List[t.Dict]:
109112

110113

111114
@dataclass
112-
class ContextRecall(MetricWithLLM, SingleTurnMetric):
115+
class LLMContextRecall(MetricWithLLM, SingleTurnMetric):
113116
"""
114117
Estimates context recall by estimating TP and FN using annotated answer and
115118
retrieved context.
@@ -213,4 +216,78 @@ def save(self, cache_dir: str | None = None) -> None:
213216
self.context_recall_prompt.save(cache_dir)
214217

215218

219+
@dataclass
220+
class ContextRecall(LLMContextRecall):
221+
name: str = "context_recall"
222+
223+
@deprecated(since="0.2", removal="0.3", alternative="LLMContextRecall")
224+
async def _single_turn_ascore(
225+
self, sample: SingleTurnSample, callbacks: Callbacks
226+
) -> float:
227+
row = sample.dict()
228+
return await self._ascore(row, callbacks)
229+
230+
@deprecated(since="0.2", removal="0.3", alternative="LLMContextRecall")
231+
async def _ascore(self, row: t.Dict, callbacks: Callbacks) -> float:
232+
return await super()._ascore(row, callbacks)
233+
234+
235+
@dataclass
236+
class NonLLMContextRecall(SingleTurnMetric):
237+
name: str = "non_llm_context_recall" # type: ignore
238+
_required_columns: t.Dict[MetricType, t.Set[str]] = field(
239+
default_factory=lambda: {
240+
MetricType.SINGLE_TURN: {
241+
"retrieved_contexts",
242+
"reference_contexts",
243+
}
244+
}
245+
)
246+
distance_measure: SingleTurnMetric = field(
247+
default_factory=lambda: NonLLMStringSimilarity()
248+
)
249+
threshold: float = 0.5
250+
251+
def __post_init__(self):
252+
if isinstance(self.distance_measure, MetricWithLLM):
253+
raise ValueError(
254+
"distance_measure must not be an instance of MetricWithLLM for NonLLMContextPrecisionWithReference"
255+
)
256+
257+
def init(self, run_config: RunConfig) -> None:
258+
...
259+
260+
async def _single_turn_ascore(
261+
self, sample: SingleTurnSample, callbacks: Callbacks
262+
) -> float:
263+
retrieved_contexts = sample.retrieved_contexts
264+
reference_contexts = sample.reference_contexts
265+
assert retrieved_contexts is not None, "retrieved_contexts is empty"
266+
assert reference_contexts is not None, "reference_contexts is empty"
267+
268+
scores = []
269+
for ref in reference_contexts:
270+
scores.append(
271+
max(
272+
[
273+
await self.distance_measure.single_turn_ascore(
274+
SingleTurnSample(reference=rc, response=ref), callbacks
275+
)
276+
for rc in retrieved_contexts
277+
]
278+
)
279+
)
280+
return self._compute_score(scores)
281+
282+
async def _ascore(self, row: t.Dict, callbacks: Callbacks) -> float:
283+
return await self._single_turn_ascore(SingleTurnSample(**row), callbacks)
284+
285+
def _compute_score(self, verdict_list: t.List[float]) -> float:
286+
response = [1 if score > self.threshold else 0 for score in verdict_list]
287+
denom = len(response)
288+
numerator = sum(response)
289+
score = numerator / denom if denom > 0 else np.nan
290+
return score
291+
292+
216293
context_recall = ContextRecall()

0 commit comments

Comments
 (0)