Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions libs/core/langchain_core/language_models/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,8 @@ def generate_from_stream(stream: Iterator[ChatGenerationChunk]) -> ChatResult:
message=message_chunk_to_message(generation.message),
generation_info=generation.generation_info,
)
]
],
llm_output={},
)


Expand Down Expand Up @@ -596,7 +597,7 @@ def stream(
run_manager.on_llm_error(err, response=LLMResult(generations=[]))
raise err

run_manager.on_llm_end(LLMResult(generations=[[generation]]))
run_manager.on_llm_end(LLMResult(generations=[[generation]], llm_output={}))

@override
async def astream(
Expand Down Expand Up @@ -725,7 +726,7 @@ async def astream(
raise err

await run_manager.on_llm_end(
LLMResult(generations=[[generation]]),
LLMResult(generations=[[generation]], llm_output={}),
)

# --- Custom methods ---
Expand Down Expand Up @@ -1148,7 +1149,7 @@ def _generate_with_cache(
cache_val = llm_cache.lookup(prompt, llm_string)
if isinstance(cache_val, list):
converted_generations = self._convert_cached_generations(cache_val)
return ChatResult(generations=converted_generations)
return ChatResult(generations=converted_generations, llm_output={})
elif self.cache is None:
pass
else:
Expand Down Expand Up @@ -1266,7 +1267,7 @@ async def _agenerate_with_cache(
cache_val = await llm_cache.alookup(prompt, llm_string)
if isinstance(cache_val, list):
converted_generations = self._convert_cached_generations(cache_val)
return ChatResult(generations=converted_generations)
return ChatResult(generations=converted_generations, llm_output={})
elif self.cache is None:
pass
else:
Expand Down Expand Up @@ -1722,7 +1723,7 @@ def _generate(
output_str = self._call(messages, stop=stop, run_manager=run_manager, **kwargs)
message = AIMessage(content=output_str)
generation = ChatGeneration(message=message)
return ChatResult(generations=[generation])
return ChatResult(generations=[generation], llm_output={})

@abstractmethod
def _call(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def _generate(
else:
self.i = 0
generation = ChatGeneration(message=response)
return ChatResult(generations=[generation])
return ChatResult(generations=[generation], llm_output={})

@property
@override
Expand Down Expand Up @@ -261,7 +261,7 @@ def _generate(
message = next(self.messages)
message_ = AIMessage(content=message) if isinstance(message, str) else message
generation = ChatGeneration(message=message_)
return ChatResult(generations=[generation])
return ChatResult(generations=[generation], llm_output={})

def _stream(
self,
Expand Down Expand Up @@ -386,7 +386,9 @@ def _generate(
run_manager: CallbackManagerForLLMRun | None = None,
**kwargs: Any,
) -> ChatResult:
return ChatResult(generations=[ChatGeneration(message=messages[-1])])
return ChatResult(
generations=[ChatGeneration(message=messages[-1])], llm_output={}
)

@property
def _llm_type(self) -> str:
Expand Down
10 changes: 6 additions & 4 deletions libs/core/langchain_core/language_models/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,7 +564,7 @@ def stream(
run_manager.on_llm_error(err, response=LLMResult(generations=[]))
raise err

run_manager.on_llm_end(LLMResult(generations=[[generation]]))
run_manager.on_llm_end(LLMResult(generations=[[generation]], llm_output={}))

@override
async def astream(
Expand Down Expand Up @@ -635,7 +635,9 @@ async def astream(
await run_manager.on_llm_error(err, response=LLMResult(generations=[]))
raise err

await run_manager.on_llm_end(LLMResult(generations=[[generation]]))
await run_manager.on_llm_end(
LLMResult(generations=[[generation]], llm_output={})
)

# --- Custom methods ---

Expand Down Expand Up @@ -1502,7 +1504,7 @@ def _generate(
else self._call(prompt, stop=stop, **kwargs)
)
generations.append([Generation(text=text)])
return LLMResult(generations=generations)
return LLMResult(generations=generations, llm_output={})

async def _agenerate(
self,
Expand All @@ -1520,4 +1522,4 @@ async def _agenerate(
else await self._acall(prompt, stop=stop, **kwargs)
)
generations.append([Generation(text=text)])
return LLMResult(generations=generations)
return LLMResult(generations=generations, llm_output={})
44 changes: 42 additions & 2 deletions libs/core/tests/unit_tests/fake/test_fake_chat_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,15 @@

from typing_extensions import override

from langchain_core.callbacks.base import AsyncCallbackHandler
from langchain_core.callbacks.base import AsyncCallbackHandler, BaseCallbackHandler
from langchain_core.language_models import (
FakeListChatModel,
FakeMessagesListChatModel,
GenericFakeChatModel,
ParrotFakeChatModel,
)
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage, HumanMessage
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk, LLMResult
from tests.unit_tests.stubs import (
_any_id_ai_message,
_any_id_ai_message_chunk,
Expand Down Expand Up @@ -253,3 +253,43 @@ def test_fake_messages_list_chat_model_sleep_delay() -> None:
elapsed = time.time() - start

assert elapsed >= sleep_time


def test_stream_llm_result_contains_llm_output() -> None:
"""Test that streaming mode includes llm_output in LLMResult."""

class LLMResultCaptureHandler(BaseCallbackHandler):
"""Callback handler that captures LLMResult from on_llm_end."""

def __init__(self) -> None:
self.llm_results: list[LLMResult] = []

@override
def on_llm_end(
self,
response: LLMResult,
*,
run_id: UUID,
parent_run_id: UUID | None = None,
**kwargs: Any,
) -> None:
"""Capture the LLMResult."""
self.llm_results.append(response)

model = GenericFakeChatModel(messages=cycle([AIMessage(content="hello world")]))
handler = LLMResultCaptureHandler()

# Consume the stream to trigger on_llm_end
chunks = list(model.stream("test", config={"callbacks": [handler]}))

# Verify we got chunks
assert len(chunks) > 0

# Verify on_llm_end was called
assert len(handler.llm_results) == 1

# Verify llm_output field exists in the LLMResult
llm_result = handler.llm_results[0]
assert hasattr(llm_result, "llm_output")
assert llm_result.llm_output is not None
assert isinstance(llm_result.llm_output, dict)
Original file line number Diff line number Diff line change
Expand Up @@ -648,7 +648,7 @@ def i_dont_stream(value: Any, config: RunnableConfig) -> Any:
}
]
],
"llm_output": None,
"llm_output": {},
"run": None,
"type": "LLMResult",
},
Expand Down Expand Up @@ -780,7 +780,7 @@ async def ai_dont_stream(value: Any, config: RunnableConfig) -> Any:
}
]
],
"llm_output": None,
"llm_output": {},
"run": None,
"type": "LLMResult",
},
Expand Down Expand Up @@ -1030,7 +1030,7 @@ async def test_event_stream_with_simple_chain() -> None:
}
]
],
"llm_output": None,
"llm_output": {},
"run": None,
"type": "LLMResult",
},
Expand Down Expand Up @@ -1809,7 +1809,7 @@ async def test_with_llm() -> None:
}
]
],
"llm_output": None,
"llm_output": {},
"run": None,
"type": "LLMResult",
},
Expand Down
Loading