Skip to content

Commit

Permalink
Add rrt connect
Browse files Browse the repository at this point in the history
  • Loading branch information
aaronzguan committed Oct 19, 2020
1 parent 0d8ca32 commit 20dd647
Show file tree
Hide file tree
Showing 2 changed files with 294 additions and 0 deletions.
163 changes: 163 additions & 0 deletions Path Planning/rrt_connect.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
from time import time
import numpy as np
from scipy.optimize import minimize, LinearConstraint, NonlinearConstraint

from kdtree import KDTree
from franka_robot import FrankaRobot


class SimpleTree:

def __init__(self, dim):
self._parents_map = {}
self._kd = KDTree(dim)

def insert_new_node(self, point, parent=None):
node_id = self._kd.insert(point)
self._parents_map[node_id] = parent

return node_id

def get_parent(self, child_id):
return self._parents_map[child_id]

def get_point(self, node_id):
return self._kd.get_node(node_id).point

def get_nearest_node(self, point):
return self._kd.find_nearest_point(point)

def construct_path_to_root(self, leaf_node_id):
path = []
node_id = leaf_node_id
while node_id is not None:
path.append(self.get_point(node_id))
node_id = self.get_parent(node_id)

return path


class RRTConnect:

def __init__(self, fr, is_in_collision):
self._fr = fr
self._is_in_collision = is_in_collision

self._q_step_size = 0.02
self._connect_dist = 0.8
self._max_n_nodes = int(1e5)

def sample_valid_joints(self):
q = np.random.random(self._fr.num_dof) * (self._fr.joint_limits_high - self._fr.joint_limits_low) + self._fr.joint_limits_low
return q

def project_to_constraint(self, q0, constraint):
def f(q):
return constraint(q)[0]

def df(q):
c_grad = constraint(q)[1]
q_grad = self._fr.jacobian(q).T @ c_grad
return q_grad

def c_f(q):
diff_q = q - q0
return diff_q @ diff_q

def c_df(q):
diff_q = q - q0
return 0.5 * diff_q

c_joint_limits = LinearConstraint(np.eye(len(q0)), self._fr.joint_limits_low, self._fr.joint_limits_high)
c_close_to_q0 = NonlinearConstraint(c_f, 0, self._q_step_size ** 2, jac=c_df)

res = minimize(f, q0, jac=df, method='SLSQP', tol=0.1,
constraints=(c_joint_limits, c_close_to_q0))

return res.x

def _is_seg_valid(self, q0, q1):
qs = np.linspace(q0, q1, int(np.linalg.norm(q1 - q0) / self._q_step_size))
for q in qs:
if self._is_in_collision(q):
return False
return True

def extend(self, tree_0, tree_1, constraint=None):
'''
TODO: Implement extend for RRT Connect
- Only perform self.project_to_constraint if constraint is not None
- Use self._is_seg_valid, self._q_step_size, self._connect_dist
'''
target_reached = False
node_id_new = None
is_collision = True

while is_collision:
q_sample = self.sample_valid_joints()

node_id_near = tree.get_nearest_node(q_sample)[0]
q_near = tree.get_point(node_id_near)
q_new = q_near + min(self._q_step_size, np.linalg.norm(q_sample - q_near)) * (q_sample - q_near) / np.linalg.norm(q_sample - q_near)

q_new = self.project_to_constraint(q_new, constraint)

if self._is_in_collision(q_new):
is_collision = True
continue
else:
is_collision = False

# Add the q_new as vertex, and the edge between q_new and q_near as edge to the tree
node_id_new = tree_0.insert_new_node(q_new, node_id_near)
node_id_1 = tree_1.get_nearest_node(q_new)[0]
q_1 = tree.get_point(node_id_1)
# if the new state is close to the target state, then we reached the target state
if np.linalg.norm(q_new - q_1) < self._connect_dist and self._is_seg_valid(q_new, q_1):
target_reached = True

return target_reached, node_id_new, node_id_1

def plan(self, q_start, q_target, constraint=None):
tree_0 = SimpleTree(len(q_start))
tree_0.insert_new_node(q_start)

tree_1 = SimpleTree(len(q_target))
tree_1.insert_new_node(q_target)

q_start_is_tree_0 = True

s = time()
for n_nodes_sampled in range(self._max_n_nodes):
if n_nodes_sampled > 0 and n_nodes_sampled % 100 == 0:
print('RRT: Sampled {} nodes'.format(n_nodes_sampled))

reached_target, node_id_new, node_id_1 = self.extend(tree_0, tree_1, constraint)

if reached_target:
break

q_start_is_tree_0 = not q_start_is_tree_0
tree_0, tree_1 = tree_1, tree_0

print('RRT: Sampled {} nodes in {:.2f}s'.format(n_nodes_sampled, time() - s))

if not q_start_is_tree_0:
tree_0, tree_1 = tree_1, tree_0

if reached_target:
tree_0_backward_path = tree_0.construct_path_to_root(node_id_new)
tree_1_forward_path = tree_1.construct_path_to_root(node_id_1)

q0 = tree_0_backward_path[0]
q1 = tree_1_forward_path[0]
tree_01_connect_path = np.linspace(q0, q1, int(np.linalg.norm(q1 - q0) / self._q_step_size))[1:].tolist()

path = tree_0_backward_path[::-1] + tree_01_connect_path + tree_1_forward_path
print('RRT: Found a path! Path length is {}.'.format(len(path)))
else:
path = []
print('RRT: Was not able to find a path!')

return np.array(path).tolist()
131 changes: 131 additions & 0 deletions Path Planning/test_rrt_connect.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
import argparse
import numpy as np
from time import sleep
import rospy
from tqdm import tqdm
from frankapy import FrankaArm

from franka_robot import FrankaRobot
from collision_boxes_publisher import CollisionBoxesPublisher
from rrt_connect import RRTConnect


if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--run_on_robot', action='store_true')
parser.add_argument('--seed', '-s', type=int, default=0)
args = parser.parse_args()

np.random.seed(args.seed)
fr = FrankaRobot()

if args.run_on_robot:
fa = FrankaArm()
else:
rospy.init_node('rrt')

'''
TODO: Replace obstacle box w/ the box specs in your workspace:
[x, y, z, r, p, y, sx, sy, sz]
'''
boxes = np.array([
# obstacle
[0, 0, 0, 0, 0, 0, 0, 0, 0],
# sides
[0.15, 0.46, 0.5, 0, 0, 0, 1.2, 0.01, 1.1],
[0.15, -0.46, 0.5, 0, 0, 0, 1.2, 0.01, 1.1],
# back
[-0.41, 0, 0.5, 0, 0, 0, 0.01, 1, 1.1],
# front
[0.75, 0, 0.5, 0, 0, 0, 0.01, 1, 1.1],
# top
[0.2, 0, 1, 0, 0, 0, 1.2, 1, 0.01],
# bottom
[0.2, 0, -0.05, 0, 0, 0, 1.2, 1, 0.01]
])
def is_in_collision(joints):
for box in boxes:
if fr.check_box_collision(joints, box):
return True
return False

desired_ee_rp = fr.ee(fr.home_joints)[3:5]
def ee_upright_constraint(q):
'''
TODO: Implement constraint function and its gradient.
This constraint should enforce the end-effector stays upright.
Hint: Use the roll and pitch angle in desired_ee_rp. The end-effector is upright in its home state.
Input:
q - a joint configuration
Output:
err - a non-negative scalar that is 0 when the constraint is satisfied
grad - a vector of length 6, where the ith element is the derivative of err w.r.t. the ith element of ee
'''
ee = fr.ee(q)
err = np.sum((np.asarray(desired_ee_rp) - np.asarray(ee[3:5]))**2)
grad = np.asarray([0, 0, 0, 2*(ee[3]-desired_ee_rp[0]), 2*(ee[4]-desired_ee_rp[1]), 0])
return err, grad

'''
TODO: Fill in start and target joint positions
'''
joints_start = None
joints_target = None

rrtc = RRTConnect(fr, is_in_collision)
constraint = None # ee_upright_constraint
plan = rrtc.plan(joints_start, joints_target, constraint)

collision_boxes_publisher = CollisionBoxesPublisher('collision_boxes')
rate = rospy.Rate(10)
i = 0
while not rospy.is_shutdown():
rate.sleep()
joints = plan[i % len(plan)]
fr.publish_joints(joints)
fr.publish_collision_boxes(joints)
collision_boxes_publisher.publish_boxes(boxes)

i += 1
if args.run_on_robot:
if i == len(plan) - 1:
while True:
inp = input('Would you like to [c]ontinue to execute the plan or [r]eplay the plan? ')
if inp in ('r', 'c'):
break
print('Please enter a valid input! Only c and r are accepted!')
if inp == 'r':
i = 0
else:
break

if args.run_on_robot:
while True:
input('Press [Enter] to run guide mode for 10s and move robot to near the strat configuration.')
fa.apply_effector_forces_torques(10, 0, 0, 0)

while True:
inp = input('Would you like to [c]ontinue or [r]erun guide mode? ')
if inp in ('r', 'c'):
break
print('Please enter a valid input! Only c and r are accepted!')

if inp == 'c':
break

print('Running plan...')
fa.goto_joints(joints_start)
forward_plan = plan[::4] # subsample plan by 1 in 4
backward_plan = forward_plan[::-1]

while True:
for joints in tqdm(forward_plan):
fa.goto_joints(joints, duration=max(float(max(joints - fa.get_joints()) / 0.1), 1))
sleep(0.1)
sleep(1)
for joints in tqdm(backward_plan):
fa.goto_joints(joints, duration=max(float(max(joints - fa.get_joints()) / 0.1), 1))
sleep(0.1)

0 comments on commit 20dd647

Please sign in to comment.