Skip to content

Commit f12fbd7

Browse files
committed
feat: add ability for answers to be generated from user questions
When a dataset is provided and is missing the `response` field, we will need to generate these responses. This commit ensures that when this case happens, we will error out when a student model is not configured. Otherwise, we will always generate these responses if the student model exists, regardless if `response` is in the dataframe or not. Signed-off-by: Oleg S <97077423+RobotSail@users.noreply.github.com>
1 parent df441c1 commit f12fbd7

File tree

1 file changed

+131
-17
lines changed

1 file changed

+131
-17
lines changed

src/instructlab/eval/ragas.py

Lines changed: 131 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,47 +1,110 @@
11
# Standard
22
from pathlib import Path
3-
from typing import List, TypedDict
3+
from typing import List, Optional, TypedDict
44

55
# Third Party
66
from langchain_community.chat_models import ChatOpenAI
7+
from pandas import DataFrame, read_json
8+
from pydantic import BaseModel, ConfigDict, field_validator
79
from ragas.evaluation import EvaluationDataset, EvaluationResult, RunConfig, evaluate
810
from ragas.metrics._domain_specific_rubrics import ( # the rubrics we must instantiate are located inside of a file marked as private
911
DEFAULT_WITH_REFERENCE_RUBRICS,
1012
RubricsScore,
1113
)
12-
import pandas as pd
1314

1415
# Local
1516
from .evaluator import Evaluator
17+
from .mt_bench_common import get_openai_client
1618

1719

1820
class Sample(TypedDict):
21+
"""
22+
TypedDict of a sample that we accept when doing eval with Ragas.
23+
We specifically use TypedDict here to be flexible with the input data we accept.
24+
"""
25+
1926
# question
2027
user_input: str
2128

2229
# model answer
23-
response: str
30+
response: Optional[str]
2431

2532
# golden answer
2633
reference: str
2734

2835

36+
# default system prompt we'll use when none is provided. Make it private as we don't intend this to be a public object
37+
_DEFAULT_SYSTEM_PROMPT = """You are an advanced AI assistant designed to provide precise and accurate information.
38+
Your primary goal is to answer queries with the most up-to-date and factual information available.
39+
Focus on delivering clear, concise, and correct responses.
40+
If you're uncertain about any aspect of the query, state your level of confidence and provide the most accurate information you can.
41+
Your responses should prioritize accuracy over all other considerations."""
42+
43+
DEFAULT_SEED = 1337
44+
DEFAULT_JUDGE_MODEL = "gpt-4o"
45+
46+
47+
class ModelConfig(BaseModel):
48+
model_config = ConfigDict(protected_namespaces=())
49+
50+
# URL of the OpenAI server where the model shall be hosted
51+
base_url: str
52+
53+
# name of the model to use
54+
model_name: str
55+
system_prompt: str = _DEFAULT_SYSTEM_PROMPT
56+
57+
# We do NOT read from OPENAI_API_KEY for the student model for security reasons (e.g. sending the API key to another client)
58+
# To provide an OpenAI key, you must set it here; else the default is used.
59+
api_key: str = "no-api-key"
60+
61+
# "model randomness" aka likelihood of sampling something other than the likeliest token
62+
temperature: float = 0.0
63+
64+
max_tokens: int = 768
65+
66+
# Random seed for reproducibility. This is not supported everywhere and therefore is unreliable.
67+
seed: int = DEFAULT_SEED
68+
69+
@field_validator("temperature")
70+
@classmethod
71+
def check_temperature(cls, v: float) -> float:
72+
if not 0.0 <= v <= 1.0:
73+
raise ValueError("temperature must be between 0.0 and 1.0")
74+
return v
75+
76+
2977
class RagasEvaluator(Evaluator):
3078
# most basic implementation, we just assume that the user will bring the existing model responses
3179
name = "ragas"
3280

33-
def __init__(self):
34-
pass
81+
def __init__(
82+
self,
83+
student_model: ModelConfig | None = None,
84+
run_config: RunConfig | None = None,
85+
):
86+
self.student_model = student_model
87+
self.run_config = run_config
3588

3689
def run(
37-
self, dataset: List[Sample] | Path = None, run_config: RunConfig | None = None
90+
self,
91+
dataset: List[Sample] | Path,
92+
student_model: ModelConfig | None = None,
93+
run_config: RunConfig | None = None,
3894
) -> EvaluationResult:
3995
"""
4096
Evaluates the quality of model responses against a graded rubric.
4197
98+
When the `dataset` lacks the `response` field, then `student_model` must be provided
99+
in order to generate the answers.
100+
42101
Args:
43102
dataset (List[Sample] | Path):
44-
List of model questions and answers
103+
Can be either a list of `Sample` objects or a path to a jsonl file containing
104+
records matching `Sample`.
105+
student_model: (StudentModelConfig):
106+
When this parameter is provided, we'll attempt to use the described model in order to
107+
generate the responses from the given list of questions.
45108
run_config (RunConfig | None, optional):
46109
Configuration to use when running evaluations. If none is provided, then
47110
a default one is created containing extremely permissive settings when handling
@@ -51,26 +114,44 @@ def run(
51114
Returns:
52115
EvaluationResult: The results of all evaluations performed by Ragas
53116
"""
117+
student_model = student_model if student_model else self.student_model
118+
run_config = run_config if run_config else self.run_config
119+
54120
if not dataset:
55121
raise ValueError(
56122
"no dataset was provided, please specify the `dataset` argument"
57123
)
58-
if isinstance(dataset, Path):
59-
input_ds = EvaluationDataset.from_pandas(
60-
pd.read_json(dataset, lines=True, orient="records")
124+
125+
if type(dataset) not in (list, Path):
126+
raise TypeError(f"invalid type of dataset: {type(dataset)}")
127+
128+
# ensure we are in the dataframe format
129+
input_df = None
130+
if isinstance(dataset, list):
131+
input_df = DataFrame(dataset)
132+
elif isinstance(dataset, Path):
133+
input_df = read_json(dataset, orient="records", lines=True)
134+
135+
# this should never happen, but pylint is not smart enough to detect it
136+
assert input_df is not None
137+
138+
need_to_generate_questions = "response" not in input_df.columns
139+
if need_to_generate_questions and not student_model:
140+
raise ValueError(
141+
"provided dataset doesn't contain the model `response`, but no `student_model` was provided for inference"
61142
)
62-
elif isinstance(dataset, list):
63-
input_ds = EvaluationDataset.from_list(dataset)
64-
else:
65-
raise TypeError(f"invalid type passed for dataset: {type(dataset)}")
143+
144+
# if the student model was provided then we always generate regardless
145+
if student_model:
146+
input_df = self._generate_answers_from_model(input_df, student_model)
66147

67148
if not run_config:
68149
# we set extreme timeout/retry values by default since OpenAI tier-1 rate limits
69150
# are horrible and will result in half of our evaluation results being NaN or 0
70151
run_config = RunConfig(
71152
max_retries=120,
72153
max_wait=7200,
73-
seed=42,
154+
seed=DEFAULT_SEED,
74155
timeout=3600,
75156
)
76157

@@ -81,15 +162,48 @@ def run(
81162
)
82163
]
83164

165+
evaluation_ds = EvaluationDataset.from_pandas(input_df)
166+
84167
# we will be using gpt-4o for the foreseeable future, we hardcode this
85168
# for consistency of answers
86-
critic_lm = ChatOpenAI(model="gpt-4o")
169+
critic_lm = ChatOpenAI(model=DEFAULT_JUDGE_MODEL)
87170
results = evaluate(
88-
dataset=input_ds,
171+
dataset=evaluation_ds,
89172
batch_size=4,
90173
run_config=run_config,
91174
llm=critic_lm,
92175
metrics=metrics,
93176
show_progress=True,
94177
)
95178
return results
179+
180+
def _generate_answers_from_model(
181+
self, questions: DataFrame, student_model: ModelConfig
182+
) -> DataFrame:
183+
"""
184+
Given a DataFrame containing `user_input` columns, generates responses from the given model
185+
and returns a new DataFrame containing its answers in the `response` column.
186+
"""
187+
client = get_openai_client(
188+
model_api_base=student_model.base_url, api_key=student_model.api_key
189+
)
190+
191+
# initialize response to write into
192+
updated_df = questions.copy()
193+
updated_df["response"] = ""
194+
195+
for i, qna in updated_df.iterrows():
196+
messages = [
197+
student_model.system_prompt,
198+
qna["user_input"],
199+
]
200+
response = client.chat.completions.create(
201+
messages=messages,
202+
model=student_model.model_name,
203+
# specify the seed so we can at least try to have some reproducibility when the clients support it
204+
seed=42,
205+
max_tokens=student_model.max_tokens,
206+
temperature=student_model.temperature,
207+
)
208+
updated_df.at[i, "response"] = response.choices[0].message.content
209+
return updated_df

0 commit comments

Comments
 (0)