Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 113 additions & 3 deletions examples/manipulation-demo-streamlit.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,30 +13,118 @@
# limitations under the License.

import importlib
import sys
from pathlib import Path

import rclpy
import streamlit as st
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
from launch import LaunchDescription
from launch.actions import (
IncludeLaunchDescription,
)
from launch.launch_description_sources import PythonLaunchDescriptionSource
from launch_ros.actions import Node
from launch_ros.substitutions import FindPackageShare
from rai.agents.integrations.streamlit import get_streamlit_cb, streamlit_invoke
from rai.communication.ros2.connectors.ros2_connector import ROS2Connector
from rai.messages import HumanMultimodalMessage

from rai_bench.manipulation_o3de import get_scenarios
from rai_bench.manipulation_o3de.benchmark import Scenario
from rai_sim.o3de.o3de_bridge import (
O3DEngineArmManipulationBridge,
O3DExROS2SimulationConfig,
)
from rai_sim.simulation_bridge import SceneConfig

manipulation_demo = importlib.import_module("manipulation-demo")


def launch_description():
launch_moveit = IncludeLaunchDescription(
PythonLaunchDescriptionSource(
[
"src/examples/rai-manipulation-demo/Project/Examples/panda_moveit_config_demo.launch.py",
]
)
)

launch_robotic_manipulation = Node(
package="robotic_manipulation",
executable="robotic_manipulation",
output="screen",
parameters=[
{"use_sim_time": True},
],
)

launch_openset = IncludeLaunchDescription(
PythonLaunchDescriptionSource(
[
FindPackageShare("rai_bringup"),
"/launch/openset.launch.py",
]
),
)

return LaunchDescription(
[
launch_openset,
launch_moveit,
launch_robotic_manipulation,
]
)


@st.cache_resource
def init_ros():
rclpy.init()
return "ros"


@st.cache_resource
def initialize_graph():
return manipulation_demo.create_agent()


def main():
@st.cache_resource
def initialize_o3de(scenario_path: str, o3de_config_path: str):
simulation_config = O3DExROS2SimulationConfig.load_config(
config_path=Path(o3de_config_path)
)
scene_config = SceneConfig.load_base_config(Path(scenario_path))
scenario = Scenario(
task=None,
scene_config=scene_config,
scene_config_path=scenario_path,
)
o3de = O3DEngineArmManipulationBridge(ROS2Connector())
o3de.init_simulation(simulation_config=simulation_config)
o3de.launch_robotic_stack(
required_robotic_ros2_interfaces=simulation_config.required_robotic_ros2_interfaces,
launch_description=launch_description(),
)
o3de.setup_scene(scenario.scene_config)


def main(scenario: Scenario, simulation_config: O3DExROS2SimulationConfig):
st.set_page_config(
page_title="RAI Manipulation Demo",
page_icon=":robot:",
)
st.title("RAI Manipulation Demo")
st.markdown("---")

st.sidebar.header("Tool Calls History")

if "ros" not in st.session_state:
ros = init_ros()
st.session_state["ros"] = ros

if "o3de" not in st.session_state:
o3de = initialize_o3de(scenario, simulation_config)
st.session_state["o3de"] = o3de

if "graph" not in st.session_state:
graph = initialize_graph()
st.session_state["graph"] = graph
Expand Down Expand Up @@ -70,4 +158,26 @@ def main():


if __name__ == "__main__":
main()
levels = [
"medium",
"hard",
"very_hard",
]
scenarios: list[Scenario] = get_scenarios(levels=levels)
scenario_names = [Path(s.scene_config_path).stem for s in scenarios]
print(scenario_names)

if len(sys.argv) > 1:
layout = sys.argv[1]
if layout not in scenario_names:
raise ValueError(f"Invalid layout: {layout}. Select from {scenario_names}")
else:
layout = "3carrots_1a_1t_2bc_2yc"
o3de_config_path = (
"src/rai_bench/rai_bench/manipulation_o3de/predefined/configs/o3de_config.yaml"
)

scenario_idx = scenario_names.index(layout)
scenario = str(scenarios[scenario_idx].scene_config_path)

main(scenario, o3de_config_path)