Skip to content

Added custom is_finished_parser logic to Google Vertex AI customizati… #1728

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 9, 2024

Conversation

cnemri
Copy link
Contributor

@cnemri cnemri commented Dec 4, 2024

Title: Added a custom completion detection parser for Gemini models

Description

This PR updates Ragas model customization how to guide. It adds proper completion detection support for Google's Vertex AI Gemini models. Currently, Ragas systematically raises LLMDidNotFinishException with Gemini models because it doesn't correctly interpret Gemini's completion signals.

Problem

The current implementation in LangchainLLMWrapper doesn't properly handle Gemini's completion signals:

  • Gemini uses "STOP" and "MAX_TOKENS" as valid completion reasons
  • The completion status can be found in either generation_info or response_metadata
  • The current logic doesn't account for these Gemini-specific patterns

Solution

  1. Added a custom completion detection parser for Gemini models
  2. Added the custom parser to LangchainLLMWrapper call

Changes

  1. Modified ragas/docs/howtos/customizations/customize_models.md:
from langchain_core.outputs import LLMResult, ChatGeneration

def gemini_is_finished_parser(response: LLMResult) -> bool:
    is_finished_list = []
    for g in response.flatten():
        resp = g.generations[0][0]
        
        # Check generation_info first
        if resp.generation_info is not None:
            finish_reason = resp.generation_info.get("finish_reason")
            if finish_reason is not None:
                is_finished_list.append(
                    finish_reason in ["STOP", "MAX_TOKENS"]
                )
                continue
                
        # Check response_metadata as fallback
        if isinstance(resp, ChatGeneration) and resp.message is not None:
            metadata = resp.message.response_metadata
            if metadata.get("finish_reason"):
                is_finished_list.append(
                    metadata["finish_reason"] in ["STOP", "MAX_TOKENS"]
                )
            elif metadata.get("stop_reason"):
                is_finished_list.append(
                    metadata["stop_reason"] in ["STOP", "MAX_TOKENS"] 
                )
        
        # If no finish reason found, default to True
        if not is_finished_list:
            is_finished_list.append(True)
            
    return all(is_finished_list)

vertextai_llm = LangchainLLMWrapper(vertextai_llm, is_finished_parser=gemini_is_finished_parser)

…on doc. As current LangchainLLMWrapper class does not natively support Gemini models
@dosubot dosubot bot added the size:M This PR changes 30-99 lines, ignoring generated files. label Dec 4, 2024
@cnemri
Copy link
Contributor Author

cnemri commented Dec 4, 2024

Hi, @shahules786 Is there further clarification needed on this PR?

@jjmachan
Copy link
Member

jjmachan commented Dec 9, 2024

@cnemri thanks again for this contribution 🙂

@jjmachan jjmachan merged commit 57c6cbf into explodinggradients:main Dec 9, 2024
5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
size:M This PR changes 30-99 lines, ignoring generated files.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants