Skip to content

Commit

Permalink
handled resource path change (#861)
Browse files Browse the repository at this point in the history
* handled resource path change

* readme changed

* added new test case for stable diffusion

* test cases refactored
  • Loading branch information
Aryan-Singh-14 authored Jul 26, 2023
1 parent 97c5ee9 commit b638d37
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 8 deletions.
11 changes: 10 additions & 1 deletion superagi/tools/image_generation/stable_diffusion_image_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from superagi.helper.resource_helper import ResourceHelper
from superagi.resource_manager.file_manager import FileManager
from superagi.tools.base_tool import BaseTool
from superagi.models.agent_execution import AgentExecution
from superagi.models.agent import Agent


class StableDiffusionImageGenInput(BaseModel):
Expand Down Expand Up @@ -35,6 +37,7 @@ class StableDiffusionImageGenTool(BaseTool):
args_schema: Type[BaseModel] = StableDiffusionImageGenInput
description: str = "Generate Images using Stable Diffusion"
agent_id: int = None
agent_execution_id: int = None
resource_manager: Optional[FileManager] = None

class Config:
Expand Down Expand Up @@ -70,7 +73,13 @@ def _execute(self, prompt: str, image_names: list, width: int = 512, height: int
self.resource_manager.write_binary_file(image_names[i], img_byte_arr.getvalue())

for image in image_names:
image_paths.append(ResourceHelper.get_resource_path(image))
final_path = ResourceHelper.get_agent_read_resource_path(image, agent=Agent.get_agent_from_id(
session=self.toolkit_config.session, agent_id=self.agent_id), agent_execution=AgentExecution
.get_agent_execution_from_id(session=self
.toolkit_config.session,
agent_execution_id=self
.agent_execution_id))
image_paths.append(final_path)

return f"Images downloaded and saved successfully at the following locations: {image_paths}"

Expand Down
2 changes: 1 addition & 1 deletion superagi/tools/instagram_tool/README.MD
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

# SuperAGI Instagram Tool

The SuperAGI Instagram Tool works with the stable diffusion tool, generates an image & caption based on the goals defined by the user and posts it on their instagram business account.
The SuperAGI Instagram Tool works with the stable diffusion tool, generates an image & caption based on the goals defined by the user and posts it on their instagram business account.Currently will only work on the webapp

## ⚙️ Installation

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,13 @@ def create_sample_image_base64():
def stable_diffusion_tool():
with patch('superagi.tools.image_generation.stable_diffusion_image_gen.requests.post') as post_mock, \
patch(
'superagi.tools.image_generation.stable_diffusion_image_gen.FileManager') as resource_manager_mock:
'superagi.tools.image_generation.stable_diffusion_image_gen.FileManager') as resource_manager_mock, \
patch(
'superagi.tools.image_generation.stable_diffusion_image_gen.ResourceHelper') as resource_helper_mock, \
patch(
'superagi.tools.image_generation.stable_diffusion_image_gen.Agent') as agent_mock, \
patch(
'superagi.tools.image_generation.stable_diffusion_image_gen.AgentExecution') as agent_execution_mock:

# Create a mock response object
response_mock = Mock()
Expand All @@ -39,16 +45,23 @@ def stable_diffusion_tool():

resource_manager_mock.write_binary_file.return_value = None

# Mock Agent and AgentExecution to return dummy values
agent_mock.get_agent_from_id.return_value = Mock()
agent_execution_mock.get_agent_execution_from_id.return_value = Mock()

yield


def test_execute(stable_diffusion_tool):
tool = StableDiffusionImageGenTool()
tool.resource_manager = Mock()
tool.toolkit_config.get_tool_config = mock_get_tool_config


result = tool._execute('prompt', ['img1.png', 'img2.png'])
assert result.startswith('Images downloaded and saved successfully')
tool.agent_id = 123 # Use a dummy agent_id for testing purposes
tool.toolkit_config.get_tool_config = lambda key: 'fake_api_key' if key == 'STABILITY_API_KEY' else 'engine_id_1'
prompt = 'Test prompt'
image_names = ['img1.png', 'img2.png']
expected_result = 'Images downloaded and saved successfully'
result = tool._execute(prompt, image_names)
assert result.startswith(expected_result)
tool.resource_manager.write_binary_file.assert_called()

def test_call_stable_diffusion(stable_diffusion_tool):
Expand Down

0 comments on commit b638d37

Please sign in to comment.