Skip to content

Commit 9a95d0d

Browse files
committed
tool gen from wizard
1 parent 0cad2cd commit 9a95d0d

File tree

3 files changed

+47
-13
lines changed

3 files changed

+47
-13
lines changed

agentstack/cli/cli.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313

1414
from .agentstack_data import FrameworkData, ProjectMetadata, ProjectStructure, CookiecutterData
1515
from agentstack.logger import log
16-
from ..utils import open_json_file
16+
from .. import generation
17+
from ..utils import open_json_file, term_color
1718

1819

1920
def init_project_builder(slug_name: Optional[str] = None, skip_wizard: bool = False):
@@ -32,6 +33,8 @@ def init_project_builder(slug_name: Optional[str] = None, skip_wizard: bool = Fa
3233
'agents': [],
3334
'tasks': []
3435
}
36+
37+
tools = []
3538
else:
3639
welcome_message()
3740
project_details = ask_project_details(slug_name)
@@ -46,6 +49,7 @@ def init_project_builder(slug_name: Optional[str] = None, skip_wizard: bool = Fa
4649
f"design: {design}"
4750
)
4851
insert_template(project_details, framework, design)
52+
add_tools(tools, project_details['name'])
4953

5054

5155
def welcome_message():
@@ -260,6 +264,11 @@ def insert_template(project_details: dict, framework_name: str, design: dict):
260264
shutil.copy(
261265
f'{template_path}/{"{{cookiecutter.project_metadata.project_slug}}"}/.env.example',
262266
f'{template_path}/{"{{cookiecutter.project_metadata.project_slug}}"}/.env')
267+
268+
if os.path.isdir(project_details['name']):
269+
print(term_color(f"Directory {template_path} already exists. Please check this and try again", "red"))
270+
return
271+
263272
cookiecutter(str(template_path), no_input=True, extra_context=None)
264273

265274
# TODO: inits a git repo in the directory the command was run in
@@ -287,6 +296,11 @@ def insert_template(project_details: dict, framework_name: str, design: dict):
287296
)
288297

289298

299+
def add_tools(tools: list, project_name: str):
300+
for tool in tools:
301+
generation.add_tool(tool, project_name)
302+
303+
290304
def list_tools():
291305
try:
292306
# Determine the path to the tools.json file

agentstack/generation/tool_generation.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,43 +1,49 @@
11
import sys
2+
from typing import Optional
3+
24
from .gen_utils import insert_code_after_tag
35
from ..utils import snake_to_camel, open_json_file, get_framework
46
import os
57
import shutil
68
import fileinput
79

810

9-
def add_tool(tool_name: str):
11+
def add_tool(tool_name: str, path: Optional[str] = None):
1012
script_dir = os.path.dirname(os.path.abspath(__file__))
1113
tools = open_json_file(os.path.join(script_dir, '..', 'tools', 'tools.json'))
12-
framework = get_framework()
14+
framework = get_framework(path)
1315
assert_tool_exists(tool_name, tools)
1416

1517
tool_data = open_json_file(os.path.join(script_dir, '..', 'tools', f'{tool_name}.json'))
1618
tool_file_route = os.path.join(script_dir, '..', 'templates', framework, 'tools', f'{tool_name}.py')
1719

1820
os.system(tool_data['package']) # Install package
19-
shutil.copy(tool_file_route, f'src/tools/{tool_name}.py') # Move tool from package to project
20-
add_tool_to_tools_init(tool_data) # Export tool from tools dir
21-
add_tool_to_agent_definition(framework, tool_data)
22-
insert_code_after_tag('.env', '# Tools', [tool_data['env']], next_line=True) # Add env var
23-
insert_code_after_tag('.env.example', '# Tools', [tool_data['env']], next_line=True) # Add env var
21+
shutil.copy(tool_file_route, f'{path or ""}/src/tools/{tool_name}.py') # Move tool from package to project
22+
add_tool_to_tools_init(tool_data, path) # Export tool from tools dir
23+
add_tool_to_agent_definition(framework, tool_data, path)
24+
insert_code_after_tag(f'{path}/.env', '# Tools', [tool_data['env']], next_line=True) # Add env var
25+
insert_code_after_tag(f'{path}/.env.example', '# Tools', [tool_data['env']], next_line=True) # Add env var
2426

2527
print(f'\033[92m🔨 Tool {tool_name} added to agentstack project successfully\033[0m')
2628

27-
def add_tool_to_tools_init(tool_data: dict):
28-
file_path = 'src/tools/__init__.py'
29+
30+
def add_tool_to_tools_init(tool_data: dict, path: Optional[str] = None):
31+
file_path = f'{path or ""}/src/tools/__init__.py'
2932
tag = '# tool import'
3033
code_to_insert = [
3134
f"from {tool_data['name']} import {', '.join([tool_name for tool_name in tool_data['tools']])}"
3235
]
3336
insert_code_after_tag(file_path, tag, code_to_insert, next_line=True)
3437

3538

36-
def add_tool_to_agent_definition(framework: str, tool_data: dict):
39+
def add_tool_to_agent_definition(framework: str, tool_data: dict, path: Optional[str] = None):
3740
filename = ''
3841
if framework == 'crewai':
3942
filename = 'src/crew.py'
4043

44+
if path:
45+
filename = f'{path}/{filename}'
46+
4147
with fileinput.input(files=filename, inplace=True) as f:
4248
for line in f:
4349
print(line.replace('tools=[', f'tools=[tools.{", tools.".join([tool_name for tool_name in tool_data["tools"]])}, '), end='')

agentstack/utils.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import Optional
2+
13
import toml
24
import os
35
import sys
@@ -23,9 +25,12 @@ def verify_agentstack_project():
2325
sys.exit(1)
2426

2527

26-
def get_framework() -> str:
28+
def get_framework(path: Optional[str] = None) -> str:
2729
try:
28-
with open('agentstack.json', 'r') as f:
30+
file_path = 'agentstack.json'
31+
if path is not None:
32+
file_path = path + '/' + file_path
33+
with open(file_path, 'r') as f:
2934
data = json.load(f)
3035
framework = data.get('framework')
3136

@@ -56,3 +61,12 @@ def open_json_file(path) -> dict:
5661
def clean_input(input_string):
5762
special_char_pattern = re.compile(r'[^a-zA-Z0-9\s_]')
5863
return re.sub(special_char_pattern, '', input_string).lower().replace(' ', '_').replace('-', '_')
64+
65+
66+
def term_color(text: str, color: str) -> str:
67+
if color is 'red':
68+
return "\033[91m{}\033[00m".format(text)
69+
if color is 'green':
70+
return "\033[92m{}\033[00m".format(text)
71+
else:
72+
return text

0 commit comments

Comments
 (0)