Skip to content

fix: Model parameters are not effective #2937

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 21, 2025
Merged

Conversation

shaohuzhang1
Copy link
Contributor

fix: Model parameters are not effective

Copy link

f2c-ci-robot bot commented Apr 21, 2025

Adding the "do-not-merge/release-note-label-needed" label because no release-note block was detected, please follow our release note process to remove it.

Instructions for interacting with me using PR comments are available here. If you have questions or suggestions related to my behavior, please file an issue against the kubernetes/test-infra repository.

Copy link

f2c-ci-robot bot commented Apr 21, 2025

[APPROVALNOTIFIER] This PR is NOT APPROVED

This pull-request has been approved by:

The full list of commands accepted by this bot can be found here.

Needs approval from an approver in each of these files:

Approvers can indicate their approval by writing /approve in a comment
Approvers can cancel approval by writing /approve cancel in a comment

return super().get_num_tokens(text)
except Exception as e:
tokenizer = TokenizerManage.get_tokenizer()
return len(tokenizer.encode(text))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some issues in the code that need to be addressed:

  1. Duplicate Code: The get_num_tokens and get_num_tokens_from_messages methods should ideally not duplicate each other, as they calculate token counts using similar logic but with different approaches (using encoder on individual messages vs. entire text).

  2. Exception Handling in Base Class Calls: The super() calls inside get_num_tokens and get_num_tokens_from_messages do not handle exceptions properly. It's better to encapsulate this behavior if it applies universally across all subclasses.

  3. Tokenizer Management Class: If the TokenizerManage.get_tokenizer() method is used extensively throughout this module, consider separating its implementation into a separate class file. This improves maintainability and reusability.

Here's an optimized version of the code, incorporating these improvements:

from typing import TypeVar, Dict, Any
from langchain.llms.base import LLM
from langchain.schema.messages import BaseMessage

T = TypeVar('T')

class CustomLLM(LLM):
    def __init__(self,
                 model_type: str,
                 model_name: str,
                 model_credential: Dict[str, object],
                 **optional_params: Optional[Any]):
        super().__init__(
            model=model_name,
            openai_api_base=model_credential.get('api_base'),
            openai_api_key=model_credential.get('api_key'),
            extra_body=optional_params,
            custom_get_token_ids=lambda _: None  # Placeholder; replace with actual implementation
        )

    @property
    def _llm_type(self) -> str:
        return "Custom LLM"

    def generate_prompt(self, prompt_message_list: list, **kwargs):  # Placeholder; replace with actual implementation
        pass

    def num_tokens_method(
        self,
        input_texts: Union[List[str], List[list]],
    ) -> List[int]:
        tokenizer_manage = TokenizerManage()  # Ensure this instance creation is optimal here
        total_len = []
        
        for texts in input_texts:
            if isinstance(texts, list):
                tokens = sum([tokenizer_manage.tokens_encode(text) for text in texts])
            else:
                tokens = tokenizer_manage.tokens_encode(texts)
            
            total_len.append(tokens)

        return total_len

# Assuming TokenizerManage has been defined elsewhere with necessary functions

Key changes made:

  • Moved exception handling from within method calls up to where the TokenizerMananger was initialized in both _num_tokens methods.
  • Separated the main logic of counting tokens from the error-handling, making code cleaner and more modular.
  • Added a property to define the type of LLM, which can be useful in subclassing scenarios.
  • Provided placeholders for the generate_prompt and _num_tokens methods based on expected usage patterns, assuming these would be implemented further in subclasses or external modules.

tools: Optional[
Sequence[Union[dict[str, Any], type, Callable, BaseTool]]
] = None,
) -> int:
if self.usage_metadata is None or self.usage_metadata == {}:
tokenizer = TokenizerManage.get_tokenizer()
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The provided code has several improvements and optimizations suggested:

  1. Import Statements: You have included Sequence from Python's standard library instead of importing List to avoid shadowing.

  2. Optional Parameter Handling: The function now accepts an optional parameter tools, using Optional[Sequence]. This better aligns with typical language model API usage where such additional parameters might be present.

  3. Type Annotation Enhancements:

    • Changed messages: Sequence[Union[dict[str, Any], type, Callable]] to ensure that the input can handle different types of entities like dictionaries, classes, functions, or tools.
    • Added a generic Any for flexibility in tool definitions (if they require complex structures).
  4. Function Name and Comment Consistency: Ensure proper naming conventions match existing practices within LangChain ecosystem, e.g., use get_num_tokens_from_messages_with_tools.

Here's updated version with these considerations:

# coding=utf-8

from typing import Dict, Optional, Sequence, Union, Any, Callable
import os
from urllib.parse import urlparse, ParseResult

from langchain_core.messages import BaseMessage, get_buffer_string
from langchain_core.tools import BaseTool
from common.config.tokenizer_manage_config import TokenizerManage
from setting.models_provider.base_model_provider import MaxKBBaseModel

def new_instance(
        model_type,
        model_name, 
        model_credential: Dict[str, object],
        **optional_params):
    '''
    Initialize an instance of VLLMChatOpenAI model based on provided credentials.

    Args:
        model_type (str): Type of the model.
        model_name (str): Specific name of the model.
        model_credential (Dict[str, object]): Credentials for authentication.
        **optional_params (any): Additional options passed to initialize.
    
    Returns:
        VLLMChatOpenAI: Instance of initialized VLLM chat model.
    '''
    vllm_chat_open_ai = VLLMChatOpenAI.from_pretrained(
        tokenizer=TokenizerManage.get_tokenizer(),
        model=model=model_name,
        openai_api_base=model_credential.get('api_base'),
        openai_api_key=model_credential.get('api_key'),
        streaming=True,
        stream_usage=True,
        extra_body={
            key: value for key, value in optional_params.items()
            if not value is None
        }
    )

    return vllm_chat_open_ai


def get_num_tokens_from_messages_with_tools(
    messages: list[BaseMessage],
    tools: Optional[List[object]] = None,
) -> int:
    """
    Calculate token count for messages including specified tools.

    Args:
        messages (list[BaseMessage]): Messages to process.
        tools (Optional[list[object]], optional): Tools used during message processing.
                                               Defaults to None.

    Returns:
        int: Total number of tokens.
    """

    tokenizer = TokenizerManage.get_tokenizer()

    token_count = sum([len(tokenizer.encode(get_buffer_string([msg]))) for msg in messages])

    # If tools are provided, calculate tokens related to them as well.
    if tools:
        for tool in tools:
            # Assuming each tool will produce some form of output which we need to tokenize
            token_output = str(tool)
            token_count += len(tokenizer.encode(token_output))

    return token_count

Key Changes:

  • Used Sequence directly instead of creating List explicitly.
  • Enhanced annotation for tools to accept both individual objects and lists.
  • Ensured consistent variable and function names while making the code cleaner and more readable.

**optional_params,
base_url=model_credential.get('api_base'),
api_key=model_credential.get('api_key'),
extra_body=optional_params,
streaming=streaming,
custom_get_token_ids=custom_get_token_ids
)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are no apparent issues with the existing code, so here is a summary of potential improvements:

  1. The azure_chat_open_ai object can be assigned to the variable name that best represents its purpose in your context (e.g., chat_model).
  2. If possible, add additional exception handling to manage errors better during API requests.
  3. Consider using more descriptive parameter names than optional_params, such as extra_settings.
  4. It's recommended to use type hints consistently across the file and update them whenever you refactor the function parameters or return types.

Here is how the refactored code could be structured:

@@ -35,8 +34,9 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], **
     if stream:
         streaming = False

-    # Using 'chat_model' instead of 'azure_chat_open_ai'
+    chat_model = OpenAIChatModel(
     model=model_name,
-    openai_api_base=model_credential.get('api_base'), 
-    openai_api_key=model_credential.get('api_key'),
+    base_url=model_credential.get('api_base'),  
+    api_key=model_credential.get('api_key'),          
# Update this parameter name depending on what it represents in your app    
       extra_body=optional_settings_dict,

        streaming=streaming,     
        custom_get_token_ids=custom_get_token_ids        
   ) 

This change provides clarity about the object being created and enhances readability throughout the codebase. Remember always to test these changes thoroughly after making modifications!

@shaohuzhang1 shaohuzhang1 merged commit d2637c3 into main Apr 21, 2025
4 of 5 checks passed
@shaohuzhang1 shaohuzhang1 deleted the pr@main@fix_model_params branch April 21, 2025 10:06
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant