Skip to content

Commit

Permalink
Merge branch 'main' into parse-tags
Browse files Browse the repository at this point in the history
  • Loading branch information
WaelKarkoub committed Mar 26, 2024
2 parents e801354 + af9b300 commit f95a678
Show file tree
Hide file tree
Showing 20 changed files with 3,773 additions and 1,279 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/contrib-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ jobs:
- name: Coverage
run: |
pip install coverage>=5.3
coverage run -a -m pytest test/agentchat/contrib/test_img_utils.py test/agentchat/contrib/test_lmm.py test/agentchat/contrib/test_llava.py test/agentchat/contrib/capabilities/test_image_generation_capability.py --skip-openai
coverage run -a -m pytest test/agentchat/contrib/test_img_utils.py test/agentchat/contrib/test_lmm.py test/agentchat/contrib/test_llava.py test/agentchat/contrib/capabilities/test_image_generation_capability.py test/agentchat/contrib/capabilities/test_vision_capability.py --skip-openai
coverage xml
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v3
Expand Down
211 changes: 211 additions & 0 deletions autogen/agentchat/contrib/capabilities/vision_capability.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
import copy
from typing import Callable, Dict, List, Optional, Union

from autogen.agentchat.assistant_agent import ConversableAgent
from autogen.agentchat.contrib.capabilities.agent_capability import AgentCapability
from autogen.agentchat.contrib.img_utils import (
convert_base64_to_data_uri,
get_image_data,
get_pil_image,
gpt4v_formatter,
message_formatter_pil_to_b64,
)
from autogen.agentchat.contrib.multimodal_conversable_agent import MultimodalConversableAgent
from autogen.agentchat.conversable_agent import colored
from autogen.code_utils import content_str
from autogen.oai.client import OpenAIWrapper

DEFAULT_DESCRIPTION_PROMPT = (
"Write a detailed caption for this image. "
"Pay special attention to any details that might be useful or relevant "
"to the ongoing conversation."
)


class VisionCapability(AgentCapability):
"""We can add vision capability to regular ConversableAgent, even if the agent does not have the multimodal capability,
such as GPT-3.5-turbo agent, Llama, Orca, or Mistral agents. This vision capability will invoke a LMM client to describe
the image (captioning) before sending the information to the agent's actual client.
The vision capability will hook to the ConversableAgent's `process_last_received_message`.
Some technical details:
When the agent (who has the vision capability) received an message, it will:
1. _process_received_message:
a. _append_oai_message
2. generate_reply: if the agent is a MultimodalAgent, it will also use the image tag.
a. hook process_last_received_message (NOTE: this is where the vision capability will be hooked to.)
b. hook process_all_messages_before_reply
3. send:
a. hook process_message_before_send
b. _append_oai_message
"""

def __init__(
self,
lmm_config: Dict,
description_prompt: Optional[str] = DEFAULT_DESCRIPTION_PROMPT,
custom_caption_func: Callable = None,
) -> None:
"""
Initializes a new instance, setting up the configuration for interacting with
a Language Multimodal (LMM) client and specifying optional parameters for image
description and captioning.
Args:
lmm_config (Dict): Configuration for the LMM client, which is used to call
the LMM service for describing the image. This must be a dictionary containing
the necessary configuration parameters. If `lmm_config` is False or an empty dictionary,
it is considered invalid, and initialization will assert.
description_prompt (Optional[str], optional): The prompt to use for generating
descriptions of the image. This parameter allows customization of the
prompt passed to the LMM service. Defaults to `DEFAULT_DESCRIPTION_PROMPT` if not provided.
custom_caption_func (Callable, optional): A callable that, if provided, will be used
to generate captions for images. This allows for custom captioning logic outside
of the standard LMM service interaction.
The callable should take three parameters as input:
1. an image URL (or local location)
2. image_data (a PIL image)
3. lmm_client (to call remote LMM)
and then return a description (as string).
If not provided, captioning will rely on the LMM client configured via `lmm_config`.
If provided, we will not run the default self._get_image_caption method.
Raises:
AssertionError: If neither a valid `lmm_config` nor a `custom_caption_func` is provided,
an AssertionError is raised to indicate that the Vision Capability requires
one of these to be valid for operation.
"""
self._lmm_config = lmm_config
self._description_prompt = description_prompt
self._parent_agent = None

if lmm_config:
self._lmm_client = OpenAIWrapper(**lmm_config)
else:
self._lmm_client = None

self._custom_caption_func = custom_caption_func
assert (
self._lmm_config or custom_caption_func
), "Vision Capability requires a valid lmm_config or custom_caption_func."

def add_to_agent(self, agent: ConversableAgent) -> None:
self._parent_agent = agent

# Append extra info to the system message.
agent.update_system_message(agent.system_message + "\nYou've been given the ability to interpret images.")

# Register a hook for processing the last message.
agent.register_hook(hookable_method="process_last_received_message", hook=self.process_last_received_message)

def process_last_received_message(self, content: Union[str, List[dict]]) -> str:
"""
Processes the last received message content by normalizing and augmenting it
with descriptions of any included images. The function supports input content
as either a string or a list of dictionaries, where each dictionary represents
a content item (e.g., text, image). If the content contains image URLs, it
fetches the image data, generates a caption for each image, and inserts the
caption into the augmented content.
The function aims to transform the content into a format compatible with GPT-4V
multimodal inputs, specifically by formatting strings into PIL-compatible
images if needed and appending text descriptions for images. This allows for
a more accessible presentation of the content, especially in contexts where
images cannot be displayed directly.
Args:
content (Union[str, List[dict]]): The last received message content, which
can be a plain text string or a list of dictionaries representing
different types of content items (e.g., text, image_url).
Returns:
str: The augmented message content
Raises:
AssertionError: If an item in the content list is not a dictionary.
Examples:
Assuming `self._get_image_caption(img_data)` returns
"A beautiful sunset over the mountains" for the image.
- Input as String:
content = "Check out this cool photo!"
Output: "Check out this cool photo!"
(Content is a string without an image, remains unchanged.)
- Input as String, with image location:
content = "What's weather in this cool photo: <img http://example.com/photo.jpg>"
Output: "What's weather in this cool photo: <img http://example.com/photo.jpg> in case you can not see, the caption of this image is:
A beautiful sunset over the mountains\n"
(Caption added after the image)
- Input as List with Text Only:
content = [{"type": "text", "text": "Here's an interesting fact."}]
Output: "Here's an interesting fact."
(No images in the content, it remains unchanged.)
- Input as List with Image URL:
content = [
{"type": "text", "text": "What's weather in this cool photo:"},
{"type": "image_url", "image_url": {"url": "http://example.com/photo.jpg"}}
]
Output: "What's weather in this cool photo: <img http://example.com/photo.jpg> in case you can not see, the caption of this image is:
A beautiful sunset over the mountains\n"
(Caption added after the image)
"""
copy.deepcopy(content)
# normalize the content into the gpt-4v format for multimodal
# we want to keep the URL format to keep it concise.
if isinstance(content, str):
content = gpt4v_formatter(content, img_format="url")

aug_content: str = ""
for item in content:
assert isinstance(item, dict)
if item["type"] == "text":
aug_content += item["text"]
elif item["type"] == "image_url":
img_url = item["image_url"]["url"]
img_caption = ""

if self._custom_caption_func:
img_caption = self._custom_caption_func(img_url, get_pil_image(img_url), self._lmm_client)
elif self._lmm_client:
img_data = get_image_data(img_url)
img_caption = self._get_image_caption(img_data)
else:
img_caption = ""

aug_content += f"<img {img_url}> in case you can not see, the caption of this image is: {img_caption}\n"
else:
print(f"Warning: the input type should either be `test` or `image_url`. Skip {item['type']} here.")

return aug_content

def _get_image_caption(self, img_data: str) -> str:
"""
Args:
img_data (str): base64 encoded image data.
Returns:
str: caption for the given image.
"""
response = self._lmm_client.create(
context=None,
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": self._description_prompt},
{
"type": "image_url",
"image_url": {
"url": convert_base64_to_data_uri(img_data),
},
},
],
}
],
)
description = response.choices[0].message.content
return content_str(description)
6 changes: 6 additions & 0 deletions autogen/agentchat/contrib/img_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@ def get_pil_image(image_file: Union[str, Image.Image]) -> Image.Image:
# Already a PIL Image object
return image_file

# Remove quotes if existed
if image_file.startswith('"') and image_file.endswith('"'):
image_file = image_file[1:-1]
if image_file.startswith("'") and image_file.endswith("'"):
image_file = image_file[1:-1]

if image_file.startswith("http://") or image_file.startswith("https://"):
# A URL file
response = requests.get(image_file)
Expand Down
81 changes: 49 additions & 32 deletions autogen/agentchat/conversable_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,37 +5,35 @@
import json
import logging
import re
import warnings
from collections import defaultdict
from functools import partial
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Type, TypeVar, Union
import warnings

from openai import BadRequestError

from autogen.exception_utils import InvalidCarryOverType, SenderRequired

from ..coding.base import CodeExecutor
from ..coding.factory import CodeExecutorFactory
from ..formatting_utils import colored

from ..oai.client import OpenAIWrapper, ModelClient
from ..runtime_logging import logging_enabled, log_new_agent
from .._pydantic import model_dump
from ..cache.cache import Cache
from ..code_utils import (
UNKNOWN,
content_str,
check_can_use_docker_or_throw,
content_str,
decide_use_docker,
execute_code,
extract_code,
infer_lang,
)
from .utils import gather_usage_summary, consolidate_chat_info
from .chat import ChatResult, initiate_chats, a_initiate_chats


from ..coding.base import CodeExecutor
from ..coding.factory import CodeExecutorFactory
from ..formatting_utils import colored
from ..function_utils import get_function_schema, load_basemodels_if_needed, serialize_to_str
from ..oai.client import ModelClient, OpenAIWrapper
from ..runtime_logging import log_new_agent, logging_enabled
from .agent import Agent, LLMAgent
from .._pydantic import model_dump
from .chat import ChatResult, a_initiate_chats, initiate_chats
from .utils import consolidate_chat_info, gather_usage_summary

__all__ = ("ConversableAgent",)

Expand Down Expand Up @@ -698,8 +696,8 @@ def _print_received_message(self, message: Union[Dict, str], sender: Agent):
id_key = "name"
else:
id_key = "tool_call_id"

func_print = f"***** Response from calling {message['role']} \"{message[id_key]}\" *****"
id = message.get(id_key, "No id found")
func_print = f"***** Response from calling {message['role']} ({id}) *****"
print(colored(func_print, "green"), flush=True)
print(message["content"], flush=True)
print(colored("*" * len(func_print), "green"), flush=True)
Expand All @@ -716,7 +714,7 @@ def _print_received_message(self, message: Union[Dict, str], sender: Agent):
if "function_call" in message and message["function_call"]:
function_call = dict(message["function_call"])
func_print = (
f"***** Suggested function Call: {function_call.get('name', '(No function name found)')} *****"
f"***** Suggested function call: {function_call.get('name', '(No function name found)')} *****"
)
print(colored(func_print, "green"), flush=True)
print(
Expand All @@ -728,9 +726,9 @@ def _print_received_message(self, message: Union[Dict, str], sender: Agent):
print(colored("*" * len(func_print), "green"), flush=True)
if "tool_calls" in message and message["tool_calls"]:
for tool_call in message["tool_calls"]:
id = tool_call.get("id", "(No id found)")
id = tool_call.get("id", "No tool call id found")
function_call = dict(tool_call.get("function", {}))
func_print = f"***** Suggested tool Call ({id}): {function_call.get('name', '(No function name found)')} *****"
func_print = f"***** Suggested tool call ({id}): {function_call.get('name', '(No function name found)')} *****"
print(colored(func_print, "green"), flush=True)
print(
"Arguments: \n",
Expand Down Expand Up @@ -1311,6 +1309,12 @@ def _generate_oai_reply_from_client(self, llm_client, messages, cache) -> Union[
)
for tool_call in extracted_response.get("tool_calls") or []:
tool_call["function"]["name"] = self._normalize_name(tool_call["function"]["name"])
# Remove id and type if they are not present.
# This is to make the tool call object compatible with Mistral API.
if tool_call.get("id") is None:
tool_call.pop("id")
if tool_call.get("type") is None:
tool_call.pop("type")
return extracted_response

async def a_generate_oai_reply(
Expand Down Expand Up @@ -1527,7 +1531,6 @@ def generate_tool_calls_reply(
message = messages[-1]
tool_returns = []
for tool_call in message.get("tool_calls", []):
id = tool_call["id"]
function_call = tool_call.get("function", {})
func = self._function_map.get(function_call.get("name", None), None)
if inspect.iscoroutinefunction(func):
Expand All @@ -1545,13 +1548,24 @@ def generate_tool_calls_reply(
loop.close()
else:
_, func_return = self.execute_function(function_call)
tool_returns.append(
{
"tool_call_id": id,
content = func_return.get("content", "")
if content is None:
content = ""
tool_call_id = tool_call.get("id", None)
if tool_call_id is not None:
tool_call_response = {
"tool_call_id": tool_call_id,
"role": "tool",
"content": func_return.get("content", ""),
"content": content,
}
)
else:
# Do not include tool_call_id if it is not present.
# This is to make the tool call object compatible with Mistral API.
tool_call_response = {
"role": "tool",
"content": content,
}
tool_returns.append(tool_call_response)
if tool_returns:
return True, {
"role": "tool",
Expand Down Expand Up @@ -2603,22 +2617,25 @@ def process_last_received_message(self, messages):
return messages # Last message contains a context key.
if "content" not in last_message:
return messages # Last message has no content.
user_text = last_message["content"]
if not isinstance(user_text, str):
return messages # Last message content is not a string. TODO: Multimodal agents will use a dict here.
if user_text == "exit":

user_content = last_message["content"]
if not isinstance(user_content, str) and not isinstance(user_content, list):
# if the user_content is a string, it is for regular LLM
# if the user_content is a list, it should follow the multimodal LMM format.
return messages
if user_content == "exit":
return messages # Last message is an exit command.

# Call each hook (in order of registration) to process the user's message.
processed_user_text = user_text
processed_user_content = user_content
for hook in hook_list:
processed_user_text = hook(processed_user_text)
if processed_user_text == user_text:
processed_user_content = hook(processed_user_content)
if processed_user_content == user_content:
return messages # No hooks actually modified the user's message.

# Replace the last user message with the expanded one.
messages = messages.copy()
messages[-1]["content"] = processed_user_text
messages[-1]["content"] = processed_user_content
return messages

def print_usage_summary(self, mode: Union[str, List[str]] = ["actual", "total"]) -> None:
Expand Down
Loading

0 comments on commit f95a678

Please sign in to comment.