Skip to content

Commit

Permalink
better prompt for GroundTruth feedback function + pydantic v2 valudat…
Browse files Browse the repository at this point in the history
…ion for feedback scores (truera#782)

Co-authored-by: Josh Reini <60949774+joshreini1@users.noreply.github.com>
  • Loading branch information
daniel-huang-1230 and joshreini1 authored Jan 9, 2024
1 parent e36b997 commit c16b700
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 3 deletions.
6 changes: 4 additions & 2 deletions trulens_eval/trulens_eval/feedback/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,13 @@
%s
The right answer is:
The expected answer is:
%s
Answer only with an integer from 1 to 10 based on how close the responses are to the right answer.
Answer only with an integer from 1 to 10 based on how semantically similar the responses are to the expected answer.
where 0 is no semantic similarity at all and 10 is perfect agreement between the responses and the expected answer.
Never elaborate.
"""

REMOVE_Y_N = " If so, respond Y. If not, respond N."
Expand Down
19 changes: 18 additions & 1 deletion trulens_eval/trulens_eval/utils/generated.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,21 @@

import logging
import re
from pydantic import BaseModel, field_validator, ValidationError


logger = logging.getLogger(__name__)


class Rating(BaseModel):
rating: int

@field_validator('rating')
def check_rating(cls, v):
if not (0 <= v <= 10):
raise ValueError('Rating must be between 0 and 10')
return v

# for extracting the 0-10 rating, we are assuming the score will
# always be the last part of the generated text from LLM - hence we are matching for the last
# group of digits in the string
Expand All @@ -22,4 +34,9 @@ def re_0_10_rating(str_val):
logger.warning(f"0-10 rating regex failed to match on: '{str_val}'")
return -10 # so this will be reported as -1 after division by 10

return int(matches.group())
try:
rating = Rating(rating=int(matches.group()))
return rating.rating
except ValidationError as e:
logger.warning(f"Validation error: {e}")
return -10 # TODO: could consider incorporating re-asking and self-critique here with Instructor https://github.com/jxnl/instructor

0 comments on commit c16b700

Please sign in to comment.