Skip to content

Commit

Permalink
chore: decouple tests into more atomic units
Browse files Browse the repository at this point in the history
Signed-off-by: Oleg S <97077423+RobotSail@users.noreply.github.com>
  • Loading branch information
RobotSail committed Jan 8, 2025
1 parent c6b5a70 commit ab3d168
Show file tree
Hide file tree
Showing 2 changed files with 220 additions and 130 deletions.
112 changes: 78 additions & 34 deletions src/instructlab/eval/ragas.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
# # SPDX-License-Identifier: Apache-2.0
# Standard
from pathlib import Path
from typing import List, Optional, TypedDict
from typing import TYPE_CHECKING, List, Optional, TypedDict

# Third Party
from langchain_community.chat_models import ChatOpenAI
from openai import Client as OpenAIClient
from openai.types.chat import ChatCompletionMessageParam
from pandas import DataFrame, read_json
from pydantic import BaseModel, ConfigDict, field_validator
from pydantic import BaseModel, ConfigDict, Field
from ragas.evaluation import EvaluationDataset, EvaluationResult, RunConfig, evaluate
from ragas.metrics import Metric
from ragas.metrics._domain_specific_rubrics import ( # the rubrics we must instantiate are located inside of a file marked as private
Expand All @@ -17,6 +18,9 @@

# Local
from .evaluator import Evaluator
from .logger_config import setup_logger

logger = setup_logger(__name__)


class Sample(TypedDict):
Expand Down Expand Up @@ -56,21 +60,14 @@ class ModelConfig(BaseModel):
system_prompt: str = _DEFAULT_SYSTEM_PROMPT

# "model randomness" aka likelihood of sampling something other than the likeliest token
temperature: float = 0.0
temperature: float = Field(default=0.0, le=1.0, ge=0.0)

# Max amount of tokens to generate.
max_tokens: int = 768

# Random seed for reproducibility. Caution: this isn't supported by all model serving runtimes.
seed: int = DEFAULT_SEED

@field_validator("temperature")
@classmethod
def check_temperature(cls, v: float) -> float:
if not 0.0 <= v <= 1.0:
raise ValueError("temperature must be between 0.0 and 1.0")
return v


class RagasEvaluator(Evaluator):
# most basic implementation, we just assume that the user will bring the existing model responses
Expand All @@ -80,18 +77,42 @@ def __init__(
self,
student_model: ModelConfig | None = None,
run_config: RunConfig | None = None,
openai_client: OpenAIClient | None = None,
student_openai_client: OpenAIClient | None = None,
judge_model_name: str = DEFAULT_JUDGE_MODEL,
judge_openai_api_key: str | None = None,
):
self.student_model = student_model
self.run_config = run_config
self.openai_client = openai_client
self.student_openai_client = student_openai_client
self.judge_model_name = judge_model_name
self.judge_openai_api_key = judge_openai_api_key

@staticmethod
def _validate_dataset(df: DataFrame):
"""
Validates whether or not the given `df` is a valid dataset of `Sample` objects.
Args:
df (DataFrame): DataFrame containing the dataset to be evaluated.
"""
# We have to hardcode these fields because the automated way of resolving the required fields from a TypedDict
# is only included by default in Python3.11+. For earlier versions, the `typing_extensions` package is required.
# See: https://docs.python.org/3/whatsnew/3.11.html#pep-655-marking-individual-typeddict-items-as-required-or-not-required
required_keys = {"user_input", "reference"}
missing_keys = required_keys - set(df.columns)
if missing_keys:
raise ValueError(
f"invalid dataset provided, missing the following keys: {', '.join(missing_keys)}"
)

def run(
self,
dataset: List[Sample] | Path,
student_model: ModelConfig | None = None,
run_config: RunConfig | None = None,
openai_client: OpenAIClient | None = None,
student_openai_client: OpenAIClient | None = None,
judge_model_name: str | None = None,
judge_openai_api_key: str | None = None,
) -> EvaluationResult:
"""
Evaluates the quality of model responses against a graded rubric.
Expand All @@ -111,21 +132,31 @@ def run(
a default one is created containing extremely permissive settings when handling
timeouts. This is because by default, OpenAI tier-1 usage accounts have very high
rate limits resulting in heavy throttling during evaluations.
openai_client (openai.Client | None, optional):
student_openai_client (openai.Client | None, optional):
The client to use when generating questions from the student model, must be compatible with the OpenAI API.
This field is required when `student_model` is provided.
judge_model_name (str | None, optional):
Name of the OpenAI model to use as the judge model. Defaults to "gpt-4o" when none is specified.
judge_openai_api_key (str | None, optional):
The API key to use for evaluating the given dataset. When this isn't provided, `OPENAI_API_KEY` is read instead.
Returns:
EvaluationResult: The results of all evaluations performed by Ragas
"""
judge_model_name = (
judge_model_name if judge_model_name else self.judge_model_name
)
judge_openai_api_key = (
judge_openai_api_key if judge_openai_api_key else self.judge_openai_api_key
)
student_model = student_model if student_model else self.student_model
run_config = run_config if run_config else self.run_config
openai_client = openai_client if openai_client else self.openai_client

if not dataset:
raise ValueError(
"no dataset was provided, please specify the `dataset` argument"
)
student_openai_client = (
student_openai_client
if student_openai_client
else self.student_openai_client
)

# ensure we are in the dataframe format
input_df = None
Expand All @@ -137,22 +168,30 @@ def run(
raise TypeError(f"invalid type of dataset: {type(dataset)}")

# this should never happen, but pylint is not smart enough to detect it
assert input_df is not None
if TYPE_CHECKING:
assert input_df is not None

# ensure the dataset is in the format we expect it
self._validate_dataset(input_df)

need_to_generate_questions = "response" not in input_df.columns
if need_to_generate_questions and (not student_model or not openai_client):
raise ValueError(
"provided dataset doesn't contain the model `response`, but either `student_model` or `openai_client` wasn't provided for inference"
if need_to_generate_questions:
logger.debug(
"`response` is missing in the input dataframe columns, generating questions from the model is required."
)
if not student_model or not student_openai_client:
raise ValueError(
"provided dataset doesn't contain the model `response`, but either `student_model` or `student_openai_client` wasn't provided for inference"
)

# if the student model was provided then we always generate regardless
if student_model:
if not openai_client:
if not student_openai_client:
raise ValueError(
"`student_model` was specified but `openai_client` was not provided"
"`student_model` was specified but `student_openai_client` was not provided"
)
input_df = self._generate_answers_from_model(
input_df, student_model, openai_client
input_df, student_model, student_openai_client
)

if not run_config:
Expand All @@ -170,7 +209,8 @@ def run(

# we will be using gpt-4o for the foreseeable future, we hardcode this
# for consistency of answers
critic_lm = ChatOpenAI(model=DEFAULT_JUDGE_MODEL)

critic_lm = ChatOpenAI(model=judge_model_name, api_key=judge_openai_api_key)
results = evaluate(
dataset=evaluation_ds,
batch_size=4,
Expand All @@ -185,7 +225,7 @@ def _generate_answers_from_model(
self,
questions: DataFrame,
student_model: ModelConfig,
openai_client: OpenAIClient,
student_openai_client: OpenAIClient,
) -> DataFrame:
"""
Given a DataFrame containing `user_input` columns, generates responses from the given model
Expand All @@ -196,11 +236,14 @@ def _generate_answers_from_model(
updated_df["response"] = ""

for i, qna in updated_df.iterrows():
messages = [
student_model.system_prompt,
qna["user_input"],
messages: List[ChatCompletionMessageParam] = [
{
"role": "system",
"content": student_model.system_prompt,
},
{"role": "user", "content": qna["user_input"]},
]
response = openai_client.chat.completions.create(
response = student_openai_client.chat.completions.create(
messages=messages,
model=student_model.model_name,
# specify the seed so we can at least try to have some reproducibility when the clients support it
Expand All @@ -211,7 +254,8 @@ def _generate_answers_from_model(
updated_df.at[i, "response"] = response.choices[0].message.content
return updated_df

def _get_metrics(self) -> List[Metric]:
@staticmethod
def _get_metrics() -> List[Metric]:
# default set of metrics
return [
RubricsScore(
Expand Down
Loading

0 comments on commit ab3d168

Please sign in to comment.