From e185097dbf1927cd90e652ed164a6ae6f2d733e3 Mon Sep 17 00:00:00 2001 From: TransformerOptimus Date: Wed, 21 Jun 2023 10:21:39 +0530 Subject: [PATCH 1/5] adding email tests and fixing json cleaner test --- superagi/tools/email/send_email.py | 2 +- tests/helper/test_json_cleaner.py | 2 +- tests/tools/email/__init__.py | 0 tests/tools/email/test_send_email.py | 70 ++++++++++++++++++++++++++++ 4 files changed, 72 insertions(+), 2 deletions(-) create mode 100644 tests/tools/email/__init__.py create mode 100644 tests/tools/email/test_send_email.py diff --git a/superagi/tools/email/send_email.py b/superagi/tools/email/send_email.py index 2d6641cf4..d475e2a71 100644 --- a/superagi/tools/email/send_email.py +++ b/superagi/tools/email/send_email.py @@ -57,7 +57,7 @@ def _execute(self, to: str, subject: str, body: str) -> str: body += f"\n{signature}" message.set_content(body) draft_folder = get_config('EMAIL_DRAFT_MODE_WITH_FOLDER') - send_to_draft = draft_folder is not None or draft_folder != "YOUR_DRAFTS_FOLDER" + send_to_draft = draft_folder is not None and draft_folder != "YOUR_DRAFTS_FOLDER" if message["To"] == "example@example.com" or send_to_draft: conn = ImapEmail().imap_open(draft_folder, email_sender, email_password) conn.append( diff --git a/tests/helper/test_json_cleaner.py b/tests/helper/test_json_cleaner.py index be64eaa6f..8579a9900 100644 --- a/tests/helper/test_json_cleaner.py +++ b/tests/helper/test_json_cleaner.py @@ -40,4 +40,4 @@ def test_clean_newline_spaces_json(): def test_has_newline_in_string(): test_str = r'{key: "value\n"\n \n}' result = JsonCleaner.check_and_clean_json(test_str) - assert result == '{key: "value\\n"}' + assert result == '{key: "value"}' diff --git a/tests/tools/email/__init__.py b/tests/tools/email/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/tools/email/test_send_email.py b/tests/tools/email/test_send_email.py new file mode 100644 index 000000000..16d802477 --- /dev/null +++ b/tests/tools/email/test_send_email.py @@ -0,0 +1,70 @@ +from unittest.mock import MagicMock + +import pytest +import imaplib +import time +from email.message import EmailMessage + +from superagi.config.config import get_config +from superagi.helper.imap_email import ImapEmail +from superagi.tools.email import send_email +from superagi.tools.email.send_email import SendEmailTool + +def test_send_to_draft(mocker): + + mock_get_config = mocker.patch('superagi.tools.email.send_email.get_config', autospec=True) + mock_get_config.side_effect = [ + 'test_sender@test.com', # EMAIL_ADDRESS + 'password', # EMAIL_PASSWORD + 'Test Signature', # EMAIL_SIGNATURE + "Draft", # EMAIL_DRAFT_MODE_WITH_FOLDER + 'smtp_host', # EMAIL_SMTP_HOST + 'smtp_port' # EMAIL_SMTP_PORT + ] + + + # Mocking the ImapEmail call + mock_imap_email = mocker.patch('superagi.tools.email.send_email.ImapEmail') + mock_imap_instance = mock_imap_email.return_value.imap_open.return_value + + # Mocking the SMTP call + mock_smtp = mocker.patch('smtplib.SMTP') + smtp_instance = mock_smtp.return_value + + # Test the SendEmailTool's execute method + send_email_tool = SendEmailTool() + result = send_email_tool._execute('mukunda@contlo.com', 'Test Subject', 'Test Body') + + # Assert the return value + assert result == 'Email went to Draft' + +def test_send_to_mailbox(mocker): + # Mocking the get_config calls + mock_get_config = mocker.patch('superagi.tools.email.send_email.get_config') + mock_get_config.side_effect = [ + 'test_sender@test.com', # EMAIL_ADDRESS + 'password', # EMAIL_PASSWORD + 'Test Signature', # EMAIL_SIGNATURE + "YOUR_DRAFTS_FOLDER", # EMAIL_DRAFT_MODE_WITH_FOLDER + 'smtp_host', # EMAIL_SMTP_HOST + 'smtp_port' # EMAIL_SMTP_PORT + ] + + # mock_get_config.return_value = 'True' + # Mocking the ImapEmail call + mock_imap_email = mocker.patch('superagi.tools.email.send_email.ImapEmail') + mock_imap_instance = mock_imap_email.return_value.imap_open.return_value + + # Mocking the SMTP call + mock_smtp = mocker.patch('smtplib.SMTP') + smtp_instance = mock_smtp.return_value + + # Test the SendEmailTool's execute method + send_email_tool = SendEmailTool() + result = send_email_tool._execute('test_receiver@test.com', 'Test Subject', 'Test Body') + + # Assert that the ImapEmail was not called (no draft mode) + mock_imap_email.assert_not_called() + + # Assert the return value + assert result == 'Email was sent to test_receiver@test.com' \ No newline at end of file From 883e020f870359cdbaf79e9e5e5cd4b17baf42fa Mon Sep 17 00:00:00 2001 From: TransformerOptimus Date: Wed, 21 Jun 2023 12:08:03 +0530 Subject: [PATCH 2/5] refactoring resource manager --- superagi/helper/resource_helper.py | 39 +++++----- superagi/jobs/agent_executor.py | 8 +- .../resource_manager}/__init__.py | 0 superagi/resource_manager/manager.py | 52 +++++++++++++ superagi/tools/file/append_file.py | 9 +-- superagi/tools/file/delete_file.py | 9 +-- superagi/tools/file/write_file.py | 35 +-------- .../tools/image_generation/dalle_image_gen.py | 55 ++++---------- .../stable_diffusion_image_gen.py | 75 ++++++------------- superagi/tools/thinking/tools.py | 8 +- .../__init__.py | 0 .../vector_store}/__init__.py | 0 .../vector_store/test_weaviate.py | 0 tests/{tools/email => unit_tests}/__init__.py | 0 tests/unit_tests/agent/__init__.py | 0 .../{ => unit_tests}/agent/test_task_queue.py | 0 .../unit_tests/agent_permissions/__init__.py | 0 ...est_check_permission_in_restricted_mode.py | 0 .../test_handle_wait_for_permission.py | 0 tests/unit_tests/helper/__init__.py | 0 .../helper/test_github_helper.py | 0 .../helper/test_json_cleaner.py | 0 .../unit_tests/helper/test_resource_helper.py | 33 ++++++++ tests/unit_tests/resource_manager/__init__.py | 0 .../resource_manager/test_resource_manager.py | 37 +++++++++ tests/unit_tests/tools/email/__init__.py | 0 .../tools/email/test_send_email.py | 0 .../{ => unit_tests}/tools/image_gen_test.py | 0 .../tools/stable_diffusion_image_gen_test.py | 3 +- 29 files changed, 195 insertions(+), 168 deletions(-) rename {tests/agent => superagi/resource_manager}/__init__.py (100%) create mode 100644 superagi/resource_manager/manager.py rename tests/{agent_permissions => integration_tests}/__init__.py (100%) rename tests/{helper => integration_tests/vector_store}/__init__.py (100%) rename tests/{ => integration_tests}/vector_store/test_weaviate.py (100%) rename tests/{tools/email => unit_tests}/__init__.py (100%) create mode 100644 tests/unit_tests/agent/__init__.py rename tests/{ => unit_tests}/agent/test_task_queue.py (100%) create mode 100644 tests/unit_tests/agent_permissions/__init__.py rename tests/{ => unit_tests}/agent_permissions/test_check_permission_in_restricted_mode.py (100%) rename tests/{ => unit_tests}/agent_permissions/test_handle_wait_for_permission.py (100%) create mode 100644 tests/unit_tests/helper/__init__.py rename tests/{ => unit_tests}/helper/test_github_helper.py (100%) rename tests/{ => unit_tests}/helper/test_json_cleaner.py (100%) create mode 100644 tests/unit_tests/helper/test_resource_helper.py create mode 100644 tests/unit_tests/resource_manager/__init__.py create mode 100644 tests/unit_tests/resource_manager/test_resource_manager.py create mode 100644 tests/unit_tests/tools/email/__init__.py rename tests/{ => unit_tests}/tools/email/test_send_email.py (100%) rename tests/{ => unit_tests}/tools/image_gen_test.py (100%) rename tests/{ => unit_tests}/tools/stable_diffusion_image_gen_test.py (98%) diff --git a/superagi/helper/resource_helper.py b/superagi/helper/resource_helper.py index 025e2e3b2..de7e62b54 100644 --- a/superagi/helper/resource_helper.py +++ b/superagi/helper/resource_helper.py @@ -8,14 +8,13 @@ class ResourceHelper: @staticmethod - def make_written_file_resource(file_name: str, agent_id: int, file, channel): + def make_written_file_resource(file_name: str, agent_id: int, channel: str): """ Function to create a Resource object for a written file. Args: file_name (str): The name of the file. agent_id (int): The ID of the agent. - file (FileStorage): The file. channel (str): The channel of the file. Returns: @@ -32,25 +31,14 @@ def make_written_file_resource(file_name: str, agent_id: int, file, channel): else: file_type = "application/misc" - root_dir = get_config('RESOURCES_OUTPUT_ROOT_DIR') - - if root_dir is not None: - root_dir = root_dir if root_dir.startswith("/") else os.getcwd() + "/" + root_dir - root_dir = root_dir if root_dir.endswith("/") else root_dir + "/" - final_path = root_dir + file_name - else: - final_path = os.getcwd() + "/" + file_name - + final_path = ResourceHelper.get_resource_path(file_name) file_size = os.path.getsize(final_path) if storage_type == "S3": file_name_parts = file_name.split('.') - file_name = file_name_parts[0] + '_' + str(datetime.datetime.now()).replace(' ', '').replace('.', '').replace( - ':', '') + '.' + file_name_parts[1] - if channel == "INPUT": - path = 'input' - else: - path = 'output' + file_name = file_name_parts[0] + '_' + str(datetime.datetime.now()).replace(' ', '') \ + .replace('.', '').replace(':', '') + '.' + file_name_parts[1] + path = 'input' if (channel == "INPUT") else 'output' logger.info(path + "/" + file_name) resource = Resource(name=file_name, path=path + "/" + file_name, storage_type=storage_type, size=file_size, @@ -58,3 +46,20 @@ def make_written_file_resource(file_name: str, agent_id: int, file, channel): channel="OUTPUT", agent_id=agent_id) return resource + + @staticmethod + def get_resource_path(file_name: str): + """Get final path of the resource. + + Args: + file_name (str): The name of the file. + """ + root_dir = get_config('RESOURCES_OUTPUT_ROOT_DIR') + + if root_dir is not None: + root_dir = root_dir if root_dir.startswith("/") else os.getcwd() + "/" + root_dir + root_dir = root_dir if root_dir.endswith("/") else root_dir + "/" + final_path = root_dir + file_name + else: + final_path = os.getcwd() + "/" + file_name + return final_path diff --git a/superagi/jobs/agent_executor.py b/superagi/jobs/agent_executor.py index 636be0aea..a1aa835d0 100644 --- a/superagi/jobs/agent_executor.py +++ b/superagi/jobs/agent_executor.py @@ -17,6 +17,7 @@ from superagi.models.organisation import Organisation from superagi.models.project import Project from superagi.models.tool import Tool +from superagi.resource_manager.manager import ResourceManager from superagi.tools.thinking.tools import ThinkingTool from superagi.vector_store.embedding.openai import OpenAiEmbedding from superagi.vector_store.vector_factory import VectorFactory @@ -164,7 +165,7 @@ def execute_next_action(self, agent_execution_id): print(user_tools) tools = self.set_default_params_tools(tools, parsed_config, agent_execution.agent_id, - model_api_key=model_api_key) + model_api_key=model_api_key, session=session) @@ -205,7 +206,7 @@ def execute_next_action(self, agent_execution_id): # finally: engine.dispose() - def set_default_params_tools(self, tools, parsed_config, agent_id, model_api_key): + def set_default_params_tools(self, tools, parsed_config, agent_id, model_api_key, session): """ Set the default parameters for the tools. @@ -232,6 +233,9 @@ def set_default_params_tools(self, tools, parsed_config, agent_id, model_api_key tool.image_llm = OpenAi(model=parsed_config["model"], api_key=model_api_key) if hasattr(tool, 'agent_id'): tool.agent_id = agent_id + if hasattr(tool, 'resource_manager'): + tool.resource_manager = ResourceManager(session=session) + new_tools.append(tool) return tools diff --git a/tests/agent/__init__.py b/superagi/resource_manager/__init__.py similarity index 100% rename from tests/agent/__init__.py rename to superagi/resource_manager/__init__.py diff --git a/superagi/resource_manager/manager.py b/superagi/resource_manager/manager.py new file mode 100644 index 000000000..d46b4eb39 --- /dev/null +++ b/superagi/resource_manager/manager.py @@ -0,0 +1,52 @@ +from sqlalchemy.orm import Session + +from superagi.helper.resource_helper import ResourceHelper +from superagi.helper.s3_helper import S3Helper +from superagi.lib.logger import logger + + +class ResourceManager: + def __init__(self, session: Session): + self.session = session + + def write_binary_file(self, file_name: str, data): + final_path = ResourceHelper.get_resource_path(file_name) + + try: + with open(final_path, mode="wb") as img: + img.write(data) + img.close() + with open(final_path, 'rb') as img: + resource = ResourceHelper.make_written_file_resource(file_name=file_name, + agent_id=self.agent_id, channel="OUTPUT") + if resource is not None: + self.session.add(resource) + self.session.commit() + self.session.flush() + if resource.storage_type == "S3": + s3_helper = S3Helper() + s3_helper.upload_file(img, path=resource.path) + logger.info(f"Binary {file_name} saved successfully") + except Exception as err: + return f"Error: {err}" + + def write_file(self, file_name: str, content): + final_path = ResourceHelper.get_resource_path(file_name) + + try: + with open(final_path, mode="w") as file: + file.write(content) + file.close() + with open(final_path, 'rb') as img: + resource = ResourceHelper.make_written_file_resource(file_name=file_name, + agent_id=self.agent_id, channel="OUTPUT") + if resource is not None: + self.session.add(resource) + self.session.commit() + self.session.flush() + if resource.storage_type == "S3": + s3_helper = S3Helper() + s3_helper.upload_file(img, path=resource.path) + logger.info(f"{file_name} saved successfully") + except Exception as err: + return f"Error: {err}" \ No newline at end of file diff --git a/superagi/tools/file/append_file.py b/superagi/tools/file/append_file.py index b2f62d30a..72f3fae5f 100644 --- a/superagi/tools/file/append_file.py +++ b/superagi/tools/file/append_file.py @@ -3,6 +3,7 @@ from pydantic import BaseModel, Field from superagi.config.config import get_config +from superagi.helper.resource_helper import ResourceHelper from superagi.tools.base_tool import BaseTool @@ -38,13 +39,7 @@ def _execute(self, file_name: str, content: str): Returns: file written to successfully. or error message. """ - final_path = file_name - root_dir = get_config('RESOURCES_OUTPUT_ROOT_DIR') - if root_dir is not None: - root_dir = root_dir if root_dir.endswith("/") else root_dir + "/" - final_path = root_dir + file_name - else: - final_path = os.getcwd() + "/" + file_name + final_path = ResourceHelper.get_resource_path(file_name) try: directory = os.path.dirname(final_path) os.makedirs(directory, exist_ok=True) diff --git a/superagi/tools/file/delete_file.py b/superagi/tools/file/delete_file.py index cba875cb5..3917f0a1a 100644 --- a/superagi/tools/file/delete_file.py +++ b/superagi/tools/file/delete_file.py @@ -3,6 +3,7 @@ from pydantic import BaseModel, Field +from superagi.helper.resource_helper import ResourceHelper from superagi.tools.base_tool import BaseTool from superagi.config.config import get_config @@ -36,13 +37,7 @@ def _execute(self, file_name: str, content: str): Returns: file deleted successfully. or error message. """ - final_path = file_name - root_dir = get_config('RESOURCES_INPUT_ROOT_DIR') - if root_dir is not None: - root_dir = root_dir if root_dir.endswith("/") else root_dir + "/" - final_path = root_dir + file_name - else: - final_path = os.getcwd() + "/" + file_name + final_path = ResourceHelper.get_resource_path(file_name) try: os.remove(final_path) return "File deleted successfully." diff --git a/superagi/tools/file/write_file.py b/superagi/tools/file/write_file.py index d11247e06..233cea4e1 100644 --- a/superagi/tools/file/write_file.py +++ b/superagi/tools/file/write_file.py @@ -1,6 +1,8 @@ import os from typing import Type from pydantic import BaseModel, Field + +from superagi.resource_manager.manager import ResourceManager from superagi.tools.base_tool import BaseTool from superagi.config.config import get_config from sqlalchemy.orm import sessionmaker @@ -32,6 +34,7 @@ class WriteFileTool(BaseTool): args_schema: Type[BaseModel] = WriteFileInput description: str = "Writes text to a file" agent_id: int = None + resource_manager: ResourceManager = None def _execute(self, file_name: str, content: str): """ @@ -44,35 +47,5 @@ def _execute(self, file_name: str, content: str): Returns: file written to successfully. or error message. """ - engine = connect_db() - Session = sessionmaker(bind=engine) - session = Session() - - final_path = file_name - root_dir = get_config('RESOURCES_OUTPUT_ROOT_DIR') - if root_dir is not None: - root_dir = root_dir if root_dir.startswith("/") else os.getcwd() + "/" + root_dir - root_dir = root_dir if root_dir.endswith("/") else root_dir + "/" - final_path = root_dir + file_name - else: - final_path = os.getcwd() + "/" + file_name + self.resource_manager.write_file(file_name, content) - try: - with open(final_path, 'w', encoding="utf-8") as file: - file.write(content) - file.close() - with open(final_path, 'rb') as file: - resource = ResourceHelper.make_written_file_resource(file_name=file_name, - agent_id=self.agent_id,file=file,channel="OUTPUT") - if resource is not None: - session.add(resource) - session.commit() - session.flush() - if resource.storage_type == "S3": - s3_helper = S3Helper() - s3_helper.upload_file(file, path=resource.path) - logger.info("Resource Uploaded to S3!") - session.close() - return f"File written to successfully - {file_name}" - except Exception as err: - return f"Error: {err}" diff --git a/superagi/tools/image_generation/dalle_image_gen.py b/superagi/tools/image_generation/dalle_image_gen.py index cd80847f1..19962ddd4 100644 --- a/superagi/tools/image_generation/dalle_image_gen.py +++ b/superagi/tools/image_generation/dalle_image_gen.py @@ -1,16 +1,11 @@ from typing import Type, Optional + +import requests from pydantic import BaseModel, Field + from superagi.llms.base_llm import BaseLlm +from superagi.resource_manager.manager import ResourceManager from superagi.tools.base_tool import BaseTool -from superagi.config.config import get_config -import os -import requests -from superagi.models.db import connect_db -from superagi.helper.resource_helper import ResourceHelper -from superagi.helper.s3_helper import S3Helper -from sqlalchemy.orm import sessionmaker -from superagi.lib.logger import logger - class ImageGenInput(BaseModel): @@ -25,15 +20,19 @@ class ImageGenTool(BaseTool): Dalle Image Generation tool Attributes: - name : The name. - description : The description. - args_schema : The args schema. + name : Name of the tool + description : The description + args_schema : The args schema + llm : The llm + agent_id : The agent id + resource_manager : Manages the file resources """ - name: str = "Dalle Image Generation" + name: str = "DalleImageGeneration" args_schema: Type[BaseModel] = ImageGenInput description: str = "Generate Images using Dalle" llm: Optional[BaseLlm] = None agent_id: int = None + resource_manager: ResourceManager = None class Config: arbitrary_types_allowed = True @@ -51,9 +50,6 @@ def _execute(self, prompt: str, image_name: list, size: int = 512, num: int = 2) Returns: Image generated successfully. or error message. """ - engine = connect_db() - Session = sessionmaker(bind=engine) - session = Session() if size not in [256, 512, 1024]: size = min([256, 512, 1024], key=lambda x: abs(x - size)) response = self.llm.generate_image(prompt, size, num) @@ -61,32 +57,7 @@ def _execute(self, prompt: str, image_name: list, size: int = 512, num: int = 2) response = response['_previous']['data'] for i in range(num): image = image_name[i] - final_path = image - root_dir = get_config('RESOURCES_OUTPUT_ROOT_DIR') - if root_dir is not None: - root_dir = root_dir if root_dir.startswith("/") else os.getcwd() + "/" + root_dir - root_dir = root_dir if root_dir.endswith("/") else root_dir + "/" - final_path = root_dir + image - else: - final_path = os.getcwd() + "/" + image url = response[i]['url'] data = requests.get(url).content - try: - with open(final_path, mode="wb") as img: - img.write(data) - with open(final_path, 'rb') as img: - resource = ResourceHelper.make_written_file_resource(file_name=image_name[i], - agent_id=self.agent_id, file=img,channel="OUTPUT") - if resource is not None: - session.add(resource) - session.commit() - session.flush() - if resource.storage_type == "S3": - s3_helper = S3Helper() - s3_helper.upload_file(img, path=resource.path) - logger.info(f"Image {image} saved successfully") - except Exception as err: - session.close() - return f"Error: {err}" - session.close() + self.resource_manager.write_binary_file(image, data) return "Images downloaded successfully" diff --git a/superagi/tools/image_generation/stable_diffusion_image_gen.py b/superagi/tools/image_generation/stable_diffusion_image_gen.py index be16014ad..f6f7561e8 100644 --- a/superagi/tools/image_generation/stable_diffusion_image_gen.py +++ b/superagi/tools/image_generation/stable_diffusion_image_gen.py @@ -1,17 +1,12 @@ -from typing import Type, Optional +import base64 +from typing import Type + +import requests from pydantic import BaseModel, Field -from superagi.tools.base_tool import BaseTool + from superagi.config.config import get_config -import os -from PIL import Image -from io import BytesIO -import requests -import base64 -from superagi.models.db import connect_db -from superagi.helper.resource_helper import ResourceHelper -from superagi.helper.s3_helper import S3Helper -from sqlalchemy.orm import sessionmaker -from superagi.lib.logger import logger +from superagi.resource_manager.manager import ResourceManager +from superagi.tools.base_tool import BaseTool class StableDiffusionImageGenInput(BaseModel): @@ -25,17 +20,24 @@ class StableDiffusionImageGenInput(BaseModel): class StableDiffusionImageGenTool(BaseTool): + """ + Stable diffusion Image Generation tool + + Attributes: + name : Name of the tool + description : The description + args_schema : The args schema + agent_id : The agent id + resource_manager : Manages the file resources + """ name: str = "Stable Diffusion Image Generation" args_schema: Type[BaseModel] = StableDiffusionImageGenInput description: str = "Generate Images using Stable Diffusion" agent_id: int = None + resource_manager: ResourceManager = None def _execute(self, prompt: str, image_name: list, width: int = 512, height: int = 512, num: int = 2, steps: int = 50): - engine = connect_db() - Session = sessionmaker(bind=engine) - session = Session() - api_key = get_config("STABILITY_API_KEY") if api_key is None: @@ -56,23 +58,12 @@ def _execute(self, prompt: str, image_name: list, width: int = 512, height: int for i in range(num): image_base64 = base64_strings[i] img_data = base64.b64decode(image_base64) - final_img = Image.open(BytesIO(img_data)) - image_format = final_img.format + # final_img = Image.open(BytesIO(img_data)) + # image_format = final_img.format image = image_name[i] - root_dir = get_config('RESOURCES_OUTPUT_ROOT_DIR') + self.resource_manager.write_binary_file(image_name[i], img_data) - final_path = self.build_file_path(image, root_dir) - - try: - self.upload_to_s3(final_img, final_path, image_format, image_name[i], session) - - logger.info(f"Image {image} saved successfully") - except Exception as err: - session.close() - print(f"Error in _execute: {err}") - return f"Error: {err}" - session.close() return "Images downloaded and saved successfully" def call_stable_diffusion(self, api_key, width, height, num, prompt, steps): @@ -102,27 +93,3 @@ def call_stable_diffusion(self, api_key, width, height, num, prompt, steps): }, ) return response - - def upload_to_s3(self, final_img, final_path, image_format, file_name, session): - with open(final_path, mode="wb") as img: - final_img.save(img, format=image_format) - with open(final_path, 'rb') as img: - resource = ResourceHelper.make_written_file_resource(file_name=file_name, - agent_id=self.agent_id, file=img, channel="OUTPUT") - logger.info(resource) - if resource is not None: - session.add(resource) - session.commit() - session.flush() - if resource.storage_type == "S3": - s3_helper = S3Helper() - s3_helper.upload_file(img, path=resource.path) - - def build_file_path(self, image, root_dir): - if root_dir is not None: - root_dir = root_dir if root_dir.startswith("/") else os.getcwd() + "/" + root_dir - root_dir = root_dir if root_dir.endswith("/") else root_dir + "/" - final_path = root_dir + image - else: - final_path = os.getcwd() + "/" + image - return final_path diff --git a/superagi/tools/thinking/tools.py b/superagi/tools/thinking/tools.py index 50c4c699a..7e4c68eee 100644 --- a/superagi/tools/thinking/tools.py +++ b/superagi/tools/thinking/tools.py @@ -1,15 +1,11 @@ -import os -import openai from typing import Type, Optional, List from pydantic import BaseModel, Field from superagi.agent.agent_prompt_builder import AgentPromptBuilder -from superagi.tools.base_tool import BaseTool -from superagi.config.config import get_config -from superagi.llms.base_llm import BaseLlm -from pydantic import BaseModel, Field, PrivateAttr from superagi.lib.logger import logger +from superagi.llms.base_llm import BaseLlm +from superagi.tools.base_tool import BaseTool class ThinkingSchema(BaseModel): diff --git a/tests/agent_permissions/__init__.py b/tests/integration_tests/__init__.py similarity index 100% rename from tests/agent_permissions/__init__.py rename to tests/integration_tests/__init__.py diff --git a/tests/helper/__init__.py b/tests/integration_tests/vector_store/__init__.py similarity index 100% rename from tests/helper/__init__.py rename to tests/integration_tests/vector_store/__init__.py diff --git a/tests/vector_store/test_weaviate.py b/tests/integration_tests/vector_store/test_weaviate.py similarity index 100% rename from tests/vector_store/test_weaviate.py rename to tests/integration_tests/vector_store/test_weaviate.py diff --git a/tests/tools/email/__init__.py b/tests/unit_tests/__init__.py similarity index 100% rename from tests/tools/email/__init__.py rename to tests/unit_tests/__init__.py diff --git a/tests/unit_tests/agent/__init__.py b/tests/unit_tests/agent/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/agent/test_task_queue.py b/tests/unit_tests/agent/test_task_queue.py similarity index 100% rename from tests/agent/test_task_queue.py rename to tests/unit_tests/agent/test_task_queue.py diff --git a/tests/unit_tests/agent_permissions/__init__.py b/tests/unit_tests/agent_permissions/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/agent_permissions/test_check_permission_in_restricted_mode.py b/tests/unit_tests/agent_permissions/test_check_permission_in_restricted_mode.py similarity index 100% rename from tests/agent_permissions/test_check_permission_in_restricted_mode.py rename to tests/unit_tests/agent_permissions/test_check_permission_in_restricted_mode.py diff --git a/tests/agent_permissions/test_handle_wait_for_permission.py b/tests/unit_tests/agent_permissions/test_handle_wait_for_permission.py similarity index 100% rename from tests/agent_permissions/test_handle_wait_for_permission.py rename to tests/unit_tests/agent_permissions/test_handle_wait_for_permission.py diff --git a/tests/unit_tests/helper/__init__.py b/tests/unit_tests/helper/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/helper/test_github_helper.py b/tests/unit_tests/helper/test_github_helper.py similarity index 100% rename from tests/helper/test_github_helper.py rename to tests/unit_tests/helper/test_github_helper.py diff --git a/tests/helper/test_json_cleaner.py b/tests/unit_tests/helper/test_json_cleaner.py similarity index 100% rename from tests/helper/test_json_cleaner.py rename to tests/unit_tests/helper/test_json_cleaner.py diff --git a/tests/unit_tests/helper/test_resource_helper.py b/tests/unit_tests/helper/test_resource_helper.py new file mode 100644 index 000000000..f9ad6aa23 --- /dev/null +++ b/tests/unit_tests/helper/test_resource_helper.py @@ -0,0 +1,33 @@ +import pytest +from unittest.mock import patch +from superagi.helper.resource_helper import ResourceHelper # Replace with actual import +@pytest.fixture +def resource_helper(): + with patch('superagi.helper.resource_helper.get_config') as get_config_mock, \ + patch('superagi.helper.resource_helper.os.getcwd') as get_cwd_mock, \ + patch('superagi.helper.resource_helper.os.path.getsize') as getsize_mock: + + get_config_mock.return_value = '/fake/path' + get_cwd_mock.return_value = '/fake/cwd' + getsize_mock.return_value = 100 + + yield + +def test_make_written_file_resource(resource_helper): + file_name = 'test.png' + agent_id = 1 + channel = 'INPUT' + result = ResourceHelper.make_written_file_resource(file_name, agent_id, channel) + + assert result.name == file_name + assert result.path == '/fake/path/' + file_name + assert result.size == 100 + assert result.type == 'image/png' + assert result.channel == 'OUTPUT' + assert result.agent_id == agent_id + +def test_get_resource_path(resource_helper): + file_name = 'test.png' + result = ResourceHelper.get_resource_path(file_name) + + assert result == '/fake/path/test.png' diff --git a/tests/unit_tests/resource_manager/__init__.py b/tests/unit_tests/resource_manager/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unit_tests/resource_manager/test_resource_manager.py b/tests/unit_tests/resource_manager/test_resource_manager.py new file mode 100644 index 000000000..e1f6e6f0e --- /dev/null +++ b/tests/unit_tests/resource_manager/test_resource_manager.py @@ -0,0 +1,37 @@ +import pytest +from unittest.mock import Mock, patch +from superagi.models.resource import Resource +from superagi.helper.resource_helper import ResourceHelper +from superagi.helper.s3_helper import S3Helper +from superagi.lib.logger import logger + +from superagi.resource_manager.manager import ResourceManager + +@pytest.fixture +def resource_manager(): + session_mock = Mock() + resource_manager = ResourceManager(session_mock) + resource_manager.agent_id = 1 # replace with actual value + return resource_manager + + +def test_write_binary_file(resource_manager): + with patch.object(ResourceHelper, 'get_resource_path', return_value='test_path'), \ + patch.object(ResourceHelper, 'make_written_file_resource', + return_value=Resource(name='test.png', storage_type='S3')), \ + patch.object(S3Helper, 'upload_file'), \ + patch.object(logger, 'info') as logger_mock: + result = resource_manager.write_binary_file('test.png', b'data') + assert result is None + logger_mock.assert_called_once_with("Binary test.png saved successfully") + + +def test_write_file(resource_manager): + with patch.object(ResourceHelper, 'get_resource_path', return_value='test_path'), \ + patch.object(ResourceHelper, 'make_written_file_resource', + return_value=Resource(name='test.txt', storage_type='S3')), \ + patch.object(S3Helper, 'upload_file'), \ + patch.object(logger, 'info') as logger_mock: + result = resource_manager.write_file('test.txt', 'content') + assert result is None + logger_mock.assert_called_once_with("test.txt saved successfully") diff --git a/tests/unit_tests/tools/email/__init__.py b/tests/unit_tests/tools/email/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/tools/email/test_send_email.py b/tests/unit_tests/tools/email/test_send_email.py similarity index 100% rename from tests/tools/email/test_send_email.py rename to tests/unit_tests/tools/email/test_send_email.py diff --git a/tests/tools/image_gen_test.py b/tests/unit_tests/tools/image_gen_test.py similarity index 100% rename from tests/tools/image_gen_test.py rename to tests/unit_tests/tools/image_gen_test.py diff --git a/tests/tools/stable_diffusion_image_gen_test.py b/tests/unit_tests/tools/stable_diffusion_image_gen_test.py similarity index 98% rename from tests/tools/stable_diffusion_image_gen_test.py rename to tests/unit_tests/tools/stable_diffusion_image_gen_test.py index ae2ebc673..bff79bf96 100644 --- a/tests/tools/stable_diffusion_image_gen_test.py +++ b/tests/unit_tests/tools/stable_diffusion_image_gen_test.py @@ -18,8 +18,7 @@ def test_stable_diffusion_image_gen_tool_execute(self, mock_get_config, mock_req tool = StableDiffusionImageGenTool() prompt = 'Artificial Intelligence' image_names = ['image1.png', 'image2.png'] - height = 512 - width = 512 + height, width = 512, 512 num = 2 steps = 50 From 207686299f5371e460567883df55bd644de7b26e Mon Sep 17 00:00:00 2001 From: TransformerOptimus Date: Wed, 21 Jun 2023 15:09:00 +0530 Subject: [PATCH 3/5] minor changes --- superagi/tools/file/write_file.py | 4 ++-- superagi/tools/image_generation/dalle_image_gen.py | 3 ++- .../tools/image_generation/stable_diffusion_image_gen.py | 7 +++++-- 3 files changed, 9 insertions(+), 5 deletions(-) diff --git a/superagi/tools/file/write_file.py b/superagi/tools/file/write_file.py index 233cea4e1..67df6fa32 100644 --- a/superagi/tools/file/write_file.py +++ b/superagi/tools/file/write_file.py @@ -1,5 +1,5 @@ import os -from typing import Type +from typing import Type, Optional from pydantic import BaseModel, Field from superagi.resource_manager.manager import ResourceManager @@ -34,7 +34,7 @@ class WriteFileTool(BaseTool): args_schema: Type[BaseModel] = WriteFileInput description: str = "Writes text to a file" agent_id: int = None - resource_manager: ResourceManager = None + resource_manager: Optional[ResourceManager] = None def _execute(self, file_name: str, content: str): """ diff --git a/superagi/tools/image_generation/dalle_image_gen.py b/superagi/tools/image_generation/dalle_image_gen.py index 19962ddd4..01ba59ad5 100644 --- a/superagi/tools/image_generation/dalle_image_gen.py +++ b/superagi/tools/image_generation/dalle_image_gen.py @@ -6,6 +6,7 @@ from superagi.llms.base_llm import BaseLlm from superagi.resource_manager.manager import ResourceManager from superagi.tools.base_tool import BaseTool +from superagi.tools.image_generation.stable_diffusion_image_gen import StableDiffusionImageGenTool class ImageGenInput(BaseModel): @@ -32,7 +33,7 @@ class ImageGenTool(BaseTool): description: str = "Generate Images using Dalle" llm: Optional[BaseLlm] = None agent_id: int = None - resource_manager: ResourceManager = None + resource_manager: Optional[ResourceManager] = None class Config: arbitrary_types_allowed = True diff --git a/superagi/tools/image_generation/stable_diffusion_image_gen.py b/superagi/tools/image_generation/stable_diffusion_image_gen.py index f6f7561e8..c7278adeb 100644 --- a/superagi/tools/image_generation/stable_diffusion_image_gen.py +++ b/superagi/tools/image_generation/stable_diffusion_image_gen.py @@ -1,5 +1,5 @@ import base64 -from typing import Type +from typing import Type, Optional import requests from pydantic import BaseModel, Field @@ -34,7 +34,10 @@ class StableDiffusionImageGenTool(BaseTool): args_schema: Type[BaseModel] = StableDiffusionImageGenInput description: str = "Generate Images using Stable Diffusion" agent_id: int = None - resource_manager: ResourceManager = None + resource_manager: Optional[ResourceManager] = None + + class Config: + arbitrary_types_allowed = True def _execute(self, prompt: str, image_name: list, width: int = 512, height: int = 512, num: int = 2, steps: int = 50): From dd7d8cf920129ca845e866c050a032d31b424d06 Mon Sep 17 00:00:00 2001 From: TransformerOptimus Date: Wed, 21 Jun 2023 17:00:45 +0530 Subject: [PATCH 4/5] fixing stable diffusion and dalle unit tests --- superagi/resource_manager/manager.py | 34 +++++----- superagi/tools/file/write_file.py | 3 + .../tools/image_generation/dalle_image_gen.py | 20 +++--- .../stable_diffusion_image_gen.py | 24 ++++--- tests/unit_tests/tools/image_gen_test.py | 46 -------------- .../tools/stable_diffusion_image_gen_test.py | 63 ------------------- .../unit_tests/tools/test_dalle_image_gen.py | 27 ++++++++ .../tools/test_stable_diffusion_image_gen.py | 51 +++++++++++++++ 8 files changed, 114 insertions(+), 154 deletions(-) delete mode 100644 tests/unit_tests/tools/image_gen_test.py delete mode 100644 tests/unit_tests/tools/stable_diffusion_image_gen_test.py create mode 100644 tests/unit_tests/tools/test_dalle_image_gen.py create mode 100644 tests/unit_tests/tools/test_stable_diffusion_image_gen.py diff --git a/superagi/resource_manager/manager.py b/superagi/resource_manager/manager.py index d46b4eb39..588a58f99 100644 --- a/superagi/resource_manager/manager.py +++ b/superagi/resource_manager/manager.py @@ -16,20 +16,23 @@ def write_binary_file(self, file_name: str, data): with open(final_path, mode="wb") as img: img.write(data) img.close() - with open(final_path, 'rb') as img: - resource = ResourceHelper.make_written_file_resource(file_name=file_name, - agent_id=self.agent_id, channel="OUTPUT") - if resource is not None: - self.session.add(resource) - self.session.commit() - self.session.flush() - if resource.storage_type == "S3": - s3_helper = S3Helper() - s3_helper.upload_file(img, path=resource.path) + self.write_to_s3(file_name, final_path) logger.info(f"Binary {file_name} saved successfully") except Exception as err: return f"Error: {err}" + def write_to_s3(self, file_name, final_path): + with open(final_path, 'rb') as img: + resource = ResourceHelper.make_written_file_resource(file_name=file_name, + agent_id=self.agent_id, channel="OUTPUT") + if resource is not None: + self.session.add(resource) + self.session.commit() + self.session.flush() + if resource.storage_type == "S3": + s3_helper = S3Helper() + s3_helper.upload_file(img, path=resource.path) + def write_file(self, file_name: str, content): final_path = ResourceHelper.get_resource_path(file_name) @@ -37,16 +40,7 @@ def write_file(self, file_name: str, content): with open(final_path, mode="w") as file: file.write(content) file.close() - with open(final_path, 'rb') as img: - resource = ResourceHelper.make_written_file_resource(file_name=file_name, - agent_id=self.agent_id, channel="OUTPUT") - if resource is not None: - self.session.add(resource) - self.session.commit() - self.session.flush() - if resource.storage_type == "S3": - s3_helper = S3Helper() - s3_helper.upload_file(img, path=resource.path) + self.write_to_s3(file_name, final_path) logger.info(f"{file_name} saved successfully") except Exception as err: return f"Error: {err}" \ No newline at end of file diff --git a/superagi/tools/file/write_file.py b/superagi/tools/file/write_file.py index 67df6fa32..43a4344f4 100644 --- a/superagi/tools/file/write_file.py +++ b/superagi/tools/file/write_file.py @@ -36,6 +36,9 @@ class WriteFileTool(BaseTool): agent_id: int = None resource_manager: Optional[ResourceManager] = None + class Config: + arbitrary_types_allowed = True + def _execute(self, file_name: str, content: str): """ Execute the write file tool. diff --git a/superagi/tools/image_generation/dalle_image_gen.py b/superagi/tools/image_generation/dalle_image_gen.py index 01ba59ad5..2b120efc2 100644 --- a/superagi/tools/image_generation/dalle_image_gen.py +++ b/superagi/tools/image_generation/dalle_image_gen.py @@ -6,17 +6,15 @@ from superagi.llms.base_llm import BaseLlm from superagi.resource_manager.manager import ResourceManager from superagi.tools.base_tool import BaseTool -from superagi.tools.image_generation.stable_diffusion_image_gen import StableDiffusionImageGenTool - -class ImageGenInput(BaseModel): +class DalleImageGenInput(BaseModel): prompt: str = Field(..., description="Prompt for Image Generation to be used by Dalle.") size: int = Field(..., description="Size of the image to be Generated. default size is 512") num: int = Field(..., description="Number of Images to be generated. default num is 2") - image_name: list = Field(..., description="Image Names for the generated images, example 'image_1.png'. Only include the image name. Don't include path.") + image_names: list = Field(..., description="Image Names for the generated images, example 'image_1.png'. Only include the image name. Don't include path.") -class ImageGenTool(BaseTool): +class DalleImageGenTool(BaseTool): """ Dalle Image Generation tool @@ -29,7 +27,7 @@ class ImageGenTool(BaseTool): resource_manager : Manages the file resources """ name: str = "DalleImageGeneration" - args_schema: Type[BaseModel] = ImageGenInput + args_schema: Type[BaseModel] = DalleImageGenInput description: str = "Generate Images using Dalle" llm: Optional[BaseLlm] = None agent_id: int = None @@ -38,7 +36,7 @@ class ImageGenTool(BaseTool): class Config: arbitrary_types_allowed = True - def _execute(self, prompt: str, image_name: list, size: int = 512, num: int = 2): + def _execute(self, prompt: str, image_names: list, size: int = 512, num: int = 2): """ Execute the Dalle Image Generation tool. @@ -46,7 +44,7 @@ def _execute(self, prompt: str, image_name: list, size: int = 512, num: int = 2) prompt : The prompt for image generation. size : The size of the image to be generated. num : The number of images to be generated. - image_name (list): The name of the image to be generated. + image_names (list): The name of the image to be generated. Returns: Image generated successfully. or error message. @@ -57,8 +55,6 @@ def _execute(self, prompt: str, image_name: list, size: int = 512, num: int = 2) response = response.__dict__ response = response['_previous']['data'] for i in range(num): - image = image_name[i] - url = response[i]['url'] - data = requests.get(url).content - self.resource_manager.write_binary_file(image, data) + data = requests.get(response[i]['url']).content + self.resource_manager.write_binary_file(image_names[i], data) return "Images downloaded successfully" diff --git a/superagi/tools/image_generation/stable_diffusion_image_gen.py b/superagi/tools/image_generation/stable_diffusion_image_gen.py index c7278adeb..3f615d650 100644 --- a/superagi/tools/image_generation/stable_diffusion_image_gen.py +++ b/superagi/tools/image_generation/stable_diffusion_image_gen.py @@ -1,9 +1,10 @@ import base64 +from io import BytesIO from typing import Type, Optional import requests +from PIL import Image from pydantic import BaseModel, Field - from superagi.config.config import get_config from superagi.resource_manager.manager import ResourceManager from superagi.tools.base_tool import BaseTool @@ -15,8 +16,8 @@ class StableDiffusionImageGenInput(BaseModel): width: int = Field(..., description="Width of the image to be Generated. default width is 512") num: int = Field(..., description="Number of Images to be generated. default num is 2") steps: int = Field(..., description="Number of diffusion steps to run. default steps are 50") - image_name: list = Field(..., - description="Image Names for the generated images, example 'image_1.png'. Only include the image name. Don't include path.") + image_names: list = Field(..., + description="Image Names for the generated images, example 'image_1.png'. Only include the image name. Don't include path.") class StableDiffusionImageGenTool(BaseTool): @@ -39,7 +40,7 @@ class StableDiffusionImageGenTool(BaseTool): class Config: arbitrary_types_allowed = True - def _execute(self, prompt: str, image_name: list, width: int = 512, height: int = 512, num: int = 2, + def _execute(self, prompt: str, image_names: list, width: int = 512, height: int = 512, num: int = 2, steps: int = 50): api_key = get_config("STABILITY_API_KEY") @@ -61,11 +62,12 @@ def _execute(self, prompt: str, image_name: list, width: int = 512, height: int for i in range(num): image_base64 = base64_strings[i] img_data = base64.b64decode(image_base64) - # final_img = Image.open(BytesIO(img_data)) - # image_format = final_img.format + final_img = Image.open(BytesIO(img_data)) + image_format = final_img.format + img_byte_arr = BytesIO() + final_img.save(img_byte_arr, format=image_format) - image = image_name[i] - self.resource_manager.write_binary_file(image_name[i], img_data) + self.resource_manager.write_binary_file(image_names[i], img_byte_arr.getvalue()) return "Images downloaded and saved successfully" @@ -84,11 +86,7 @@ def call_stable_diffusion(self, api_key, width, height, num, prompt, steps): "Authorization": f"Bearer {api_key}" }, json={ - "text_prompts": [ - { - "text": prompt - } - ], + "text_prompts": [{"text": prompt}], "height": height, "width": width, "samples": num, diff --git a/tests/unit_tests/tools/image_gen_test.py b/tests/unit_tests/tools/image_gen_test.py deleted file mode 100644 index f96454545..000000000 --- a/tests/unit_tests/tools/image_gen_test.py +++ /dev/null @@ -1,46 +0,0 @@ -import os -import unittest -from unittest.mock import patch, MagicMock - -from superagi.tools.image_generation.dalle_image_gen import ImageGenTool - - -class TestImageGenTool(unittest.TestCase): - - @patch('openai.Image.create') - @patch('requests.get') - @patch('superagi.tools.image_generation.dalle_image_gen.get_config') - def test_image_gen_tool_execute(self, mock_get_config, mock_requests_get, mock_openai_create): - # Setup - tool = ImageGenTool() - prompt = 'Artificial Intelligence' - image_names = ['image1.png', 'image2.png'] - size = 512 - num = 2 - - # Mock responses - mock_get_config.return_value = "/tmp" - mock_openai_create.return_value = MagicMock(_previous=MagicMock(data=[ - {"url": "https://example.com/image1.png"}, - {"url": "https://example.com/image2.png"} - ])) - mock_requests_get.return_value.content = b"image_data" - - # Run the method under test - response = tool._execute(prompt, image_names, size, num) - - # Assert the method ran correctly - self.assertEqual(response, "Images downloaded successfully") - for image_name in image_names: - path = "/tmp/" + image_name - self.assertTrue(os.path.exists(path)) - with open(path, "rb") as file: - self.assertEqual(file.read(), b"image_data") - - # Clean up - for image_name in image_names: - os.remove("/tmp/" + image_name) - - -if __name__ == '__main__': - unittest.main() \ No newline at end of file diff --git a/tests/unit_tests/tools/stable_diffusion_image_gen_test.py b/tests/unit_tests/tools/stable_diffusion_image_gen_test.py deleted file mode 100644 index bff79bf96..000000000 --- a/tests/unit_tests/tools/stable_diffusion_image_gen_test.py +++ /dev/null @@ -1,63 +0,0 @@ -import os -import unittest -from unittest.mock import patch, MagicMock -from PIL import Image -from io import BytesIO -import base64 -from superagi.config.config import get_config - -from superagi.tools.image_generation.stable_diffusion_image_gen import StableDiffusionImageGenTool - - -class TestStableDiffusionImageGenTool(unittest.TestCase): - - @patch('requests.post') - @patch('superagi.tools.image_generation.stable_diffusion_image_gen.get_config') - def test_stable_diffusion_image_gen_tool_execute(self, mock_get_config, mock_requests_post): - # Setup - tool = StableDiffusionImageGenTool() - prompt = 'Artificial Intelligence' - image_names = ['image1.png', 'image2.png'] - height, width = 512, 512 - num = 2 - steps = 50 - - # Create a temporary directory for image storage - temp_dir = get_config("RESOURCES_OUTPUT_ROOT_DIR") - - # Mock responses - mock_configs = {"STABILITY_API_KEY": "api_key", "ENGINE_ID": "engine_id", "RESOURCES_OUTPUT_ROOT_DIR": temp_dir} - mock_get_config.side_effect = lambda k: mock_configs[k] - - # Prepare sample image bytes - img = Image.new("RGB", (width, height), "white") - buffer = BytesIO() - img.save(buffer, "PNG") - buffer.seek(0) - img_data = buffer.getvalue() - encoded_image_data = base64.b64encode(img_data).decode() - - # Use the proper base64-encoded string - mock_requests_post.return_value = MagicMock(status_code=200, json=lambda: { - "artifacts": [ - {"base64": encoded_image_data}, - {"base64": encoded_image_data} - ] - }) - - # Run the method under test - response = tool._execute(prompt, image_names, width, height, num, steps) - self.assertEqual(response, f"Images downloaded successfully") - - for image_name in image_names: - path = os.path.join(temp_dir, image_name) - self.assertTrue(os.path.exists(path)) - with open(path, "rb") as file: - self.assertEqual(file.read(), img_data) - - # Clean up - for image_name in image_names: - os.remove(os.path.join(temp_dir, image_name)) - -if __name__ == '__main__': - unittest.main() diff --git a/tests/unit_tests/tools/test_dalle_image_gen.py b/tests/unit_tests/tools/test_dalle_image_gen.py new file mode 100644 index 000000000..82653d5a2 --- /dev/null +++ b/tests/unit_tests/tools/test_dalle_image_gen.py @@ -0,0 +1,27 @@ +from unittest.mock import Mock, patch +import pytest +from superagi.tools.image_generation.dalle_image_gen import DalleImageGenTool + + +class MockBaseLlm: + def generate_image(self, prompt, size, num): + return Mock(_previous={"data": [{"url": f"https://example.com/image_{i}.png"} for i in range(num)]}) + + +class TestDalleImageGenTool: + + @pytest.fixture + def tool(self): + tool = DalleImageGenTool() + tool.llm = MockBaseLlm() + response_mock = Mock() + tool.resource_manager = response_mock + return tool + + @patch("requests.get") + def test_execute(self, mock_get, tool): + mock_get.return_value = Mock(content=b"fake image data") + response = tool._execute("test prompt", ["test1.png", "test2.png"], size=512, num=2) + assert response == "Images downloaded successfully" + mock_get.assert_called_with("https://example.com/image_1.png") + assert tool.resource_manager.write_binary_file.call_count == 2 diff --git a/tests/unit_tests/tools/test_stable_diffusion_image_gen.py b/tests/unit_tests/tools/test_stable_diffusion_image_gen.py new file mode 100644 index 000000000..6d2dc75c0 --- /dev/null +++ b/tests/unit_tests/tools/test_stable_diffusion_image_gen.py @@ -0,0 +1,51 @@ +import base64 +from io import BytesIO +from unittest.mock import patch, Mock + +import pytest +from PIL import Image + +from superagi.tools.image_generation.stable_diffusion_image_gen import StableDiffusionImageGenTool + + +def create_sample_image_base64(): + image = Image.new('RGBA', size=(50, 50), color=(73, 109, 137)) + byte_arr = BytesIO() + image.save(byte_arr, format='PNG') + encoded_image = base64.b64encode(byte_arr.getvalue()) + return encoded_image.decode('utf-8') + + +@pytest.fixture +def stable_diffusion_tool(): + with patch('superagi.tools.image_generation.stable_diffusion_image_gen.get_config') as get_config_mock, \ + patch('superagi.tools.image_generation.stable_diffusion_image_gen.requests.post') as post_mock, \ + patch('superagi.tools.image_generation.stable_diffusion_image_gen.ResourceManager') as resource_manager_mock: + get_config_mock.return_value = 'fake_api_key' + + # Create a mock response object + response_mock = Mock() + response_mock.status_code = 200 + response_mock.json.return_value = { + 'artifacts': [{'base64': create_sample_image_base64()} for _ in range(2)] + } + post_mock.return_value = response_mock + + resource_manager_mock.write_binary_file.return_value = None + + yield + +def test_execute(stable_diffusion_tool): + tool = StableDiffusionImageGenTool() + tool.resource_manager = Mock() + result = tool._execute('prompt', ['img1.png', 'img2.png']) + + assert result == 'Images downloaded and saved successfully' + tool.resource_manager.write_binary_file.assert_called() + +def test_call_stable_diffusion(stable_diffusion_tool): + tool = StableDiffusionImageGenTool() + response = tool.call_stable_diffusion('fake_api_key', 512, 512, 2, 'prompt', 50) + + assert response.status_code == 200 + assert 'artifacts' in response.json() \ No newline at end of file From 450326017274cf906ce36265678bc75772caff40 Mon Sep 17 00:00:00 2001 From: TransformerOptimus Date: Wed, 21 Jun 2023 17:45:57 +0530 Subject: [PATCH 5/5] adding coverage package --- requirements.txt | 1 + superagi/tools/file/write_file.py | 10 ++-------- 2 files changed, 3 insertions(+), 8 deletions(-) diff --git a/requirements.txt b/requirements.txt index 71d7a6dd8..96b1cb87a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -132,3 +132,4 @@ tiktoken==0.4.0 psycopg2==2.9.6 slack-sdk==3.21.3 pytest==7.3.2 +pytest-cov==4.1.0 diff --git a/superagi/tools/file/write_file.py b/superagi/tools/file/write_file.py index 43a4344f4..f425a3186 100644 --- a/superagi/tools/file/write_file.py +++ b/superagi/tools/file/write_file.py @@ -1,18 +1,12 @@ -import os from typing import Type, Optional + from pydantic import BaseModel, Field from superagi.resource_manager.manager import ResourceManager from superagi.tools.base_tool import BaseTool -from superagi.config.config import get_config -from sqlalchemy.orm import sessionmaker -from superagi.models.db import connect_db -from superagi.helper.resource_helper import ResourceHelper -# from superagi.helper.s3_helper import upload_to_s3 -from superagi.helper.s3_helper import S3Helper -from superagi.lib.logger import logger +# from superagi.helper.s3_helper import upload_to_s3 class WriteFileInput(BaseModel):