Skip to content

feat: add non llm based context recall #1266

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Sep 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 64 additions & 0 deletions docs/concepts/metrics/context_recall_v2.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Context Recall

Context Recall measures how many of the relevant documents (or pieces of information) were successfully retrieved. It focuses on not missing important results. Higher recall means fewer relevant documents were left out.
In short, recall is about not missing anything important. Since it is about not missing anything, calculating context recall always requires a reference to compare against.



## LLM Based Context Recall

Computed using `user_input`, `reference` and the `retrieved_contexts`, and the values range between 0 and 1, with higher values indicating better performance. This metric uses `reference` as a proxy to `reference_contexts` which also makes it easier to use as annotating reference contexts can be very time consuming. To estimate context recall from the `reference`, the reference is broken down into claims each claim in the `reference` answer is analyzed to determine whether it can be attributed to the retrieved context or not. In an ideal scenario, all claims in the reference answer should be attributable to the retrieved context.


The formula for calculating context recall is as follows:

```{math}
\text{context recall} = {|\text{GT claims that can be attributed to context}| \over |\text{Number of claims in GT}|}
```

## Example

```{code-block} python
from ragas.dataset_schema import SingleTurnSample
from ragas.metrics import LLMContextRecall

sample = SingleTurnSample(
user_input="Where is the Eiffel Tower located?",
response="The Eiffel Tower is located in Paris.",
reference="The Eiffel Tower is located in Paris.",
retrieved_contexts=["Paris is the capital of France."],
)

context_recall = LLMContextRecall()
await context_recall.single_turn_ascore(sample)

```

## Non LLM Based Context Recall

Computed using `retrieved_contexts` and `reference_contexts`, and the values range between 0 and 1, with higher values indicating better performance. This metrics uses non llm string comparison metrics to identify if a retrieved context is relevant or not. You can use any non LLM based metrics as distance measure to identify if a retrieved context is relevant or not.

The formula for calculating context recall is as follows:

```{math}
\text{context recall} = {|\text{Number of relevant contexts retrieved}| \over |\text{Total number of reference contexts}|}
```

## Example

```{code-block} python


from ragas.dataset_schema import SingleTurnSample
from ragas.metrics import NonLLMContextRecall

sample = SingleTurnSample(
retrieved_contexts=["Paris is the capital of France."],
reference_contexts=["Paris is the capital of France.", "The Eiffel Tower is one of the most famous landmarks in Paris."]
)

context_recall = NonLLMContextRecall()
await context_recall.single_turn_ascore(sample)


```
79 changes: 78 additions & 1 deletion src/ragas/metrics/_context_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@
from ragas.dataset_schema import SingleTurnSample
from ragas.llms.output_parser import RagasoutputParser, get_json_format_instructions
from ragas.llms.prompt import Prompt
from ragas.metrics._string import NonLLMStringSimilarity
from ragas.metrics.base import MetricType, MetricWithLLM, SingleTurnMetric, ensembler
from ragas.run_config import RunConfig
from ragas.utils import deprecated

if t.TYPE_CHECKING:
from langchain_core.callbacks import Callbacks
Expand Down Expand Up @@ -109,7 +112,7 @@ def dicts(self) -> t.List[t.Dict]:


@dataclass
class ContextRecall(MetricWithLLM, SingleTurnMetric):
class LLMContextRecall(MetricWithLLM, SingleTurnMetric):
"""
Estimates context recall by estimating TP and FN using annotated answer and
retrieved context.
Expand Down Expand Up @@ -213,4 +216,78 @@ def save(self, cache_dir: str | None = None) -> None:
self.context_recall_prompt.save(cache_dir)


@dataclass
class ContextRecall(LLMContextRecall):
name: str = "context_recall"

@deprecated(since="0.2", removal="0.3", alternative="LLMContextRecall")
async def _single_turn_ascore(
self, sample: SingleTurnSample, callbacks: Callbacks
) -> float:
row = sample.dict()
return await self._ascore(row, callbacks)

@deprecated(since="0.2", removal="0.3", alternative="LLMContextRecall")
async def _ascore(self, row: t.Dict, callbacks: Callbacks) -> float:
return await super()._ascore(row, callbacks)


@dataclass
class NonLLMContextRecall(SingleTurnMetric):
name: str = "non_llm_context_recall" # type: ignore
_required_columns: t.Dict[MetricType, t.Set[str]] = field(
default_factory=lambda: {
MetricType.SINGLE_TURN: {
"retrieved_contexts",
"reference_contexts",
}
}
)
distance_measure: SingleTurnMetric = field(
default_factory=lambda: NonLLMStringSimilarity()
)
threshold: float = 0.5

def __post_init__(self):
if isinstance(self.distance_measure, MetricWithLLM):
raise ValueError(
"distance_measure must not be an instance of MetricWithLLM for NonLLMContextPrecisionWithReference"
)

def init(self, run_config: RunConfig) -> None:
...

async def _single_turn_ascore(
self, sample: SingleTurnSample, callbacks: Callbacks
) -> float:
retrieved_contexts = sample.retrieved_contexts
reference_contexts = sample.reference_contexts
assert retrieved_contexts is not None, "retrieved_contexts is empty"
assert reference_contexts is not None, "reference_contexts is empty"

scores = []
for ref in reference_contexts:
scores.append(
max(
[
await self.distance_measure.single_turn_ascore(
SingleTurnSample(reference=rc, response=ref), callbacks
)
for rc in retrieved_contexts
]
)
)
return self._compute_score(scores)

async def _ascore(self, row: t.Dict, callbacks: Callbacks) -> float:
return await self._single_turn_ascore(SingleTurnSample(**row), callbacks)

def _compute_score(self, verdict_list: t.List[float]) -> float:
response = [1 if score > self.threshold else 0 for score in verdict_list]
denom = len(response)
numerator = sum(response)
score = numerator / denom if denom > 0 else np.nan
return score


context_recall = ContextRecall()
Loading