-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #435 from TransformerOptimus/resource_manager_refa…
…ctoring1 Resource manager refactoring
- Loading branch information
Showing
33 changed files
with
359 additions
and
303 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -132,3 +132,4 @@ tiktoken==0.4.0 | |
psycopg2==2.9.6 | ||
slack-sdk==3.21.3 | ||
pytest==7.3.2 | ||
pytest-cov==4.1.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
from sqlalchemy.orm import Session | ||
|
||
from superagi.helper.resource_helper import ResourceHelper | ||
from superagi.helper.s3_helper import S3Helper | ||
from superagi.lib.logger import logger | ||
|
||
|
||
class ResourceManager: | ||
def __init__(self, session: Session): | ||
self.session = session | ||
|
||
def write_binary_file(self, file_name: str, data): | ||
final_path = ResourceHelper.get_resource_path(file_name) | ||
|
||
try: | ||
with open(final_path, mode="wb") as img: | ||
img.write(data) | ||
img.close() | ||
self.write_to_s3(file_name, final_path) | ||
logger.info(f"Binary {file_name} saved successfully") | ||
except Exception as err: | ||
return f"Error: {err}" | ||
|
||
def write_to_s3(self, file_name, final_path): | ||
with open(final_path, 'rb') as img: | ||
resource = ResourceHelper.make_written_file_resource(file_name=file_name, | ||
agent_id=self.agent_id, channel="OUTPUT") | ||
if resource is not None: | ||
self.session.add(resource) | ||
self.session.commit() | ||
self.session.flush() | ||
if resource.storage_type == "S3": | ||
s3_helper = S3Helper() | ||
s3_helper.upload_file(img, path=resource.path) | ||
|
||
def write_file(self, file_name: str, content): | ||
final_path = ResourceHelper.get_resource_path(file_name) | ||
|
||
try: | ||
with open(final_path, mode="w") as file: | ||
file.write(content) | ||
file.close() | ||
self.write_to_s3(file_name, final_path) | ||
logger.info(f"{file_name} saved successfully") | ||
except Exception as err: | ||
return f"Error: {err}" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,92 +1,60 @@ | ||
from typing import Type, Optional | ||
|
||
import requests | ||
from pydantic import BaseModel, Field | ||
|
||
from superagi.llms.base_llm import BaseLlm | ||
from superagi.resource_manager.manager import ResourceManager | ||
from superagi.tools.base_tool import BaseTool | ||
from superagi.config.config import get_config | ||
import os | ||
import requests | ||
from superagi.models.db import connect_db | ||
from superagi.helper.resource_helper import ResourceHelper | ||
from superagi.helper.s3_helper import S3Helper | ||
from sqlalchemy.orm import sessionmaker | ||
from superagi.lib.logger import logger | ||
|
||
|
||
|
||
class ImageGenInput(BaseModel): | ||
class DalleImageGenInput(BaseModel): | ||
prompt: str = Field(..., description="Prompt for Image Generation to be used by Dalle.") | ||
size: int = Field(..., description="Size of the image to be Generated. default size is 512") | ||
num: int = Field(..., description="Number of Images to be generated. default num is 2") | ||
image_name: list = Field(..., description="Image Names for the generated images, example 'image_1.png'. Only include the image name. Don't include path.") | ||
image_names: list = Field(..., description="Image Names for the generated images, example 'image_1.png'. Only include the image name. Don't include path.") | ||
|
||
|
||
class ImageGenTool(BaseTool): | ||
class DalleImageGenTool(BaseTool): | ||
""" | ||
Dalle Image Generation tool | ||
Attributes: | ||
name : The name. | ||
description : The description. | ||
args_schema : The args schema. | ||
name : Name of the tool | ||
description : The description | ||
args_schema : The args schema | ||
llm : The llm | ||
agent_id : The agent id | ||
resource_manager : Manages the file resources | ||
""" | ||
name: str = "Dalle Image Generation" | ||
args_schema: Type[BaseModel] = ImageGenInput | ||
name: str = "DalleImageGeneration" | ||
args_schema: Type[BaseModel] = DalleImageGenInput | ||
description: str = "Generate Images using Dalle" | ||
llm: Optional[BaseLlm] = None | ||
agent_id: int = None | ||
resource_manager: Optional[ResourceManager] = None | ||
|
||
class Config: | ||
arbitrary_types_allowed = True | ||
|
||
def _execute(self, prompt: str, image_name: list, size: int = 512, num: int = 2): | ||
def _execute(self, prompt: str, image_names: list, size: int = 512, num: int = 2): | ||
""" | ||
Execute the Dalle Image Generation tool. | ||
Args: | ||
prompt : The prompt for image generation. | ||
size : The size of the image to be generated. | ||
num : The number of images to be generated. | ||
image_name (list): The name of the image to be generated. | ||
image_names (list): The name of the image to be generated. | ||
Returns: | ||
Image generated successfully. or error message. | ||
""" | ||
engine = connect_db() | ||
Session = sessionmaker(bind=engine) | ||
session = Session() | ||
if size not in [256, 512, 1024]: | ||
size = min([256, 512, 1024], key=lambda x: abs(x - size)) | ||
response = self.llm.generate_image(prompt, size, num) | ||
response = response.__dict__ | ||
response = response['_previous']['data'] | ||
for i in range(num): | ||
image = image_name[i] | ||
final_path = image | ||
root_dir = get_config('RESOURCES_OUTPUT_ROOT_DIR') | ||
if root_dir is not None: | ||
root_dir = root_dir if root_dir.startswith("/") else os.getcwd() + "/" + root_dir | ||
root_dir = root_dir if root_dir.endswith("/") else root_dir + "/" | ||
final_path = root_dir + image | ||
else: | ||
final_path = os.getcwd() + "/" + image | ||
url = response[i]['url'] | ||
data = requests.get(url).content | ||
try: | ||
with open(final_path, mode="wb") as img: | ||
img.write(data) | ||
with open(final_path, 'rb') as img: | ||
resource = ResourceHelper.make_written_file_resource(file_name=image_name[i], | ||
agent_id=self.agent_id, file=img,channel="OUTPUT") | ||
if resource is not None: | ||
session.add(resource) | ||
session.commit() | ||
session.flush() | ||
if resource.storage_type == "S3": | ||
s3_helper = S3Helper() | ||
s3_helper.upload_file(img, path=resource.path) | ||
logger.info(f"Image {image} saved successfully") | ||
except Exception as err: | ||
session.close() | ||
return f"Error: {err}" | ||
session.close() | ||
data = requests.get(response[i]['url']).content | ||
self.resource_manager.write_binary_file(image_names[i], data) | ||
return "Images downloaded successfully" |
Oops, something went wrong.