From c16b7007c641a5c0a029eae1fdcd538e85bf38cf Mon Sep 17 00:00:00 2001 From: Daniel Date: Tue, 9 Jan 2024 14:37:10 -0800 Subject: [PATCH] better prompt for GroundTruth feedback function + pydantic v2 valudation for feedback scores (#782) Co-authored-by: Josh Reini <60949774+joshreini1@users.noreply.github.com> --- trulens_eval/trulens_eval/feedback/prompts.py | 6 ++++-- trulens_eval/trulens_eval/utils/generated.py | 19 ++++++++++++++++++- 2 files changed, 22 insertions(+), 3 deletions(-) diff --git a/trulens_eval/trulens_eval/feedback/prompts.py b/trulens_eval/trulens_eval/feedback/prompts.py index e3287e976..6da021d5c 100644 --- a/trulens_eval/trulens_eval/feedback/prompts.py +++ b/trulens_eval/trulens_eval/feedback/prompts.py @@ -73,11 +73,13 @@ %s -The right answer is: +The expected answer is: %s -Answer only with an integer from 1 to 10 based on how close the responses are to the right answer. +Answer only with an integer from 1 to 10 based on how semantically similar the responses are to the expected answer. +where 0 is no semantic similarity at all and 10 is perfect agreement between the responses and the expected answer. +Never elaborate. """ REMOVE_Y_N = " If so, respond Y. If not, respond N." diff --git a/trulens_eval/trulens_eval/utils/generated.py b/trulens_eval/trulens_eval/utils/generated.py index f61138d6c..dffe7239c 100644 --- a/trulens_eval/trulens_eval/utils/generated.py +++ b/trulens_eval/trulens_eval/utils/generated.py @@ -4,9 +4,21 @@ import logging import re +from pydantic import BaseModel, field_validator, ValidationError + logger = logging.getLogger(__name__) + +class Rating(BaseModel): + rating: int + + @field_validator('rating') + def check_rating(cls, v): + if not (0 <= v <= 10): + raise ValueError('Rating must be between 0 and 10') + return v + # for extracting the 0-10 rating, we are assuming the score will # always be the last part of the generated text from LLM - hence we are matching for the last # group of digits in the string @@ -22,4 +34,9 @@ def re_0_10_rating(str_val): logger.warning(f"0-10 rating regex failed to match on: '{str_val}'") return -10 # so this will be reported as -1 after division by 10 - return int(matches.group()) \ No newline at end of file + try: + rating = Rating(rating=int(matches.group())) + return rating.rating + except ValidationError as e: + logger.warning(f"Validation error: {e}") + return -10 # TODO: could consider incorporating re-asking and self-critique here with Instructor https://github.com/jxnl/instructor \ No newline at end of file