-
Notifications
You must be signed in to change notification settings - Fork 348
Fix multimodal #319
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Fix multimodal #319
Changes from all commits
605f1f0
fd80974
4145595
d121c18
922681a
8b6defa
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -106,6 +106,11 @@ class OpenAIClient(ModelClient): | |
Users (1) simplify use ``Embedder`` and ``Generator`` components by passing OpenAIClient() as the model_client. | ||
(2) can use this as an example to create their own API client or extend this class(copying and modifing the code) in their own project. | ||
|
||
Args: | ||
api_key (Optional[str], optional): OpenAI API key. Defaults to None. | ||
chat_completion_parser (Callable[[Completion], Any], optional): A function to parse the chat completion to a str. Defaults to None. | ||
input_type (Literal["text", "messages"], optional): The type of input to use. Defaults to "text". | ||
|
||
Note: | ||
We suggest users not to use `response_format` to enforce output data type or `tools` and `tool_choice` in your model_kwargs when calling the API. | ||
We do not know how OpenAI is doing the formating or what prompt they have added. | ||
|
@@ -120,14 +125,9 @@ class OpenAIClient(ModelClient): | |
- prompt: Text description of the image to generate | ||
- size: "1024x1024", "1024x1792", or "1792x1024" for DALL-E 3; "256x256", "512x512", or "1024x1024" for DALL-E 2 | ||
- quality: "standard" or "hd" (DALL-E 3 only) | ||
- n: Number of images to generate (1 for DALL-E 3, 1-10 for DALL-E 2) | ||
- n: Number of images (1 for DALL-E 3, 1-10 for DALL-E 2) | ||
- response_format: "url" or "b64_json" | ||
|
||
Args: | ||
api_key (Optional[str], optional): OpenAI API key. Defaults to None. | ||
chat_completion_parser (Callable[[Completion], Any], optional): A function to parse the chat completion to a str. Defaults to None. | ||
Default is `get_first_message_content`. | ||
|
||
References: | ||
- Embeddings models: https://platform.openai.com/docs/guides/embeddings | ||
- Chat models: https://platform.openai.com/docs/guides/text-generation | ||
|
@@ -146,6 +146,8 @@ def __init__( | |
|
||
Args: | ||
api_key (Optional[str], optional): OpenAI API key. Defaults to None. | ||
chat_completion_parser (Callable[[Completion], Any], optional): A function to parse the chat completion to a str. Defaults to None. | ||
input_type (Literal["text", "messages"], optional): The type of input to use. Defaults to "text". | ||
""" | ||
super().__init__() | ||
self._api_key = api_key | ||
|
@@ -229,7 +231,7 @@ def convert_inputs_to_api_kwargs( | |
self, | ||
input: Optional[Any] = None, | ||
model_kwargs: Dict = {}, | ||
model_type: ModelType = ModelType.UNDEFINED, | ||
model_type: ModelType = ModelType.UNDEFINED, # Now required in practice | ||
) -> Dict: | ||
r""" | ||
Specify the API input type and output api_kwargs that will be used in _call and _acall methods. | ||
|
@@ -243,11 +245,24 @@ def convert_inputs_to_api_kwargs( | |
- images: Optional image source(s) as path, URL, or list of them | ||
- detail: Image detail level ('auto', 'low', or 'high'), defaults to 'auto' | ||
- model: The model to use (must support multimodal inputs if images are provided) | ||
model_type: The type of model (EMBEDDER or LLM) | ||
For image generation: | ||
- model: "dall-e-3" or "dall-e-2" | ||
- size: "1024x1024", "1024x1792", or "1792x1024" for DALL-E 3; "256x256", "512x512", or "1024x1024" for DALL-E 2 | ||
- quality: "standard" or "hd" (DALL-E 3 only) | ||
- n: Number of images (1 for DALL-E 3, 1-10 for DALL-E 2) | ||
- response_format: "url" or "b64_json" | ||
For image edits (DALL-E 2 only): | ||
- image: Path to the input image | ||
- mask: Path to the mask image | ||
For variations (DALL-E 2 only): | ||
- image: Path to the input image | ||
model_type: The type of model to use (EMBEDDER, LLM, or IMAGE_GENERATION). Required. | ||
|
||
Returns: | ||
Dict: API-specific kwargs for the model call | ||
""" | ||
if model_type == ModelType.UNDEFINED: | ||
raise ValueError("model_type must be specified") | ||
|
||
final_model_kwargs = model_kwargs.copy() | ||
if model_type == ModelType.EMBEDDER: | ||
|
@@ -308,24 +323,43 @@ def convert_inputs_to_api_kwargs( | |
# Ensure model is specified | ||
if "model" not in final_model_kwargs: | ||
raise ValueError("model must be specified for image generation") | ||
# Set defaults for DALL-E 3 if not specified | ||
final_model_kwargs["size"] = final_model_kwargs.get("size", "1024x1024") | ||
final_model_kwargs["quality"] = final_model_kwargs.get( | ||
"quality", "standard" | ||
) | ||
final_model_kwargs["n"] = final_model_kwargs.get("n", 1) | ||
final_model_kwargs["response_format"] = final_model_kwargs.get( | ||
"response_format", "url" | ||
) | ||
|
||
# Handle image edits and variations | ||
image = final_model_kwargs.get("image") | ||
if isinstance(image, str) and os.path.isfile(image): | ||
final_model_kwargs["image"] = self._encode_image(image) | ||
|
||
mask = final_model_kwargs.get("mask") | ||
if isinstance(mask, str) and os.path.isfile(mask): | ||
final_model_kwargs["mask"] = self._encode_image(mask) | ||
# Set defaults for image generation | ||
if "operation" not in final_model_kwargs: | ||
final_model_kwargs["operation"] = "generate" # Default operation | ||
|
||
operation = final_model_kwargs.pop("operation") | ||
|
||
if operation == "generate": | ||
# Set defaults for DALL-E 3 if not specified | ||
final_model_kwargs["size"] = final_model_kwargs.get("size", "1024x1024") | ||
final_model_kwargs["quality"] = final_model_kwargs.get("quality", "standard") | ||
final_model_kwargs["n"] = final_model_kwargs.get("n", 1) | ||
final_model_kwargs["response_format"] = final_model_kwargs.get("response_format", "url") | ||
|
||
elif operation in ["edit", "variation"]: | ||
if "model" not in final_model_kwargs or final_model_kwargs["model"] != "dall-e-2": | ||
raise ValueError(f"{operation} operation is only available with DALL-E 2") | ||
|
||
# Handle image input | ||
image_path = final_model_kwargs.get("image") | ||
if not image_path or not os.path.isfile(image_path): | ||
raise ValueError(f"Valid image path must be provided for {operation}") | ||
final_model_kwargs["image"] = open(image_path, "rb") | ||
|
||
# Handle mask for edit operation | ||
if operation == "edit": | ||
mask_path = final_model_kwargs.get("mask") | ||
if not mask_path or not os.path.isfile(mask_path): | ||
raise ValueError("Valid mask path must be provided for edit operation") | ||
final_model_kwargs["mask"] = open(mask_path, "rb") | ||
|
||
# Set defaults | ||
final_model_kwargs["size"] = final_model_kwargs.get("size", "1024x1024") | ||
final_model_kwargs["n"] = final_model_kwargs.get("n", 1) | ||
final_model_kwargs["response_format"] = final_model_kwargs.get("response_format", "url") | ||
|
||
else: | ||
raise ValueError(f"Invalid operation: {operation}") | ||
else: | ||
raise ValueError(f"model_type {model_type} is not supported") | ||
return final_model_kwargs | ||
|
@@ -361,6 +395,9 @@ def call(self, api_kwargs: Dict = {}, model_type: ModelType = ModelType.UNDEFINE | |
""" | ||
kwargs is the combined input and model_kwargs. Support streaming call. | ||
""" | ||
if model_type == ModelType.UNDEFINED: | ||
raise ValueError("model_type must be specified") | ||
|
||
log.info(f"api_kwargs: {api_kwargs}") | ||
if model_type == ModelType.EMBEDDER: | ||
return self.sync_client.embeddings.create(**api_kwargs) | ||
|
@@ -371,18 +408,25 @@ def call(self, api_kwargs: Dict = {}, model_type: ModelType = ModelType.UNDEFINE | |
return self.sync_client.chat.completions.create(**api_kwargs) | ||
return self.sync_client.chat.completions.create(**api_kwargs) | ||
elif model_type == ModelType.IMAGE_GENERATION: | ||
# Determine which image API to call based on the presence of image/mask | ||
if "image" in api_kwargs: | ||
if "mask" in api_kwargs: | ||
# Image edit | ||
operation = api_kwargs.pop("operation", "generate") | ||
|
||
try: | ||
if operation == "generate": | ||
response = self.sync_client.images.generate(**api_kwargs) | ||
elif operation == "edit": | ||
response = self.sync_client.images.edit(**api_kwargs) | ||
else: | ||
# Image variation | ||
elif operation == "variation": | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. create_variation |
||
response = self.sync_client.images.create_variation(**api_kwargs) | ||
else: | ||
# Image generation | ||
response = self.sync_client.images.generate(**api_kwargs) | ||
return response.data | ||
else: | ||
raise ValueError(f"Invalid operation: {operation}") | ||
|
||
return response.data | ||
finally: | ||
# Clean up file handles if they exist | ||
if "image" in api_kwargs and hasattr(api_kwargs["image"], "close"): | ||
api_kwargs["image"].close() | ||
if "mask" in api_kwargs and hasattr(api_kwargs["mask"], "close"): | ||
api_kwargs["mask"].close() | ||
else: | ||
raise ValueError(f"model_type {model_type} is not supported") | ||
|
||
|
@@ -403,27 +447,35 @@ async def acall( | |
""" | ||
kwargs is the combined input and model_kwargs | ||
""" | ||
if model_type == ModelType.UNDEFINED: | ||
raise ValueError("model_type must be specified") | ||
|
||
if self.async_client is None: | ||
self.async_client = self.init_async_client() | ||
if model_type == ModelType.EMBEDDER: | ||
return await self.async_client.embeddings.create(**api_kwargs) | ||
elif model_type == ModelType.LLM: | ||
return await self.async_client.chat.completions.create(**api_kwargs) | ||
elif model_type == ModelType.IMAGE_GENERATION: | ||
# Determine which image API to call based on the presence of image/mask | ||
if "image" in api_kwargs: | ||
if "mask" in api_kwargs: | ||
# Image edit | ||
operation = api_kwargs.pop("operation", "generate") | ||
|
||
try: | ||
if operation == "generate": | ||
response = await self.async_client.images.generate(**api_kwargs) | ||
elif operation == "edit": | ||
response = await self.async_client.images.edit(**api_kwargs) | ||
elif operation == "variation": | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. create_variation |
||
response = await self.async_client.images.create_variation(**api_kwargs) | ||
else: | ||
# Image variation | ||
response = await self.async_client.images.create_variation( | ||
**api_kwargs | ||
) | ||
else: | ||
# Image generation | ||
response = await self.async_client.images.generate(**api_kwargs) | ||
return response.data | ||
raise ValueError(f"Invalid operation: {operation}") | ||
|
||
return response.data | ||
finally: | ||
# Clean up file handles if they exist | ||
if "image" in api_kwargs and hasattr(api_kwargs["image"], "close"): | ||
api_kwargs["image"].close() | ||
if "mask" in api_kwargs and hasattr(api_kwargs["mask"], "close"): | ||
api_kwargs["mask"].close() | ||
else: | ||
raise ValueError(f"model_type {model_type} is not supported") | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -70,21 +70,12 @@ class Generator(GradComponent, CachedEngine, CallbackManager): | |
template (Optional[str], optional): The template for the prompt. Defaults to :ref:`DEFAULT_ADALFLOW_SYSTEM_PROMPT<core-default_prompt_template>`. | ||
prompt_kwargs (Optional[Dict], optional): The preset prompt kwargs to fill in the variables in the prompt. Defaults to None. | ||
output_processors (Optional[Component], optional): The output processors after model call. It can be a single component or a chained component via ``Sequential``. Defaults to None. | ||
trainable_params (Optional[List[str]], optional): The list of trainable parameters. Defaults to []. | ||
|
||
Note: | ||
The output_processors will be applied to the string output of the model completion. And the result will be stored in the data field of the output. | ||
And we encourage you to only use it to parse the response to data format you will use later. | ||
name (Optional[str], optional): The name of the generator. Defaults to None. | ||
cache_path (Optional[str], optional): The path to save the cache. Defaults to None. | ||
use_cache (bool, optional): Whether to use cache. Defaults to False. | ||
model_type (ModelType, optional): The type of the model. Defaults to ModelType.LLM. | ||
""" | ||
|
||
model_type: ModelType = ModelType.LLM | ||
model_client: ModelClient # for better type checking | ||
|
||
_use_cache: bool = False | ||
_kwargs: Dict[str, Any] = ( | ||
{} | ||
) # to create teacher generator from student TODO: might reaccess this | ||
|
||
def __init__( | ||
self, | ||
*, | ||
|
@@ -100,6 +91,7 @@ def __init__( | |
# args for the cache | ||
cache_path: Optional[str] = None, | ||
use_cache: bool = False, | ||
model_type: ModelType = ModelType.LLM, # Add model_type parameter with default | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. dont update it here, let it be controled in the model client |
||
) -> None: | ||
r"""The default prompt is set to the DEFAULT_ADALFLOW_SYSTEM_PROMPT. It has the following variables: | ||
- task_desc_str | ||
|
@@ -121,7 +113,6 @@ def __init__( | |
template = template or DEFAULT_ADALFLOW_SYSTEM_PROMPT | ||
|
||
# create the cache path and initialize the cache engine | ||
|
||
self.set_cache_path( | ||
cache_path, model_client, model_kwargs.get("model", "default") | ||
) | ||
|
@@ -133,6 +124,7 @@ def __init__( | |
CallbackManager.__init__(self) | ||
|
||
self.name = name or self.__class__.__name__ | ||
self.model_type = model_type # Use the passed model_type instead of getting from client | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. delete this |
||
|
||
self._init_prompt(template, prompt_kwargs) | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please default to LLM