Skip to content

Commit 0b3b9e4

Browse files
jjmachanshahules786
authored andcommitted
fix: temperature needs to be added handled effectively (explodinggradients#1759)
1 parent de6b8ea commit 0b3b9e4

File tree

3 files changed

+64
-14
lines changed

3 files changed

+64
-14
lines changed

src/ragas/llms/base.py

+27-11
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ class LangchainLLMWrapper(BaseRagasLLM):
129129

130130
def __init__(
131131
self,
132-
langchain_llm: BaseLanguageModel,
132+
langchain_llm: BaseLanguageModel[BaseMessage],
133133
run_config: t.Optional[RunConfig] = None,
134134
is_finished_parser: t.Optional[t.Callable[[LLMResult], bool]] = None,
135135
cache: t.Optional[CacheInterface] = None,
@@ -198,29 +198,36 @@ def generate_text(
198198
callbacks: Callbacks = None,
199199
) -> LLMResult:
200200
# figure out the temperature to set
201+
old_temperature: float | None = None
201202
if temperature is None:
202203
temperature = self.get_temperature(n=n)
204+
if hasattr(self.langchain_llm, "temperature"):
205+
self.langchain_llm.temperature = temperature # type: ignore
206+
old_temperature = temperature
203207

204208
if is_multiple_completion_supported(self.langchain_llm):
205-
return self.langchain_llm.generate_prompt(
209+
result = self.langchain_llm.generate_prompt(
206210
prompts=[prompt],
207211
n=n,
208-
temperature=temperature,
209212
stop=stop,
210213
callbacks=callbacks,
211214
)
212215
else:
213216
result = self.langchain_llm.generate_prompt(
214217
prompts=[prompt] * n,
215-
temperature=temperature,
216218
stop=stop,
217219
callbacks=callbacks,
218220
)
219221
# make LLMResult.generation appear as if it was n_completions
220222
# note that LLMResult.runs is still a list that represents each run
221223
generations = [[g[0] for g in result.generations]]
222224
result.generations = generations
223-
return result
225+
226+
# reset the temperature to the original value
227+
if old_temperature is not None:
228+
self.langchain_llm.temperature = old_temperature # type: ignore
229+
230+
return result
224231

225232
async def agenerate_text(
226233
self,
@@ -230,29 +237,38 @@ async def agenerate_text(
230237
stop: t.Optional[t.List[str]] = None,
231238
callbacks: Callbacks = None,
232239
) -> LLMResult:
240+
# handle temperature
241+
old_temperature: float | None = None
233242
if temperature is None:
234243
temperature = self.get_temperature(n=n)
244+
if hasattr(self.langchain_llm, "temperature"):
245+
self.langchain_llm.temperature = temperature # type: ignore
246+
old_temperature = temperature
235247

236-
if is_multiple_completion_supported(self.langchain_llm):
237-
return await self.langchain_llm.agenerate_prompt(
248+
# handle n
249+
if hasattr(self.langchain_llm, "n"):
250+
self.langchain_llm.n = n # type: ignore
251+
result = await self.langchain_llm.agenerate_prompt(
238252
prompts=[prompt],
239-
n=n,
240-
temperature=temperature,
241253
stop=stop,
242254
callbacks=callbacks,
243255
)
244256
else:
245257
result = await self.langchain_llm.agenerate_prompt(
246258
prompts=[prompt] * n,
247-
temperature=temperature,
248259
stop=stop,
249260
callbacks=callbacks,
250261
)
251262
# make LLMResult.generation appear as if it was n_completions
252263
# note that LLMResult.runs is still a list that represents each run
253264
generations = [[g[0] for g in result.generations]]
254265
result.generations = generations
255-
return result
266+
267+
# reset the temperature to the original value
268+
if old_temperature is not None:
269+
self.langchain_llm.temperature = old_temperature # type: ignore
270+
271+
return result
256272

257273
def set_run_config(self, run_config: RunConfig):
258274
self.run_config = run_config

src/ragas/optimizers/genetic.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -519,9 +519,8 @@ def dict_to_str(dict: t.Dict[str, t.Any]) -> str:
519519
),
520520
expected_output=dataset[idx]["prompts"][prompt_name][
521521
"edited_output"
522-
] or dataset[idx]["prompts"][prompt_name][
523-
"prompt_output"
524-
],
522+
]
523+
or dataset[idx]["prompts"][prompt_name]["prompt_output"],
525524
)
526525
for idx in indices
527526
]
+35
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import pytest
2+
from langchain_anthropic import ChatAnthropic
3+
from langchain_aws import ChatBedrock, ChatBedrockConverse
4+
from langchain_google_genai import ChatGoogleGenerativeAI
5+
from langchain_google_vertexai import ChatVertexAI
6+
from langchain_openai import ChatOpenAI
7+
8+
models = [
9+
ChatOpenAI(model="gpt-4o"),
10+
# AzureChatOpenAI(model="gpt-4o", api_version="2024-04-09"),
11+
ChatGoogleGenerativeAI(model="gemini-1.5-pro"),
12+
ChatAnthropic(
13+
model_name="claude-3-5-sonnet-20240620",
14+
timeout=10,
15+
stop=["\n\n"],
16+
temperature=0.5,
17+
),
18+
ChatBedrock(model="anthropic.claude-3-5-sonnet-20240620"),
19+
ChatBedrockConverse(model="anthropic.claude-3-5-sonnet-20240620"),
20+
ChatVertexAI(model="gemini-1.5-pro"),
21+
]
22+
23+
24+
@pytest.mark.parametrize("model", models)
25+
def test_langchain_chat_models_have_temperature(model):
26+
assert hasattr(model, "temperature")
27+
model.temperature = 0.5
28+
assert model.temperature == 0.5
29+
30+
31+
@pytest.mark.parametrize("model", models)
32+
def test_langchain_chat_models_have_n(model):
33+
assert hasattr(model, "n")
34+
model.n = 2
35+
assert model.n == 2

0 commit comments

Comments
 (0)