From b638d37b15af091cfdbc92d003c1ad48d4ba1d89 Mon Sep 17 00:00:00 2001 From: Maverick-F35 <138012351+Maverick-F35@users.noreply.github.com> Date: Wed, 26 Jul 2023 10:55:36 +0530 Subject: [PATCH] handled resource path change (#861) * handled resource path change * readme changed * added new test case for stable diffusion * test cases refactored --- .../stable_diffusion_image_gen.py | 11 +++++++- superagi/tools/instagram_tool/README.MD | 2 +- .../test_stable_diffusion_image_gen.py | 25 ++++++++++++++----- 3 files changed, 30 insertions(+), 8 deletions(-) diff --git a/superagi/tools/image_generation/stable_diffusion_image_gen.py b/superagi/tools/image_generation/stable_diffusion_image_gen.py index b329b3ebf..6831f5d98 100644 --- a/superagi/tools/image_generation/stable_diffusion_image_gen.py +++ b/superagi/tools/image_generation/stable_diffusion_image_gen.py @@ -8,6 +8,8 @@ from superagi.helper.resource_helper import ResourceHelper from superagi.resource_manager.file_manager import FileManager from superagi.tools.base_tool import BaseTool +from superagi.models.agent_execution import AgentExecution +from superagi.models.agent import Agent class StableDiffusionImageGenInput(BaseModel): @@ -35,6 +37,7 @@ class StableDiffusionImageGenTool(BaseTool): args_schema: Type[BaseModel] = StableDiffusionImageGenInput description: str = "Generate Images using Stable Diffusion" agent_id: int = None + agent_execution_id: int = None resource_manager: Optional[FileManager] = None class Config: @@ -70,7 +73,13 @@ def _execute(self, prompt: str, image_names: list, width: int = 512, height: int self.resource_manager.write_binary_file(image_names[i], img_byte_arr.getvalue()) for image in image_names: - image_paths.append(ResourceHelper.get_resource_path(image)) + final_path = ResourceHelper.get_agent_read_resource_path(image, agent=Agent.get_agent_from_id( + session=self.toolkit_config.session, agent_id=self.agent_id), agent_execution=AgentExecution + .get_agent_execution_from_id(session=self + .toolkit_config.session, + agent_execution_id=self + .agent_execution_id)) + image_paths.append(final_path) return f"Images downloaded and saved successfully at the following locations: {image_paths}" diff --git a/superagi/tools/instagram_tool/README.MD b/superagi/tools/instagram_tool/README.MD index fcaa034ac..293504c06 100644 --- a/superagi/tools/instagram_tool/README.MD +++ b/superagi/tools/instagram_tool/README.MD @@ -4,7 +4,7 @@ # SuperAGI Instagram Tool -The SuperAGI Instagram Tool works with the stable diffusion tool, generates an image & caption based on the goals defined by the user and posts it on their instagram business account. +The SuperAGI Instagram Tool works with the stable diffusion tool, generates an image & caption based on the goals defined by the user and posts it on their instagram business account.Currently will only work on the webapp ## ⚙️ Installation diff --git a/tests/unit_tests/tools/image_generation/test_stable_diffusion_image_gen.py b/tests/unit_tests/tools/image_generation/test_stable_diffusion_image_gen.py index dafd60b27..3f34cec6f 100644 --- a/tests/unit_tests/tools/image_generation/test_stable_diffusion_image_gen.py +++ b/tests/unit_tests/tools/image_generation/test_stable_diffusion_image_gen.py @@ -27,7 +27,13 @@ def create_sample_image_base64(): def stable_diffusion_tool(): with patch('superagi.tools.image_generation.stable_diffusion_image_gen.requests.post') as post_mock, \ patch( - 'superagi.tools.image_generation.stable_diffusion_image_gen.FileManager') as resource_manager_mock: + 'superagi.tools.image_generation.stable_diffusion_image_gen.FileManager') as resource_manager_mock, \ + patch( + 'superagi.tools.image_generation.stable_diffusion_image_gen.ResourceHelper') as resource_helper_mock, \ + patch( + 'superagi.tools.image_generation.stable_diffusion_image_gen.Agent') as agent_mock, \ + patch( + 'superagi.tools.image_generation.stable_diffusion_image_gen.AgentExecution') as agent_execution_mock: # Create a mock response object response_mock = Mock() @@ -39,16 +45,23 @@ def stable_diffusion_tool(): resource_manager_mock.write_binary_file.return_value = None + # Mock Agent and AgentExecution to return dummy values + agent_mock.get_agent_from_id.return_value = Mock() + agent_execution_mock.get_agent_execution_from_id.return_value = Mock() + yield + def test_execute(stable_diffusion_tool): tool = StableDiffusionImageGenTool() tool.resource_manager = Mock() - tool.toolkit_config.get_tool_config = mock_get_tool_config - - - result = tool._execute('prompt', ['img1.png', 'img2.png']) - assert result.startswith('Images downloaded and saved successfully') + tool.agent_id = 123 # Use a dummy agent_id for testing purposes + tool.toolkit_config.get_tool_config = lambda key: 'fake_api_key' if key == 'STABILITY_API_KEY' else 'engine_id_1' + prompt = 'Test prompt' + image_names = ['img1.png', 'img2.png'] + expected_result = 'Images downloaded and saved successfully' + result = tool._execute(prompt, image_names) + assert result.startswith(expected_result) tool.resource_manager.write_binary_file.assert_called() def test_call_stable_diffusion(stable_diffusion_tool):