Skip to content

Conversation

@cnemri
Copy link
Contributor

@cnemri cnemri commented Dec 4, 2024

Add Gemini model completion detection support

Description

This PR adds proper completion detection support for Google's Vertex AI Gemini models in the LangchainLLMWrapper class. Currently, Ragas systematically raises LLMDidNotFinishException with Gemini models because it doesn't correctly interpret Gemini's completion signals.

Problem

The current implementation in LangchainLLMWrapper doesn't properly handle Gemini's completion signals:

  • Gemini uses "STOP" and "MAX_TOKENS" as valid completion reasons
  • The completion status can be found in either generation_info or response_metadata
  • The current logic doesn't account for these Gemini-specific patterns

Solution

  1. Modified is_finished class to support completion detection for Gemini models. Added proper handling of Gemini's completion signals

Changes

  1. Modified LangchainLLMWrapper in ragas/llms/base.py:
class LangchainLLMWrapper(BaseRagasLLM):
    def __init__(
        self,
        langchain_llm: BaseLanguageModel,
        run_config: Optional[RunConfig] = None,
        is_finished_parser: Optional[Callable[[LLMResult], bool]] = None,
    ):
        self.langchain_llm = langchain_llm
        if run_config is None:
            run_config = RunConfig()
        self.set_run_config(run_config)
        self.is_finished_parser = is_finished_parser

    def is_finished(self, response: LLMResult) -> bool:
        """
        Parse the response to check if the LLM finished by checking the finish_reason
        or stop_reason. Supports OpenAI and Vertex AI models.
        """
        if self.is_finished_parser is not None:
            return self.is_finished_parser(response)
        # if no parser is provided default to our own

        is_finished_list = []
        for g in response.flatten():
            resp = g.generations[0][0]
            if resp.generation_info is not None:
                # generation_info is provided - so we parse that
                finish_reason = resp.generation_info.get("finish_reason")
                if finish_reason is not None:
                    # OpenAI uses "stop"
                    # Vertex AI uses "STOP" or "MAX_TOKENS"
                    is_finished_list.append(
                        finish_reason in ["stop", "STOP", "MAX_TOKENS"]
                    )

            # if generation_info is empty, we parse the response_metadata
            # this is less reliable
            elif (
                isinstance(resp, ChatGeneration)
                and t.cast(ChatGeneration, resp).message is not None
            ):
                resp_message: BaseMessage = t.cast(ChatGeneration, resp).message
                if resp_message.response_metadata.get("finish_reason") is not None:
                    finish_reason = resp_message.response_metadata.get("finish_reason")
                    is_finished_list.append(
                        finish_reason in ["stop", "STOP", "MAX_TOKENS"]
                    )
                elif resp_message.response_metadata.get("stop_reason") is not None:
                    stop_reason = resp_message.response_metadata.get("stop_reason")
                    is_finished_list.append(
                        stop_reason in ["end_turn", "STOP", "MAX_TOKENS"]
                    )
            # default to True
            else:
                is_finished_list.append(True)
        return all(is_finished_list)

@dosubot dosubot bot added the size:S This PR changes 10-29 lines, ignoring generated files. label Dec 4, 2024
@cnemri
Copy link
Contributor Author

cnemri commented Dec 4, 2024

Hi, @shahules786 Is there further clarification needed on this PR?

Copy link
Member

@jjmachan jjmachan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks a lot for taking the time to add this into Ragas @cnemri ❤️ 🙂

@jjmachan jjmachan merged commit 34a7db2 into explodinggradients:main Dec 9, 2024
16 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

size:S This PR changes 10-29 lines, ignoring generated files.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants