From 6eb3145432522b969ccc0e8d212fa79e04c3f2ce Mon Sep 17 00:00:00 2001 From: Tianqi Xu Date: Tue, 15 Oct 2024 13:39:26 +0300 Subject: [PATCH] Fix camel model in the new backend model ABC --- crab-benchmark-v0/main.py | 9 +- crab/agents/backend_models/__init__.py | 83 ++++++++++++++----- crab/agents/backend_models/camel_model.py | 16 ++-- .../agents/backend_models/test_camel_model.py | 17 ++-- 4 files changed, 88 insertions(+), 37 deletions(-) diff --git a/crab-benchmark-v0/main.py b/crab-benchmark-v0/main.py index 79e4afa..f1751ed 100644 --- a/crab-benchmark-v0/main.py +++ b/crab-benchmark-v0/main.py @@ -238,22 +238,25 @@ def get_benchmark(env: str, ubuntu_url: str): ) elif args.model == "pixtral": model = BackendModelConfig( - model_class="openai-json", + model_class="openai", model_name="mistralai/Pixtral-12B-2409", + json_structre_output=True, history_messages_len=args.history_messages_len, base_url=args.model_base_url, api_key=args.model_api_key, ) elif args.model == "gpt4o-wofc": model = BackendModelConfig( - model_class="openai-json", + model_class="openai", model_name="gpt-4o", + json_structre_output=True, history_messages_len=args.history_messages_len, ) elif args.model == "llava-ov72b": model = BackendModelConfig( - model_class="sglang-openai-json", + model_class="sglang", model_name="lmms-lab/llava-onevision-qwen2-72b-ov-chat", + json_structre_output=True, history_messages_len=args.history_messages_len, base_url=args.model_base_url, api_key=args.model_api_key, diff --git a/crab/agents/backend_models/__init__.py b/crab/agents/backend_models/__init__.py index 6c6bdab..32b21cc 100644 --- a/crab/agents/backend_models/__init__.py +++ b/crab/agents/backend_models/__init__.py @@ -25,13 +25,43 @@ class BackendModelConfig(BaseModel): - model_class: Literal["openai", "claude", "gemini", "camel", "vllm", "sglang"] + model_class: Literal["openai", "claude", "gemini", "camel", "sglang"] + """Specify the model class to be used. Different model classese use different + APIs. + """ + model_name: str + """Specify the model name to be used. This value is directly passed to the API, + check model provider API documentation for more details. + """ + + model_platform: str | None = None + """Required for CamelModel. Otherwise, it is ignored. Please check CAMEL + documentation for more details. + """ + history_messages_len: int = 0 + """Number of rounds of previous messages to be used in the model input. 0 means no + history. + """ + parameters: dict[str, Any] = {} + """Additional parameters to be passed to the model.""" + + json_structre_output: bool = False + """If True, the model generate action through JSON without using "tool call" or + "function call". SGLang model only supports JSON output. OpenAI model supports both. + Other models do not support JSON output. + """ + tool_call_required: bool = True - base_url: str | None = None # Only used in OpenAIModel and VLLMModel currently - api_key: str | None = None # Only used in OpenAIModel and VLLMModel currently + """Specify if the model enforce each round to generate tool/function calls.""" + + base_url: str | None = None + """Specify the base URL of the API. Only used in OpenAI and SGLang currently.""" + + api_key: str | None = None + """Specify the API key to be used. Only used in OpenAI and SGLang currently.""" def create_backend_model(model_config: BackendModelConfig) -> BackendModel: @@ -41,6 +71,10 @@ def create_backend_model(model_config: BackendModelConfig) -> BackendModel: raise Warning( "base_url and api_key are not supported for ClaudeModel currently." ) + if model_config.json_structre_output: + raise Warning( + "json_structre_output is not supported for ClaudeModel currently." + ) return ClaudeModel( model=model_config.model_name, parameters=model_config.parameters, @@ -52,6 +86,10 @@ def create_backend_model(model_config: BackendModelConfig) -> BackendModel: raise Warning( "base_url and api_key are not supported for GeminiModel currently." ) + if model_config.json_structre_output: + raise Warning( + "json_structre_output is not supported for GeminiModel currently." + ) return GeminiModel( model=model_config.model_name, parameters=model_config.parameters, @@ -59,31 +97,38 @@ def create_backend_model(model_config: BackendModelConfig) -> BackendModel: tool_call_required=model_config.tool_call_required, ) case "openai": - return OpenAIModel( - model=model_config.model_name, - parameters=model_config.parameters, - history_messages_len=model_config.history_messages_len, - base_url=model_config.base_url, - api_key=model_config.api_key, - tool_call_required=model_config.tool_call_required, - ) - case "openai-json": - return OpenAIModelJSON( + if not model_config.json_structre_output: + return OpenAIModel( + model=model_config.model_name, + parameters=model_config.parameters, + history_messages_len=model_config.history_messages_len, + base_url=model_config.base_url, + api_key=model_config.api_key, + tool_call_required=model_config.tool_call_required, + ) + else: + return OpenAIModelJSON( + model=model_config.model_name, + parameters=model_config.parameters, + history_messages_len=model_config.history_messages_len, + base_url=model_config.base_url, + api_key=model_config.api_key, + ) + case "sglang": + return SGlangOpenAIModelJSON( model=model_config.model_name, parameters=model_config.parameters, history_messages_len=model_config.history_messages_len, base_url=model_config.base_url, api_key=model_config.api_key, ) - case "sglang-openai-json": - return SGlangOpenAIModelJSON( + case "camel": + return CamelModel( model=model_config.model_name, + model_platform=model_config.model_platform, parameters=model_config.parameters, history_messages_len=model_config.history_messages_len, - base_url=model_config.base_url, - api_key=model_config.api_key, + tool_call_required=model_config.tool_call_required, ) - case "camel": - raise NotImplementedError("Cannot support camel model currently.") case _: raise ValueError(f"Unsupported model name: {model_config.model_name}") diff --git a/crab/agents/backend_models/camel_model.py b/crab/agents/backend_models/camel_model.py index 6631c4c..636006b 100644 --- a/crab/agents/backend_models/camel_model.py +++ b/crab/agents/backend_models/camel_model.py @@ -84,20 +84,20 @@ def __init__( model_platform: str, parameters: dict[str, Any] | None = None, history_messages_len: int = 0, + tool_call_required: bool = True, ) -> None: if not CAMEL_ENABLED: raise ImportError("Please install camel-ai to use CamelModel") - self.parameters = parameters or {} + self.model = model + self.parameters = parameters if parameters is not None else {} + self.history_messages_len = history_messages_len + self.model_type = _get_model_type(model) self.model_platform_type = _get_model_platform_type(model_platform) self.client: ChatAgent | None = None self.token_usage = 0 - - super().__init__( - model, - parameters, - history_messages_len, - ) + self.tool_call_required = tool_call_required + self.history_messages_len = history_messages_len def get_token_usage(self) -> int: return self.token_usage @@ -106,7 +106,7 @@ def reset(self, system_message: str, action_space: list[Action] | None) -> None: action_schema = _convert_action_to_schema(action_space) config = self.parameters.copy() if action_schema is not None: - config["tool_choice"] = "required" + config["tool_choice"] = "required" if self.tool_call_required else "auto" config["tools"] = [ schema.get_openai_tool_schema() for schema in action_schema ] diff --git a/test/agents/backend_models/test_camel_model.py b/test/agents/backend_models/test_camel_model.py index f1239ba..8694900 100644 --- a/test/agents/backend_models/test_camel_model.py +++ b/test/agents/backend_models/test_camel_model.py @@ -14,16 +14,19 @@ import pytest from crab import action -from crab.agents.backend_models import CamelModel +from crab.agents.backend_models import BackendModelConfig, create_backend_model @pytest.fixture def camel_model(): - return CamelModel( - model_platform="openai", - model="gpt-4o", - parameters={"max_tokens": 3000}, - history_messages_len=1, + return create_backend_model( + BackendModelConfig( + model_class="camel", + model_name="gpt-4o", + model_platform="openai", + parameters={"max_tokens": 3000}, + history_messages_len=1, + ) ) @@ -38,7 +41,7 @@ def add(a: int, b: int): return a + b -@pytest.mark.skip(reason="Mock data to be added") +# @pytest.mark.skip(reason="Mock data to be added") def test_action_chat(camel_model): camel_model.reset("You are a helpful assistant.", [add]) message = (