@@ -129,7 +129,7 @@ class LangchainLLMWrapper(BaseRagasLLM):
129
129
130
130
def __init__ (
131
131
self ,
132
- langchain_llm : BaseLanguageModel ,
132
+ langchain_llm : BaseLanguageModel [ BaseMessage ] ,
133
133
run_config : t .Optional [RunConfig ] = None ,
134
134
is_finished_parser : t .Optional [t .Callable [[LLMResult ], bool ]] = None ,
135
135
cache : t .Optional [CacheInterface ] = None ,
@@ -198,29 +198,36 @@ def generate_text(
198
198
callbacks : Callbacks = None ,
199
199
) -> LLMResult :
200
200
# figure out the temperature to set
201
+ old_temperature : float | None = None
201
202
if temperature is None :
202
203
temperature = self .get_temperature (n = n )
204
+ if hasattr (self .langchain_llm , "temperature" ):
205
+ self .langchain_llm .temperature = temperature # type: ignore
206
+ old_temperature = temperature
203
207
204
208
if is_multiple_completion_supported (self .langchain_llm ):
205
- return self .langchain_llm .generate_prompt (
209
+ result = self .langchain_llm .generate_prompt (
206
210
prompts = [prompt ],
207
211
n = n ,
208
- temperature = temperature ,
209
212
stop = stop ,
210
213
callbacks = callbacks ,
211
214
)
212
215
else :
213
216
result = self .langchain_llm .generate_prompt (
214
217
prompts = [prompt ] * n ,
215
- temperature = temperature ,
216
218
stop = stop ,
217
219
callbacks = callbacks ,
218
220
)
219
221
# make LLMResult.generation appear as if it was n_completions
220
222
# note that LLMResult.runs is still a list that represents each run
221
223
generations = [[g [0 ] for g in result .generations ]]
222
224
result .generations = generations
223
- return result
225
+
226
+ # reset the temperature to the original value
227
+ if old_temperature is not None :
228
+ self .langchain_llm .temperature = old_temperature # type: ignore
229
+
230
+ return result
224
231
225
232
async def agenerate_text (
226
233
self ,
@@ -230,29 +237,38 @@ async def agenerate_text(
230
237
stop : t .Optional [t .List [str ]] = None ,
231
238
callbacks : Callbacks = None ,
232
239
) -> LLMResult :
240
+ # handle temperature
241
+ old_temperature : float | None = None
233
242
if temperature is None :
234
243
temperature = self .get_temperature (n = n )
244
+ if hasattr (self .langchain_llm , "temperature" ):
245
+ self .langchain_llm .temperature = temperature # type: ignore
246
+ old_temperature = temperature
235
247
236
- if is_multiple_completion_supported (self .langchain_llm ):
237
- return await self .langchain_llm .agenerate_prompt (
248
+ # handle n
249
+ if hasattr (self .langchain_llm , "n" ):
250
+ self .langchain_llm .n = n # type: ignore
251
+ result = await self .langchain_llm .agenerate_prompt (
238
252
prompts = [prompt ],
239
- n = n ,
240
- temperature = temperature ,
241
253
stop = stop ,
242
254
callbacks = callbacks ,
243
255
)
244
256
else :
245
257
result = await self .langchain_llm .agenerate_prompt (
246
258
prompts = [prompt ] * n ,
247
- temperature = temperature ,
248
259
stop = stop ,
249
260
callbacks = callbacks ,
250
261
)
251
262
# make LLMResult.generation appear as if it was n_completions
252
263
# note that LLMResult.runs is still a list that represents each run
253
264
generations = [[g [0 ] for g in result .generations ]]
254
265
result .generations = generations
255
- return result
266
+
267
+ # reset the temperature to the original value
268
+ if old_temperature is not None :
269
+ self .langchain_llm .temperature = old_temperature # type: ignore
270
+
271
+ return result
256
272
257
273
def set_run_config (self , run_config : RunConfig ):
258
274
self .run_config = run_config
0 commit comments