Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/ros2 #1

Merged
merged 2 commits into from
Sep 26, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
ROS2 integration first commit
  • Loading branch information
RoboCoachian committed Sep 26, 2023
commit 215a62dc551028618d21c03ef34d916dc58aa681
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -163,3 +163,6 @@ catkin_ws/
# Generated figures
*.gv
*.pdf

# Docker scripts
docker_run
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "roscribe"
version = "0.0.2"
version = "0.0.3"
description = "Translate natural language into robot software."
readme = "README.md"
authors = [{ name = "RoboCoach Technologies", email = "robocoachtechnologies@gmail.com" }]
Expand Down
149 changes: 116 additions & 33 deletions roscribe/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,88 @@
import roscribe.ui as ui


def catkin_ws_generator(project_name):
if not os.path.exists('catkin_ws'):
os.mkdir('catkin_ws')

if not os.path.exists('catkin_ws/src'):
os.mkdir('catkin_ws/src')

os.mkdir(f'catkin_ws/src/{project_name}')
os.mkdir(f'catkin_ws/src/{project_name}/src')
os.mkdir(f'catkin_ws/src/{project_name}/launch')


def code_generator(task, node_topic_list, curr_node, summary, project_name, llm, verbose=False):
gen_code_prompt = get_gen_code_prompt()
ROS_WS_NAME = 'ros_ws'

SETUP_PY_TEMPLATE = """
setup.py
```python
from setuptools import setup

package_name = '{package_name}'

setup(
name=package_name,
version='0.0.1',
packages=[package_name],
data_files=[
('share/ament_index/resource_index/packages',
['resource/' + package_name]),
('share/' + package_name, ['package.xml']),
],
install_requires=['setuptools'],
zip_safe=True,
maintainer='TODO',
maintainer_email='TODO',
description='TODO: Package description',
license='TODO: License declaration',
tests_require=['pytest'],
entry_points={
'console_scripts': [
{node_list}
],
},
)
```
"""


SETUP_CFG_TEMPLATE = """
setup.cfg
```cfg
[develop]
script_dir=$base/lib/{package_name}
[install]
install_scripts=$base/lib/{package_name}
```
"""


def make_setup_py(node_topic_dict, package_name):
node_list_str = ""
for node in node_topic_dict.keys():
node_list_str += f"'{node} = {package_name}.{node}:main', "

setup_py = SETUP_PY_TEMPLATE.format(package_name=package_name, node_list=node_list_str)
return setup_py


def make_setup_cfg(package_name):
setup_cfg = SETUP_CFG_TEMPLATE.format(package_name=package_name)
return setup_cfg


def ros_ws_generator(project_name, ros_version):
if not os.path.exists(ROS_WS_NAME):
os.mkdir(ROS_WS_NAME)

if not os.path.exists(f'{ROS_WS_NAME}/src'):
os.mkdir(f'{ROS_WS_NAME}/src')

os.mkdir(f'{ROS_WS_NAME}/src/{project_name}')
os.mkdir(f'{ROS_WS_NAME}/src/{project_name}/launch')

if ros_version == 'ros1':
os.mkdir(f'{ROS_WS_NAME}/src/{project_name}/src')
elif ros_version == 'ros2':
os.mkdir(f'{ROS_WS_NAME}/src/{project_name}/{project_name}')
open(f'{ROS_WS_NAME}/src/{project_name}/{project_name}/__init__.py', 'x')

os.mkdir(f'{ROS_WS_NAME}/src/{project_name}/resource')
open(f'{ROS_WS_NAME}/src/{project_name}/resource/{project_name}', 'x')


def code_generator(task, node_topic_list, curr_node, summary, project_name, ros_version, llm, verbose=False):
gen_code_prompt = get_gen_code_prompt(ros_version)

gen_code_chain = LLMChain(
llm=llm,
Expand All @@ -33,13 +101,13 @@ def code_generator(task, node_topic_list, curr_node, summary, project_name, llm,
gen_code_output = gen_code_chain.predict(task=task, node_topic_list=node_topic_list,
curr_node=curr_node, summary=summary)

to_files(gen_code_output, project_name, 'src')
to_files(gen_code_output, project_name, 'impl', ros_version)

print(ui.GEN_NODE_CODE_MSG.format(node=curr_node))


def launch_generator(task, node_topic_list, project_name, llm, verbose=False):
gen_launch_prompt = get_gen_launch_prompt()
def launch_generator(task, node_topic_list, project_name, ros_version, llm, verbose=False):
gen_launch_prompt = get_gen_launch_prompt(ros_version)

gen_launch_chain = LLMChain(
llm=llm,
Expand All @@ -54,20 +122,25 @@ def launch_generator(task, node_topic_list, project_name, llm, verbose=False):
print(ui.GEN_LAUNCH_MSG)


def install_generator(task, node_topic_list, project_name, llm, verbose=False):
gen_cmake_prompt = get_gen_cmake_prompt()
def install_generator(task, node_topic_dict, node_topic_list, project_name, ros_version, llm, verbose=False):
if ros_version == 'ros1':
gen_cmake_prompt = get_gen_cmake_prompt()

gen_cmake_chain = LLMChain(
llm=llm,
prompt=gen_cmake_prompt,
verbose=verbose
)
gen_cmake_chain = LLMChain(
llm=llm,
prompt=gen_cmake_prompt,
verbose=verbose
)

gen_install_output = gen_cmake_chain.predict(task=task, node_topic_list=node_topic_list,
project_name=project_name)

gen_cmake_output = gen_cmake_chain.predict(task=task, node_topic_list=node_topic_list, project_name=project_name)
elif ros_version == 'ros2':
gen_install_output = make_setup_py(node_topic_dict, project_name)

to_files(gen_cmake_output, project_name, 'install')
to_files(gen_install_output, project_name, 'install', ros_version)

gen_package_prompt = get_gen_package_prompt()
gen_package_prompt = get_gen_package_prompt(ros_version)

gen_package_chain = LLMChain(
llm=llm,
Expand All @@ -83,7 +156,7 @@ def install_generator(task, node_topic_list, project_name, llm, verbose=False):
print(ui.GEN_INSTALL_MSG)


def to_files(chat, project_name, mode):
def to_files(chat, project_name, mode, ros_version='ros1'):
workspace = dict()

files = get_code_from_chat(chat)
Expand All @@ -96,14 +169,24 @@ def to_files(chat, project_name, mode):
for filename in workspace.keys():
code = workspace[filename]

if mode == 'src':
with open(f'catkin_ws/src/{project_name}/src/{filename}', 'w') as file:
file.write(code)
if mode == 'impl':
if ros_version == 'ros1':
with open(f'{ROS_WS_NAME}/src/{project_name}/src/{filename}', 'w') as file:
file.write(code)
elif ros_version == 'ros2':
with open(f'{ROS_WS_NAME}/src/{project_name}/{project_name}/{filename}', 'w') as file:
file.write(code)

elif mode == 'launch':
with open(f'catkin_ws/src/{project_name}/launch/{filename}', 'w') as file:
with open(f'{ROS_WS_NAME}/src/{project_name}/launch/{filename}', 'w') as file:
file.write(code)

elif mode == 'install':
with open(f'catkin_ws/src/{project_name}/{filename}', 'w') as file:
with open(f'{ROS_WS_NAME}/src/{project_name}/{filename}', 'w') as file:
file.write(code)

if ros_version == 'ros2' and filename == 'setup.py':
with open(f'{ROS_WS_NAME}/src/{project_name}/setup.cfg', 'w') as file:
file.write(make_setup_cfg(project_name))
else:
print('Invalid file storage mode!')
26 changes: 16 additions & 10 deletions roscribe/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from roscribe.prompt import get_project_name_prompt, get_task_spec_prompt, get_task_spec_summarize_prompt,\
get_gen_node_prompt, get_gen_topic_prompt, get_node_qa_prompt, get_node_qa_sum_prompt
from roscribe.parser import make_node_list, make_node_topic_dict, make_node_topic_list_str, modify_node_dict
from roscribe.generator import catkin_ws_generator, code_generator, launch_generator, install_generator
from roscribe.generator import ros_ws_generator, code_generator, launch_generator, install_generator
from roscribe.visualization import show_node_graph

import roscribe.ui as ui
Expand All @@ -17,6 +17,10 @@ def main(verbose=False):
print(ui.WELCOME_MSG)

task_message = input("Your Robot Software: ") # User-specified task
ros_version = input("ROS1 or ROS2? ").replace(" ", "").lower() # User-specified ROS version
while ros_version not in ['ros1', 'ros2']:
print(ui.VALID_ROS_VER)
ros_version = input("ROS1 or ROS2? ").replace(" ", "").lower()

project_name_prompt = get_project_name_prompt()
project_name_chain = LLMChain(
Expand All @@ -25,9 +29,9 @@ def main(verbose=False):
verbose=verbose)
project_name = project_name_chain.predict(task=task_message)

catkin_ws_generator(project_name)
ros_ws_generator(project_name, ros_version)

task_spec_prompt, task_spec_end_str = get_task_spec_prompt(task_message)
task_spec_prompt, task_spec_end_str = get_task_spec_prompt(task_message, ros_version)
task_spec_memory = ConversationBufferMemory()
task_spec_chain = ConversationChain(
llm=llm,
Expand Down Expand Up @@ -57,7 +61,7 @@ def main(verbose=False):
task_spec_memory.return_messages = True
task_spec_sum_output = task_spec_summary_chain.predict(input=task_spec_memory.load_memory_variables({}))

node_gen_prompt, node_gen_parser = get_gen_node_prompt()
node_gen_prompt, node_gen_parser = get_gen_node_prompt(ros_version)
node_gen_chain = LLMChain(
llm=llm,
prompt=node_gen_prompt,
Expand All @@ -67,7 +71,7 @@ def main(verbose=False):
node_gen_list = node_gen_parser.parse(node_gen_output).ros_nodes
node_list_str = make_node_list(node_gen_list)

topic_gen_prompt, topic_gen_parser = get_gen_topic_prompt()
topic_gen_prompt, topic_gen_parser = get_gen_topic_prompt(ros_version)
topic_gen_chain = LLMChain(
llm=llm,
prompt=topic_gen_prompt,
Expand Down Expand Up @@ -134,8 +138,10 @@ def main(verbose=False):
print(ui.QA_MSG_INIT)

for node in node_topic_dict.keys():
node_spec_prompt, node_spec_end_str = get_node_qa_prompt(
task=task_message, node_topic_list=node_topic_list_str, curr_node=node)
node_spec_prompt, node_spec_end_str = get_node_qa_prompt(task=task_message,
node_topic_list=node_topic_list_str,
curr_node=node,
ros_version=ros_version)
node_spec_memory = ConversationBufferMemory()
node_spec_chain = ConversationChain(
llm=llm,
Expand Down Expand Up @@ -165,13 +171,13 @@ def main(verbose=False):
node_spec_memory.return_messages = True
sum_output = summary_chain.predict(input=node_spec_memory.load_memory_variables({}))

code_generator(task_message, node_topic_list_str, node, sum_output, project_name, llm, verbose)
code_generator(task_message, node_topic_list_str, node, sum_output, project_name, ros_version, llm, verbose)

print(ui.LAUNCH_INSTALL_MSG)

launch_generator(task_message, node_topic_list_str, project_name, llm)
launch_generator(task_message, node_topic_list_str, project_name, ros_version, llm, verbose)

install_generator(task_message, node_topic_list_str, project_name, llm)
install_generator(task_message, node_topic_dict, node_topic_list_str, project_name, ros_version, llm, verbose)

print(ui.FAREWELL_MSG)

Expand Down
6 changes: 3 additions & 3 deletions roscribe/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
from langchain.output_parsers import PydanticOutputParser


MOD_INPUT_SCHEMA = Schema((str, str, [(str, str)], [(str, str)]))


# Data structure for the output of ROS nodes
class NodeList(BaseModel):
ros_nodes: Dict[str, str] = Field(description="dictionary containing ROS node names as keys and ROS node descriptions as values")
Expand Down Expand Up @@ -79,9 +82,6 @@ def make_node_topic_list_str(node_topic_dict):
return node_topic_list_str


MOD_INPUT_SCHEMA = Schema((str, str, [(str, str)], [(str, str)]))


def modify_node_dict(mod_input, node_topic_dict):
try:
mod_tuple = ast.literal_eval(mod_input)
Expand Down
Loading