Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 100 additions & 48 deletions adalflow/adalflow/components/model_client/openai_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,11 @@ class OpenAIClient(ModelClient):
Users (1) simplify use ``Embedder`` and ``Generator`` components by passing OpenAIClient() as the model_client.
(2) can use this as an example to create their own API client or extend this class(copying and modifing the code) in their own project.

Args:
api_key (Optional[str], optional): OpenAI API key. Defaults to None.
chat_completion_parser (Callable[[Completion], Any], optional): A function to parse the chat completion to a str. Defaults to None.
input_type (Literal["text", "messages"], optional): The type of input to use. Defaults to "text".

Note:
We suggest users not to use `response_format` to enforce output data type or `tools` and `tool_choice` in your model_kwargs when calling the API.
We do not know how OpenAI is doing the formating or what prompt they have added.
Expand All @@ -120,14 +125,9 @@ class OpenAIClient(ModelClient):
- prompt: Text description of the image to generate
- size: "1024x1024", "1024x1792", or "1792x1024" for DALL-E 3; "256x256", "512x512", or "1024x1024" for DALL-E 2
- quality: "standard" or "hd" (DALL-E 3 only)
- n: Number of images to generate (1 for DALL-E 3, 1-10 for DALL-E 2)
- n: Number of images (1 for DALL-E 3, 1-10 for DALL-E 2)
- response_format: "url" or "b64_json"

Args:
api_key (Optional[str], optional): OpenAI API key. Defaults to None.
chat_completion_parser (Callable[[Completion], Any], optional): A function to parse the chat completion to a str. Defaults to None.
Default is `get_first_message_content`.

References:
- Embeddings models: https://platform.openai.com/docs/guides/embeddings
- Chat models: https://platform.openai.com/docs/guides/text-generation
Expand All @@ -146,6 +146,8 @@ def __init__(

Args:
api_key (Optional[str], optional): OpenAI API key. Defaults to None.
chat_completion_parser (Callable[[Completion], Any], optional): A function to parse the chat completion to a str. Defaults to None.
input_type (Literal["text", "messages"], optional): The type of input to use. Defaults to "text".
"""
super().__init__()
self._api_key = api_key
Expand Down Expand Up @@ -229,7 +231,7 @@ def convert_inputs_to_api_kwargs(
self,
input: Optional[Any] = None,
model_kwargs: Dict = {},
model_type: ModelType = ModelType.UNDEFINED,
model_type: ModelType = ModelType.UNDEFINED, # Now required in practice
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please default to LLM

) -> Dict:
r"""
Specify the API input type and output api_kwargs that will be used in _call and _acall methods.
Expand All @@ -243,11 +245,24 @@ def convert_inputs_to_api_kwargs(
- images: Optional image source(s) as path, URL, or list of them
- detail: Image detail level ('auto', 'low', or 'high'), defaults to 'auto'
- model: The model to use (must support multimodal inputs if images are provided)
model_type: The type of model (EMBEDDER or LLM)
For image generation:
- model: "dall-e-3" or "dall-e-2"
- size: "1024x1024", "1024x1792", or "1792x1024" for DALL-E 3; "256x256", "512x512", or "1024x1024" for DALL-E 2
- quality: "standard" or "hd" (DALL-E 3 only)
- n: Number of images (1 for DALL-E 3, 1-10 for DALL-E 2)
- response_format: "url" or "b64_json"
For image edits (DALL-E 2 only):
- image: Path to the input image
- mask: Path to the mask image
For variations (DALL-E 2 only):
- image: Path to the input image
model_type: The type of model to use (EMBEDDER, LLM, or IMAGE_GENERATION). Required.

Returns:
Dict: API-specific kwargs for the model call
"""
if model_type == ModelType.UNDEFINED:
raise ValueError("model_type must be specified")

final_model_kwargs = model_kwargs.copy()
if model_type == ModelType.EMBEDDER:
Expand Down Expand Up @@ -308,24 +323,43 @@ def convert_inputs_to_api_kwargs(
# Ensure model is specified
if "model" not in final_model_kwargs:
raise ValueError("model must be specified for image generation")
# Set defaults for DALL-E 3 if not specified
final_model_kwargs["size"] = final_model_kwargs.get("size", "1024x1024")
final_model_kwargs["quality"] = final_model_kwargs.get(
"quality", "standard"
)
final_model_kwargs["n"] = final_model_kwargs.get("n", 1)
final_model_kwargs["response_format"] = final_model_kwargs.get(
"response_format", "url"
)

# Handle image edits and variations
image = final_model_kwargs.get("image")
if isinstance(image, str) and os.path.isfile(image):
final_model_kwargs["image"] = self._encode_image(image)

mask = final_model_kwargs.get("mask")
if isinstance(mask, str) and os.path.isfile(mask):
final_model_kwargs["mask"] = self._encode_image(mask)
# Set defaults for image generation
if "operation" not in final_model_kwargs:
final_model_kwargs["operation"] = "generate" # Default operation

operation = final_model_kwargs.pop("operation")

if operation == "generate":
# Set defaults for DALL-E 3 if not specified
final_model_kwargs["size"] = final_model_kwargs.get("size", "1024x1024")
final_model_kwargs["quality"] = final_model_kwargs.get("quality", "standard")
final_model_kwargs["n"] = final_model_kwargs.get("n", 1)
final_model_kwargs["response_format"] = final_model_kwargs.get("response_format", "url")

elif operation in ["edit", "variation"]:
if "model" not in final_model_kwargs or final_model_kwargs["model"] != "dall-e-2":
raise ValueError(f"{operation} operation is only available with DALL-E 2")

# Handle image input
image_path = final_model_kwargs.get("image")
if not image_path or not os.path.isfile(image_path):
raise ValueError(f"Valid image path must be provided for {operation}")
final_model_kwargs["image"] = open(image_path, "rb")

# Handle mask for edit operation
if operation == "edit":
mask_path = final_model_kwargs.get("mask")
if not mask_path or not os.path.isfile(mask_path):
raise ValueError("Valid mask path must be provided for edit operation")
final_model_kwargs["mask"] = open(mask_path, "rb")

# Set defaults
final_model_kwargs["size"] = final_model_kwargs.get("size", "1024x1024")
final_model_kwargs["n"] = final_model_kwargs.get("n", 1)
final_model_kwargs["response_format"] = final_model_kwargs.get("response_format", "url")

else:
raise ValueError(f"Invalid operation: {operation}")
else:
raise ValueError(f"model_type {model_type} is not supported")
return final_model_kwargs
Expand Down Expand Up @@ -361,6 +395,9 @@ def call(self, api_kwargs: Dict = {}, model_type: ModelType = ModelType.UNDEFINE
"""
kwargs is the combined input and model_kwargs. Support streaming call.
"""
if model_type == ModelType.UNDEFINED:
raise ValueError("model_type must be specified")

log.info(f"api_kwargs: {api_kwargs}")
if model_type == ModelType.EMBEDDER:
return self.sync_client.embeddings.create(**api_kwargs)
Expand All @@ -371,18 +408,25 @@ def call(self, api_kwargs: Dict = {}, model_type: ModelType = ModelType.UNDEFINE
return self.sync_client.chat.completions.create(**api_kwargs)
return self.sync_client.chat.completions.create(**api_kwargs)
elif model_type == ModelType.IMAGE_GENERATION:
# Determine which image API to call based on the presence of image/mask
if "image" in api_kwargs:
if "mask" in api_kwargs:
# Image edit
operation = api_kwargs.pop("operation", "generate")

try:
if operation == "generate":
response = self.sync_client.images.generate(**api_kwargs)
elif operation == "edit":
response = self.sync_client.images.edit(**api_kwargs)
else:
# Image variation
elif operation == "variation":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

create_variation

response = self.sync_client.images.create_variation(**api_kwargs)
else:
# Image generation
response = self.sync_client.images.generate(**api_kwargs)
return response.data
else:
raise ValueError(f"Invalid operation: {operation}")

return response.data
finally:
# Clean up file handles if they exist
if "image" in api_kwargs and hasattr(api_kwargs["image"], "close"):
api_kwargs["image"].close()
if "mask" in api_kwargs and hasattr(api_kwargs["mask"], "close"):
api_kwargs["mask"].close()
else:
raise ValueError(f"model_type {model_type} is not supported")

Expand All @@ -403,27 +447,35 @@ async def acall(
"""
kwargs is the combined input and model_kwargs
"""
if model_type == ModelType.UNDEFINED:
raise ValueError("model_type must be specified")

if self.async_client is None:
self.async_client = self.init_async_client()
if model_type == ModelType.EMBEDDER:
return await self.async_client.embeddings.create(**api_kwargs)
elif model_type == ModelType.LLM:
return await self.async_client.chat.completions.create(**api_kwargs)
elif model_type == ModelType.IMAGE_GENERATION:
# Determine which image API to call based on the presence of image/mask
if "image" in api_kwargs:
if "mask" in api_kwargs:
# Image edit
operation = api_kwargs.pop("operation", "generate")

try:
if operation == "generate":
response = await self.async_client.images.generate(**api_kwargs)
elif operation == "edit":
response = await self.async_client.images.edit(**api_kwargs)
elif operation == "variation":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

create_variation

response = await self.async_client.images.create_variation(**api_kwargs)
else:
# Image variation
response = await self.async_client.images.create_variation(
**api_kwargs
)
else:
# Image generation
response = await self.async_client.images.generate(**api_kwargs)
return response.data
raise ValueError(f"Invalid operation: {operation}")

return response.data
finally:
# Clean up file handles if they exist
if "image" in api_kwargs and hasattr(api_kwargs["image"], "close"):
api_kwargs["image"].close()
if "mask" in api_kwargs and hasattr(api_kwargs["mask"], "close"):
api_kwargs["mask"].close()
else:
raise ValueError(f"model_type {model_type} is not supported")

Expand Down
20 changes: 6 additions & 14 deletions adalflow/adalflow/core/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,21 +70,12 @@ class Generator(GradComponent, CachedEngine, CallbackManager):
template (Optional[str], optional): The template for the prompt. Defaults to :ref:`DEFAULT_ADALFLOW_SYSTEM_PROMPT<core-default_prompt_template>`.
prompt_kwargs (Optional[Dict], optional): The preset prompt kwargs to fill in the variables in the prompt. Defaults to None.
output_processors (Optional[Component], optional): The output processors after model call. It can be a single component or a chained component via ``Sequential``. Defaults to None.
trainable_params (Optional[List[str]], optional): The list of trainable parameters. Defaults to [].

Note:
The output_processors will be applied to the string output of the model completion. And the result will be stored in the data field of the output.
And we encourage you to only use it to parse the response to data format you will use later.
name (Optional[str], optional): The name of the generator. Defaults to None.
cache_path (Optional[str], optional): The path to save the cache. Defaults to None.
use_cache (bool, optional): Whether to use cache. Defaults to False.
model_type (ModelType, optional): The type of the model. Defaults to ModelType.LLM.
"""

model_type: ModelType = ModelType.LLM
model_client: ModelClient # for better type checking

_use_cache: bool = False
_kwargs: Dict[str, Any] = (
{}
) # to create teacher generator from student TODO: might reaccess this

def __init__(
self,
*,
Expand All @@ -100,6 +91,7 @@ def __init__(
# args for the cache
cache_path: Optional[str] = None,
use_cache: bool = False,
model_type: ModelType = ModelType.LLM, # Add model_type parameter with default
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dont update it here, let it be controled in the model client

) -> None:
r"""The default prompt is set to the DEFAULT_ADALFLOW_SYSTEM_PROMPT. It has the following variables:
- task_desc_str
Expand All @@ -121,7 +113,6 @@ def __init__(
template = template or DEFAULT_ADALFLOW_SYSTEM_PROMPT

# create the cache path and initialize the cache engine

self.set_cache_path(
cache_path, model_client, model_kwargs.get("model", "default")
)
Expand All @@ -133,6 +124,7 @@ def __init__(
CallbackManager.__init__(self)

self.name = name or self.__class__.__name__
self.model_type = model_type # Use the passed model_type instead of getting from client
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

delete this


self._init_prompt(template, prompt_kwargs)

Expand Down
5 changes: 3 additions & 2 deletions adalflow/tests/test_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from adalflow.core.model_client import ModelClient
from adalflow.components.model_client.groq_client import GroqAPIClient
from adalflow.tracing import GeneratorStateLogger
from adalflow.core.types import ModelType


class TestGenerator(IsolatedAsyncioTestCase):
Expand All @@ -32,7 +33,7 @@ def setUp(self):
)
self.mock_api_client = mock_api_client

self.generator = Generator(model_client=mock_api_client)
self.generator = Generator(model_client=mock_api_client, model_type=ModelType.LLM)
self.save_dir = "./tests/log"
self.project_name = "TestGenerator"
self.filename = "prompt_logger_test.json"
Expand Down Expand Up @@ -182,7 +183,7 @@ def test_groq_client_call(self, mock_call):
template = "Hello, {{ input_str }}!"

# Initialize the Generator with the mocked client
generator = Generator(model_client=self.client, template=template)
generator = Generator(model_client=self.client, template=template, model_type=ModelType.LLM)

# Call the generator and get the output
output = generator.call(prompt_kwargs=prompt_kwargs, model_kwargs=model_kwargs)
Expand Down
Loading
Loading