Skip to content

Added functionality for adding metadata using validate api #72

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Apr 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

## [1.0.12] - 2025-04-17

- Support adding metadata in `validate()` method in Validator API.

## [1.0.11] - 2025-04-16

- Update default thresholds for custom evals to 0.0 in `Validator` API.
Expand Down Expand Up @@ -59,7 +63,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Initial release of the `cleanlab-codex` client library.

[Unreleased]: https://github.com/cleanlab/cleanlab-codex/compare/v1.0.11...HEAD
[Unreleased]: https://github.com/cleanlab/cleanlab-codex/compare/v1.0.12...HEAD
[1.0.12]: https://github.com/cleanlab/cleanlab-codex/compare/v1.0.11...v1.0.12
[1.0.11]: https://github.com/cleanlab/cleanlab-codex/compare/v1.0.10...v1.0.11
[1.0.10]: https://github.com/cleanlab/cleanlab-codex/compare/v1.0.9...v1.0.10
[1.0.9]: https://github.com/cleanlab/cleanlab-codex/compare/v1.0.8...v1.0.9
Expand Down
2 changes: 1 addition & 1 deletion src/cleanlab_codex/__about__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
# SPDX-License-Identifier: MIT
__version__ = "1.0.11"
__version__ = "1.0.12"
57 changes: 57 additions & 0 deletions src/cleanlab_codex/internal/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,14 @@
"""Evaluation metrics (excluding trustworthiness) that are used to determine if a response is bad."""
DEFAULT_EVAL_METRICS = ["response_helpfulness"]

# Simple mappings for is_bad keys
_SCORE_TO_IS_BAD_KEY = {
"trustworthiness": "is_not_trustworthy",
"query_ease": "is_not_query_easy",
"response_helpfulness": "is_not_response_helpful",
"context_sufficiency": "is_not_context_sufficient",
}


def get_default_evaluations() -> list[Eval]:
"""Get the default evaluations for the TrustworthyRAG.
Expand Down Expand Up @@ -51,3 +59,52 @@ def is_bad(score: Optional[float], threshold: float) -> bool:
"is_bad": is_bad(score_dict["score"], thresholds.get_threshold(eval_name)),
}
return cast(ThresholdedTrustworthyRAGScore, thresholded_scores)


def process_score_metadata(scores: ThresholdedTrustworthyRAGScore, thresholds: BadResponseThresholds) -> dict[str, Any]:
"""Process scores into metadata format with standardized keys.

Args:
scores: The ThresholdedTrustworthyRAGScore containing evaluation results
thresholds: The BadResponseThresholds configuration

Returns:
dict: A dictionary containing evaluation scores and their corresponding metadata
"""
metadata: dict[str, Any] = {}

# Process scores and add to metadata
for metric, score_data in scores.items():
metadata[metric] = score_data["score"]

# Add is_bad flags with standardized naming
is_bad_key = _SCORE_TO_IS_BAD_KEY.get(metric, f"is_not_{metric}")
metadata[is_bad_key] = score_data["is_bad"]

# Special case for trustworthiness explanation
if metric == "trustworthiness" and "log" in score_data and "explanation" in score_data["log"]:
metadata["explanation"] = score_data["log"]["explanation"]

# Add thresholds to metadata
thresholds_dict = thresholds.model_dump()
for metric in {k for k in scores if k not in thresholds_dict}:
thresholds_dict[metric] = thresholds.get_threshold(metric)
metadata["thresholds"] = thresholds_dict

# TODO: Remove this as the backend can infer this from the is_bad flags
metadata["label"] = _get_label(metadata)

return metadata
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you plan to add label to the metadata that's passed to project.query?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not long term, no. But I've added it in e0141b3 for now.



def _get_label(metadata: dict[str, Any]) -> str:
def is_bad(metric: str) -> bool:
return bool(metadata.get(_SCORE_TO_IS_BAD_KEY[metric], False))

if is_bad("context_sufficiency"):
return "search_failure"
if is_bad("response_helpfulness") or is_bad("query_ease"):
return "unhelpful"
if is_bad("trustworthiness"):
return "hallucination"
return "other_issues"
32 changes: 27 additions & 5 deletions src/cleanlab_codex/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from __future__ import annotations

from copy import deepcopy
from typing import TYPE_CHECKING, Any, Callable, Optional, cast

from cleanlab_tlm import TrustworthyRAG
Expand All @@ -13,6 +14,9 @@
get_default_evaluations,
get_default_trustworthyrag_config,
)
from cleanlab_codex.internal.validator import (
process_score_metadata as _process_score_metadata,
)
from cleanlab_codex.internal.validator import (
update_scores_based_on_thresholds as _update_scores_based_on_thresholds,
)
Expand Down Expand Up @@ -100,11 +104,14 @@ def __init__(

def validate(
self,
*,
query: str,
context: str,
response: str,
prompt: Optional[str] = None,
form_prompt: Optional[Callable[[str, str], str]] = None,
metadata: Optional[dict[str, Any]] = None,
log_results: bool = True,
) -> dict[str, Any]:
"""Evaluate whether the AI-generated response is bad, and if so, request an alternate expert answer.
If no expert answer is available, this query is still logged for SMEs to answer.
Expand All @@ -122,10 +129,16 @@ def validate(
- 'is_bad_response': True if the response is flagged as potentially bad, False otherwise. When True, a Codex lookup is performed, which logs this query into the Codex Project for SMEs to answer.
- Additional keys from a [`ThresholdedTrustworthyRAGScore`](/codex/api/python/types.validator/#class-thresholdedtrustworthyragscore) dictionary: each corresponds to a [TrustworthyRAG](/tlm/api/python/utils.rag/#class-trustworthyrag) evaluation metric, and points to the score for this evaluation as well as a boolean `is_bad` flagging whether the score falls below the corresponding threshold.
"""
scores, is_bad_response = self.detect(query, context, response, prompt, form_prompt)
scores, is_bad_response = self.detect(
query=query, context=context, response=response, prompt=prompt, form_prompt=form_prompt
)
expert_answer = None
if is_bad_response:
expert_answer = self._remediate(query)
final_metadata = deepcopy(metadata) if metadata else {}
if log_results:
processed_metadata = _process_score_metadata(scores, self._bad_response_thresholds)
final_metadata.update(processed_metadata)
expert_answer = self._remediate(query=query, metadata=final_metadata)

return {
"expert_answer": expert_answer,
Expand All @@ -135,11 +148,14 @@ def validate(

async def validate_async(
self,
*,
query: str,
context: str,
response: str,
prompt: Optional[str] = None,
form_prompt: Optional[Callable[[str, str], str]] = None,
metadata: Optional[dict[str, Any]] = None,
log_results: bool = True,
) -> dict[str, Any]:
"""Evaluate whether the AI-generated response is bad, and if so, request an alternate expert answer.
If no expert answer is available, this query is still logged for SMEs to answer.
Expand All @@ -158,9 +174,14 @@ async def validate_async(
- Additional keys from a [`ThresholdedTrustworthyRAGScore`](/codex/api/python/types.validator/#class-thresholdedtrustworthyragscore) dictionary: each corresponds to a [TrustworthyRAG](/tlm/api/python/utils.rag/#class-trustworthyrag) evaluation metric, and points to the score for this evaluation as well as a boolean `is_bad` flagging whether the score falls below the corresponding threshold.
"""
scores, is_bad_response = await self.detect_async(query, context, response, prompt, form_prompt)
final_metadata = metadata.copy() if metadata else {}
if log_results:
processed_metadata = _process_score_metadata(scores, self._bad_response_thresholds)
final_metadata.update(processed_metadata)

expert_answer = None
if is_bad_response:
expert_answer = self._remediate(query)
expert_answer = self._remediate(query=query, metadata=final_metadata)

return {
"expert_answer": expert_answer,
Expand All @@ -170,6 +191,7 @@ async def validate_async(

def detect(
self,
*,
query: str,
context: str,
response: str,
Expand Down Expand Up @@ -258,7 +280,7 @@ async def detect_async(
is_bad_response = any(score_dict["is_bad"] for score_dict in thresholded_scores.values())
return thresholded_scores, is_bad_response

def _remediate(self, query: str) -> str | None:
def _remediate(self, *, query: str, metadata: dict[str, Any] | None = None) -> str | None:
"""Request a SME-provided answer for this query, if one is available in Codex.

Args:
Expand All @@ -267,7 +289,7 @@ def _remediate(self, query: str) -> str | None:
Returns:
str | None: The SME-provided answer from Codex, or None if no answer could be found in the Codex Project.
"""
codex_answer, _ = self._project.query(question=query)
codex_answer, _ = self._project.query(question=query, metadata=metadata)
return codex_answer


Expand Down
83 changes: 82 additions & 1 deletion tests/internal/test_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,12 @@

from cleanlab_tlm.utils.rag import TrustworthyRAGScore

from cleanlab_codex.internal.validator import get_default_evaluations
from cleanlab_codex.internal.validator import (
get_default_evaluations,
process_score_metadata,
update_scores_based_on_thresholds,
)
from cleanlab_codex.types.validator import ThresholdedTrustworthyRAGScore
from cleanlab_codex.validator import BadResponseThresholds


Expand All @@ -27,3 +32,79 @@ def make_is_bad_response_config(trustworthiness: float, response_helpfulness: fl

def test_get_default_evaluations() -> None:
assert {evaluation.name for evaluation in get_default_evaluations()} == {"response_helpfulness"}


def test_process_score_metadata() -> None:
# Create test scores with various metrics
thresholded_scores = {
"trustworthiness": {"score": 0.8, "is_bad": False, "log": {"explanation": "Test explanation"}},
"response_helpfulness": {"score": 0.6, "is_bad": True},
"query_ease": {"score": 0.9, "is_bad": False},
}

thresholds = BadResponseThresholds(trustworthiness=0.7, response_helpfulness=0.7)

metadata = process_score_metadata(cast(ThresholdedTrustworthyRAGScore, thresholded_scores), thresholds)

# Check scores and flags
expected_metadata = {
"trustworthiness": 0.8,
"response_helpfulness": 0.6,
"query_ease": 0.9,
"is_not_trustworthy": False,
"is_not_response_helpful": True,
"is_not_query_easy": False,
"explanation": "Test explanation",
"thresholds": {"trustworthiness": 0.7, "response_helpfulness": 0.7, "query_ease": 0.0},
"label": "unhelpful",
}

assert metadata == expected_metadata


def test_process_score_metadata_edge_cases() -> None:
"""Test edge cases for process_score_metadata."""
thresholds = BadResponseThresholds()

# Test empty scores
metadata_for_empty_scores = process_score_metadata(cast(ThresholdedTrustworthyRAGScore, {}), thresholds)
assert {"thresholds", "label"} == set(metadata_for_empty_scores.keys())

# Test missing explanation
scores = cast(ThresholdedTrustworthyRAGScore, {"trustworthiness": {"score": 0.6, "is_bad": True}})
metadata_missing_explanation = process_score_metadata(scores, thresholds)
assert "explanation" not in metadata_missing_explanation

# Test custom metric
scores = cast(ThresholdedTrustworthyRAGScore, {"my_metric": {"score": 0.3, "is_bad": True}})
metadata_custom_metric = process_score_metadata(scores, thresholds)
assert metadata_custom_metric["my_metric"] == 0.3
assert metadata_custom_metric["is_not_my_metric"] is True


def test_update_scores_based_on_thresholds() -> None:
"""Test that update_scores_based_on_thresholds correctly flags scores based on thresholds."""
raw_scores = cast(
TrustworthyRAGScore,
{
"trustworthiness": {"score": 0.6}, # Below threshold
"response_helpfulness": {"score": 0.8}, # Above threshold
"custom_metric": {"score": 0.4}, # Below custom threshold
"another_metric": {"score": 0.6}, # Uses default threshold
},
)

thresholds = BadResponseThresholds(trustworthiness=0.7, response_helpfulness=0.7, custom_metric=0.45) # type: ignore[call-arg]

scores = update_scores_based_on_thresholds(raw_scores, thresholds)

expected_is_bad = {
"trustworthiness": True,
"response_helpfulness": False,
"custom_metric": True,
"another_metric": False,
}

for metric, expected in expected_is_bad.items():
assert scores[metric]["is_bad"] is expected
assert all(scores[k]["score"] == raw_scores[k]["score"] for k in raw_scores)
4 changes: 2 additions & 2 deletions tests/test_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,10 +137,10 @@ def test_remediate(self, mock_project: Mock, mock_trustworthy_rag: Mock) -> None
mock_project.from_access_key.return_value.query.return_value = ("expert answer", None)

validator = Validator(codex_access_key="test")
result = validator._remediate("test query") # noqa: SLF001
result = validator._remediate(query="test query") # noqa: SLF001

# Verify project.query was called
mock_project.from_access_key.return_value.query.assert_called_once_with(question="test query")
mock_project.from_access_key.return_value.query.assert_called_once_with(question="test query", metadata=None)
assert result == "expert answer"

def test_user_provided_thresholds(self, mock_project: Mock, mock_trustworthy_rag: Mock) -> None: # noqa: ARG002
Expand Down