Skip to content

Commit d581e7d

Browse files
committed
chore: add unit tests for ragas evaluator
Signed-off-by: Oleg S <97077423+RobotSail@users.noreply.github.com>
1 parent 3a9e3f2 commit d581e7d

File tree

2 files changed

+168
-7
lines changed

2 files changed

+168
-7
lines changed

src/instructlab/eval/ragas.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
# # SPDX-License-Identifier: Apache-2.0
12
# Standard
23
from pathlib import Path
34
from typing import List, Optional, TypedDict
@@ -53,7 +54,7 @@ class ModelConfig(BaseModel):
5354

5455
# name of the model to use.
5556
model_name: str
56-
57+
5758
# The system prompt to be used when applying the chat template.
5859
system_prompt: str = _DEFAULT_SYSTEM_PROMPT
5960

@@ -67,7 +68,7 @@ class ModelConfig(BaseModel):
6768
# Max amount of tokens to generate.
6869
max_tokens: int = 768
6970

70-
# Random seed for reproducibility. Caution: this isn't supported by all model serving runtimes.
71+
# Random seed for reproducibility. Caution: this isn't supported by all model serving runtimes.
7172
seed: int = DEFAULT_SEED
7273

7374
@field_validator("temperature")
@@ -126,15 +127,14 @@ def run(
126127
"no dataset was provided, please specify the `dataset` argument"
127128
)
128129

129-
if type(dataset) not in (list, Path):
130-
raise TypeError(f"invalid type of dataset: {type(dataset)}")
131-
132130
# ensure we are in the dataframe format
133131
input_df = None
134132
if isinstance(dataset, list):
135133
input_df = DataFrame(dataset)
136134
elif isinstance(dataset, Path):
137135
input_df = read_json(dataset, orient="records", lines=True)
136+
else:
137+
raise TypeError(f"invalid type of dataset: {type(dataset)}")
138138

139139
# this should never happen, but pylint is not smart enough to detect it
140140
assert input_df is not None
@@ -192,8 +192,8 @@ def _generate_answers_from_model(
192192

193193
for i, qna in updated_df.iterrows():
194194
messages = [
195-
student_model.system_prompt,
196-
qna["user_input"],
195+
{"role": "system", "content": student_model.system_prompt},
196+
{"role": "user", "content": qna["user_input"]},
197197
]
198198
response = client.chat.completions.create(
199199
messages=messages,

tests/test_ragas.py

Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
# # SPDX-License-Identifier: Apache-2.0
2+
# Standard
3+
from pathlib import Path
4+
from unittest.mock import MagicMock, patch
5+
import unittest
6+
7+
# Third Party
8+
from pandas import DataFrame
9+
from ragas.callbacks import ChainRun
10+
from ragas.dataset_schema import EvaluationDataset, EvaluationResult
11+
import pandas as pd
12+
13+
# First Party
14+
from instructlab.eval.ragas import ModelConfig, RagasEvaluator, RunConfig, Sample
15+
16+
17+
class TestRagasEvaluator(unittest.TestCase):
18+
@patch("instructlab.eval.ragas.get_openai_client")
19+
def test_generate_answers_from_model(self, mock_get_openai_client):
20+
# mock the OpenAI client to always return "london" for chat completions
21+
mock_client = MagicMock()
22+
mock_response = MagicMock()
23+
mock_response.choices[0].message.content = "London"
24+
mock_client.chat.completions.create.return_value = mock_response
25+
mock_get_openai_client.return_value = mock_client
26+
27+
# get answers
28+
questions = pd.DataFrame({"user_input": ["What is the capital of France?"]})
29+
student_model = ModelConfig(
30+
base_url="https://api.openai.com",
31+
model_name="gpt-3.5-turbo",
32+
api_key="test-api-key",
33+
)
34+
evaluator = RagasEvaluator()
35+
result_df = evaluator._generate_answers_from_model(questions, student_model)
36+
37+
# what we expect to see
38+
expected_df = questions.copy()
39+
expected_df["response"] = ["London"]
40+
41+
# perform the assertions
42+
pd.testing.assert_frame_equal(result_df, expected_df)
43+
mock_get_openai_client.assert_called_once_with(
44+
model_api_base=student_model.base_url, api_key=student_model.api_key
45+
)
46+
mock_client.chat.completions.create.assert_called_once_with(
47+
messages=[student_model.system_prompt, "What is the capital of France?"],
48+
model=student_model.model_name,
49+
seed=42,
50+
max_tokens=student_model.max_tokens,
51+
temperature=student_model.temperature,
52+
)
53+
54+
@patch("instructlab.eval.ragas.read_json")
55+
@patch("instructlab.eval.ragas.evaluate")
56+
@patch("instructlab.eval.ragas.ChatOpenAI")
57+
@patch.object(RagasEvaluator, "_generate_answers_from_model")
58+
@patch.object(RagasEvaluator, "_get_metrics")
59+
def test_run(
60+
self,
61+
mock_get_metrics: MagicMock,
62+
mock_generate_answers_from_model: MagicMock,
63+
mock_ChatOpenAI: MagicMock,
64+
mock_evaluate: MagicMock,
65+
mock_read_json: MagicMock,
66+
):
67+
########################################################################
68+
# SETUP EVERYTHING WE NEED FOR THE TESTS
69+
########################################################################
70+
71+
# These are the variables which will control the flow of the test.
72+
# Since we have to re-construct some Ragas components under the hood,
73+
74+
student_model_response = "Paris"
75+
user_question = "What is the capital of France?"
76+
golden_answer = "The capital of France is Paris."
77+
base_ds = [{"user_input": user_question, "reference": golden_answer}]
78+
mocked_metric = "mocked-metric"
79+
mocked_metric_score = 4.0
80+
81+
# The following section takes care of mocking function return calls.
82+
# Ragas is tricky because it has some complex data structures under the hood,
83+
# so what we have to do is configure the intermediate outputs that we expect
84+
# to receive from Ragas.
85+
86+
mock_get_metrics.return_value = [mocked_metric]
87+
interim_df = DataFrame(
88+
{
89+
"user_input": [user_question],
90+
"response": [student_model_response],
91+
"reference": [golden_answer],
92+
}
93+
)
94+
mock_generate_answers_from_model.return_value = interim_df.copy()
95+
mocked_evaluation_ds = EvaluationDataset.from_pandas(interim_df)
96+
mock_ChatOpenAI.return_value = MagicMock()
97+
98+
# Ragas requires this value to instantiate an EvaluationResult object, so we must provide it.
99+
# It isn't functionally used for our purposes though.
100+
101+
_unimportant_ragas_traces = {
102+
"default": ChainRun(
103+
run_id="42",
104+
parent_run_id=None,
105+
name="root",
106+
inputs={"system": "null", "user": "null"},
107+
outputs={"assistant": "null"},
108+
metadata={"user_id": 1337},
109+
)
110+
}
111+
mock_evaluate.return_value = EvaluationResult(
112+
scores=[{mocked_metric: mocked_metric_score}],
113+
dataset=mocked_evaluation_ds,
114+
ragas_traces=_unimportant_ragas_traces,
115+
)
116+
117+
########################################################################
118+
# Run the tests
119+
########################################################################
120+
121+
# Configure all other inputs that Ragas does not depend on for proper mocking
122+
student_model = ModelConfig(
123+
base_url="https://api.openai.com",
124+
model_name="pt-3.5-turbo",
125+
api_key="test-api-key",
126+
)
127+
run_config = RunConfig(max_retries=3, max_wait=60, seed=42, timeout=30)
128+
evaluator = RagasEvaluator()
129+
130+
########################################################################
131+
# Test case: directly passing a dataset
132+
########################################################################
133+
result = evaluator.run(
134+
dataset=base_ds, student_model=student_model, run_config=run_config
135+
)
136+
137+
self.assertIsInstance(result, EvaluationResult)
138+
mock_generate_answers_from_model.assert_called_once()
139+
mock_evaluate.assert_called_once()
140+
mock_ChatOpenAI.assert_called_once_with(model="gpt-4o")
141+
142+
########################################################################
143+
# Test case: passing a dataset in via Path to JSONL file
144+
########################################################################
145+
mock_read_json.return_value = DataFrame(base_ds)
146+
result = evaluator.run(
147+
dataset=Path("dummy_path.jsonl"),
148+
student_model=student_model,
149+
run_config=run_config,
150+
)
151+
152+
self.assertIsInstance(result, EvaluationResult)
153+
mock_read_json.assert_called_once_with(
154+
Path("dummy_path.jsonl"), orient="records", lines=True
155+
)
156+
mock_generate_answers_from_model.assert_called()
157+
mock_evaluate.assert_called()
158+
159+
160+
if __name__ == "__main__":
161+
unittest.main()

0 commit comments

Comments
 (0)