Skip to content

Commit

Permalink
Fix camel model in the new backend model ABC
Browse files Browse the repository at this point in the history
  • Loading branch information
dandansamax committed Oct 15, 2024
1 parent 05f87aa commit 6eb3145
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 37 deletions.
9 changes: 6 additions & 3 deletions crab-benchmark-v0/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,22 +238,25 @@ def get_benchmark(env: str, ubuntu_url: str):
)
elif args.model == "pixtral":
model = BackendModelConfig(
model_class="openai-json",
model_class="openai",
model_name="mistralai/Pixtral-12B-2409",
json_structre_output=True,
history_messages_len=args.history_messages_len,
base_url=args.model_base_url,
api_key=args.model_api_key,
)
elif args.model == "gpt4o-wofc":
model = BackendModelConfig(
model_class="openai-json",
model_class="openai",
model_name="gpt-4o",
json_structre_output=True,
history_messages_len=args.history_messages_len,
)
elif args.model == "llava-ov72b":
model = BackendModelConfig(
model_class="sglang-openai-json",
model_class="sglang",
model_name="lmms-lab/llava-onevision-qwen2-72b-ov-chat",
json_structre_output=True,
history_messages_len=args.history_messages_len,
base_url=args.model_base_url,
api_key=args.model_api_key,
Expand Down
83 changes: 64 additions & 19 deletions crab/agents/backend_models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,43 @@


class BackendModelConfig(BaseModel):
model_class: Literal["openai", "claude", "gemini", "camel", "vllm", "sglang"]
model_class: Literal["openai", "claude", "gemini", "camel", "sglang"]
"""Specify the model class to be used. Different model classese use different
APIs.
"""

model_name: str
"""Specify the model name to be used. This value is directly passed to the API,
check model provider API documentation for more details.
"""

model_platform: str | None = None
"""Required for CamelModel. Otherwise, it is ignored. Please check CAMEL
documentation for more details.
"""

history_messages_len: int = 0
"""Number of rounds of previous messages to be used in the model input. 0 means no
history.
"""

parameters: dict[str, Any] = {}
"""Additional parameters to be passed to the model."""

json_structre_output: bool = False
"""If True, the model generate action through JSON without using "tool call" or
"function call". SGLang model only supports JSON output. OpenAI model supports both.
Other models do not support JSON output.
"""

tool_call_required: bool = True
base_url: str | None = None # Only used in OpenAIModel and VLLMModel currently
api_key: str | None = None # Only used in OpenAIModel and VLLMModel currently
"""Specify if the model enforce each round to generate tool/function calls."""

base_url: str | None = None
"""Specify the base URL of the API. Only used in OpenAI and SGLang currently."""

api_key: str | None = None
"""Specify the API key to be used. Only used in OpenAI and SGLang currently."""


def create_backend_model(model_config: BackendModelConfig) -> BackendModel:
Expand All @@ -41,6 +71,10 @@ def create_backend_model(model_config: BackendModelConfig) -> BackendModel:
raise Warning(
"base_url and api_key are not supported for ClaudeModel currently."
)
if model_config.json_structre_output:
raise Warning(
"json_structre_output is not supported for ClaudeModel currently."
)
return ClaudeModel(
model=model_config.model_name,
parameters=model_config.parameters,
Expand All @@ -52,38 +86,49 @@ def create_backend_model(model_config: BackendModelConfig) -> BackendModel:
raise Warning(
"base_url and api_key are not supported for GeminiModel currently."
)
if model_config.json_structre_output:
raise Warning(
"json_structre_output is not supported for GeminiModel currently."
)
return GeminiModel(
model=model_config.model_name,
parameters=model_config.parameters,
history_messages_len=model_config.history_messages_len,
tool_call_required=model_config.tool_call_required,
)
case "openai":
return OpenAIModel(
model=model_config.model_name,
parameters=model_config.parameters,
history_messages_len=model_config.history_messages_len,
base_url=model_config.base_url,
api_key=model_config.api_key,
tool_call_required=model_config.tool_call_required,
)
case "openai-json":
return OpenAIModelJSON(
if not model_config.json_structre_output:
return OpenAIModel(
model=model_config.model_name,
parameters=model_config.parameters,
history_messages_len=model_config.history_messages_len,
base_url=model_config.base_url,
api_key=model_config.api_key,
tool_call_required=model_config.tool_call_required,
)
else:
return OpenAIModelJSON(
model=model_config.model_name,
parameters=model_config.parameters,
history_messages_len=model_config.history_messages_len,
base_url=model_config.base_url,
api_key=model_config.api_key,
)
case "sglang":
return SGlangOpenAIModelJSON(
model=model_config.model_name,
parameters=model_config.parameters,
history_messages_len=model_config.history_messages_len,
base_url=model_config.base_url,
api_key=model_config.api_key,
)
case "sglang-openai-json":
return SGlangOpenAIModelJSON(
case "camel":
return CamelModel(
model=model_config.model_name,
model_platform=model_config.model_platform,
parameters=model_config.parameters,
history_messages_len=model_config.history_messages_len,
base_url=model_config.base_url,
api_key=model_config.api_key,
tool_call_required=model_config.tool_call_required,
)
case "camel":
raise NotImplementedError("Cannot support camel model currently.")
case _:
raise ValueError(f"Unsupported model name: {model_config.model_name}")
16 changes: 8 additions & 8 deletions crab/agents/backend_models/camel_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,20 +84,20 @@ def __init__(
model_platform: str,
parameters: dict[str, Any] | None = None,
history_messages_len: int = 0,
tool_call_required: bool = True,
) -> None:
if not CAMEL_ENABLED:
raise ImportError("Please install camel-ai to use CamelModel")
self.parameters = parameters or {}
self.model = model
self.parameters = parameters if parameters is not None else {}
self.history_messages_len = history_messages_len

self.model_type = _get_model_type(model)
self.model_platform_type = _get_model_platform_type(model_platform)
self.client: ChatAgent | None = None
self.token_usage = 0

super().__init__(
model,
parameters,
history_messages_len,
)
self.tool_call_required = tool_call_required
self.history_messages_len = history_messages_len

def get_token_usage(self) -> int:
return self.token_usage
Expand All @@ -106,7 +106,7 @@ def reset(self, system_message: str, action_space: list[Action] | None) -> None:
action_schema = _convert_action_to_schema(action_space)
config = self.parameters.copy()
if action_schema is not None:
config["tool_choice"] = "required"
config["tool_choice"] = "required" if self.tool_call_required else "auto"
config["tools"] = [
schema.get_openai_tool_schema() for schema in action_schema
]
Expand Down
17 changes: 10 additions & 7 deletions test/agents/backend_models/test_camel_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,19 @@
import pytest

from crab import action
from crab.agents.backend_models import CamelModel
from crab.agents.backend_models import BackendModelConfig, create_backend_model


@pytest.fixture
def camel_model():
return CamelModel(
model_platform="openai",
model="gpt-4o",
parameters={"max_tokens": 3000},
history_messages_len=1,
return create_backend_model(
BackendModelConfig(
model_class="camel",
model_name="gpt-4o",
model_platform="openai",
parameters={"max_tokens": 3000},
history_messages_len=1,
)
)


Expand All @@ -38,7 +41,7 @@ def add(a: int, b: int):
return a + b


@pytest.mark.skip(reason="Mock data to be added")
# @pytest.mark.skip(reason="Mock data to be added")
def test_action_chat(camel_model):
camel_model.reset("You are a helpful assistant.", [add])
message = (
Expand Down

0 comments on commit 6eb3145

Please sign in to comment.