-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
0d8ca32
commit 20dd647
Showing
2 changed files
with
294 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,163 @@ | ||
from time import time | ||
import numpy as np | ||
from scipy.optimize import minimize, LinearConstraint, NonlinearConstraint | ||
|
||
from kdtree import KDTree | ||
from franka_robot import FrankaRobot | ||
|
||
|
||
class SimpleTree: | ||
|
||
def __init__(self, dim): | ||
self._parents_map = {} | ||
self._kd = KDTree(dim) | ||
|
||
def insert_new_node(self, point, parent=None): | ||
node_id = self._kd.insert(point) | ||
self._parents_map[node_id] = parent | ||
|
||
return node_id | ||
|
||
def get_parent(self, child_id): | ||
return self._parents_map[child_id] | ||
|
||
def get_point(self, node_id): | ||
return self._kd.get_node(node_id).point | ||
|
||
def get_nearest_node(self, point): | ||
return self._kd.find_nearest_point(point) | ||
|
||
def construct_path_to_root(self, leaf_node_id): | ||
path = [] | ||
node_id = leaf_node_id | ||
while node_id is not None: | ||
path.append(self.get_point(node_id)) | ||
node_id = self.get_parent(node_id) | ||
|
||
return path | ||
|
||
|
||
class RRTConnect: | ||
|
||
def __init__(self, fr, is_in_collision): | ||
self._fr = fr | ||
self._is_in_collision = is_in_collision | ||
|
||
self._q_step_size = 0.02 | ||
self._connect_dist = 0.8 | ||
self._max_n_nodes = int(1e5) | ||
|
||
def sample_valid_joints(self): | ||
q = np.random.random(self._fr.num_dof) * (self._fr.joint_limits_high - self._fr.joint_limits_low) + self._fr.joint_limits_low | ||
return q | ||
|
||
def project_to_constraint(self, q0, constraint): | ||
def f(q): | ||
return constraint(q)[0] | ||
|
||
def df(q): | ||
c_grad = constraint(q)[1] | ||
q_grad = self._fr.jacobian(q).T @ c_grad | ||
return q_grad | ||
|
||
def c_f(q): | ||
diff_q = q - q0 | ||
return diff_q @ diff_q | ||
|
||
def c_df(q): | ||
diff_q = q - q0 | ||
return 0.5 * diff_q | ||
|
||
c_joint_limits = LinearConstraint(np.eye(len(q0)), self._fr.joint_limits_low, self._fr.joint_limits_high) | ||
c_close_to_q0 = NonlinearConstraint(c_f, 0, self._q_step_size ** 2, jac=c_df) | ||
|
||
res = minimize(f, q0, jac=df, method='SLSQP', tol=0.1, | ||
constraints=(c_joint_limits, c_close_to_q0)) | ||
|
||
return res.x | ||
|
||
def _is_seg_valid(self, q0, q1): | ||
qs = np.linspace(q0, q1, int(np.linalg.norm(q1 - q0) / self._q_step_size)) | ||
for q in qs: | ||
if self._is_in_collision(q): | ||
return False | ||
return True | ||
|
||
def extend(self, tree_0, tree_1, constraint=None): | ||
''' | ||
TODO: Implement extend for RRT Connect | ||
- Only perform self.project_to_constraint if constraint is not None | ||
- Use self._is_seg_valid, self._q_step_size, self._connect_dist | ||
''' | ||
target_reached = False | ||
node_id_new = None | ||
is_collision = True | ||
|
||
while is_collision: | ||
q_sample = self.sample_valid_joints() | ||
|
||
node_id_near = tree.get_nearest_node(q_sample)[0] | ||
q_near = tree.get_point(node_id_near) | ||
q_new = q_near + min(self._q_step_size, np.linalg.norm(q_sample - q_near)) * (q_sample - q_near) / np.linalg.norm(q_sample - q_near) | ||
|
||
q_new = self.project_to_constraint(q_new, constraint) | ||
|
||
if self._is_in_collision(q_new): | ||
is_collision = True | ||
continue | ||
else: | ||
is_collision = False | ||
|
||
# Add the q_new as vertex, and the edge between q_new and q_near as edge to the tree | ||
node_id_new = tree_0.insert_new_node(q_new, node_id_near) | ||
node_id_1 = tree_1.get_nearest_node(q_new)[0] | ||
q_1 = tree.get_point(node_id_1) | ||
# if the new state is close to the target state, then we reached the target state | ||
if np.linalg.norm(q_new - q_1) < self._connect_dist and self._is_seg_valid(q_new, q_1): | ||
target_reached = True | ||
|
||
return target_reached, node_id_new, node_id_1 | ||
|
||
def plan(self, q_start, q_target, constraint=None): | ||
tree_0 = SimpleTree(len(q_start)) | ||
tree_0.insert_new_node(q_start) | ||
|
||
tree_1 = SimpleTree(len(q_target)) | ||
tree_1.insert_new_node(q_target) | ||
|
||
q_start_is_tree_0 = True | ||
|
||
s = time() | ||
for n_nodes_sampled in range(self._max_n_nodes): | ||
if n_nodes_sampled > 0 and n_nodes_sampled % 100 == 0: | ||
print('RRT: Sampled {} nodes'.format(n_nodes_sampled)) | ||
|
||
reached_target, node_id_new, node_id_1 = self.extend(tree_0, tree_1, constraint) | ||
|
||
if reached_target: | ||
break | ||
|
||
q_start_is_tree_0 = not q_start_is_tree_0 | ||
tree_0, tree_1 = tree_1, tree_0 | ||
|
||
print('RRT: Sampled {} nodes in {:.2f}s'.format(n_nodes_sampled, time() - s)) | ||
|
||
if not q_start_is_tree_0: | ||
tree_0, tree_1 = tree_1, tree_0 | ||
|
||
if reached_target: | ||
tree_0_backward_path = tree_0.construct_path_to_root(node_id_new) | ||
tree_1_forward_path = tree_1.construct_path_to_root(node_id_1) | ||
|
||
q0 = tree_0_backward_path[0] | ||
q1 = tree_1_forward_path[0] | ||
tree_01_connect_path = np.linspace(q0, q1, int(np.linalg.norm(q1 - q0) / self._q_step_size))[1:].tolist() | ||
|
||
path = tree_0_backward_path[::-1] + tree_01_connect_path + tree_1_forward_path | ||
print('RRT: Found a path! Path length is {}.'.format(len(path))) | ||
else: | ||
path = [] | ||
print('RRT: Was not able to find a path!') | ||
|
||
return np.array(path).tolist() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,131 @@ | ||
import argparse | ||
import numpy as np | ||
from time import sleep | ||
import rospy | ||
from tqdm import tqdm | ||
from frankapy import FrankaArm | ||
|
||
from franka_robot import FrankaRobot | ||
from collision_boxes_publisher import CollisionBoxesPublisher | ||
from rrt_connect import RRTConnect | ||
|
||
|
||
if __name__ == '__main__': | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument('--run_on_robot', action='store_true') | ||
parser.add_argument('--seed', '-s', type=int, default=0) | ||
args = parser.parse_args() | ||
|
||
np.random.seed(args.seed) | ||
fr = FrankaRobot() | ||
|
||
if args.run_on_robot: | ||
fa = FrankaArm() | ||
else: | ||
rospy.init_node('rrt') | ||
|
||
''' | ||
TODO: Replace obstacle box w/ the box specs in your workspace: | ||
[x, y, z, r, p, y, sx, sy, sz] | ||
''' | ||
boxes = np.array([ | ||
# obstacle | ||
[0, 0, 0, 0, 0, 0, 0, 0, 0], | ||
# sides | ||
[0.15, 0.46, 0.5, 0, 0, 0, 1.2, 0.01, 1.1], | ||
[0.15, -0.46, 0.5, 0, 0, 0, 1.2, 0.01, 1.1], | ||
# back | ||
[-0.41, 0, 0.5, 0, 0, 0, 0.01, 1, 1.1], | ||
# front | ||
[0.75, 0, 0.5, 0, 0, 0, 0.01, 1, 1.1], | ||
# top | ||
[0.2, 0, 1, 0, 0, 0, 1.2, 1, 0.01], | ||
# bottom | ||
[0.2, 0, -0.05, 0, 0, 0, 1.2, 1, 0.01] | ||
]) | ||
def is_in_collision(joints): | ||
for box in boxes: | ||
if fr.check_box_collision(joints, box): | ||
return True | ||
return False | ||
|
||
desired_ee_rp = fr.ee(fr.home_joints)[3:5] | ||
def ee_upright_constraint(q): | ||
''' | ||
TODO: Implement constraint function and its gradient. | ||
This constraint should enforce the end-effector stays upright. | ||
Hint: Use the roll and pitch angle in desired_ee_rp. The end-effector is upright in its home state. | ||
Input: | ||
q - a joint configuration | ||
Output: | ||
err - a non-negative scalar that is 0 when the constraint is satisfied | ||
grad - a vector of length 6, where the ith element is the derivative of err w.r.t. the ith element of ee | ||
''' | ||
ee = fr.ee(q) | ||
err = np.sum((np.asarray(desired_ee_rp) - np.asarray(ee[3:5]))**2) | ||
grad = np.asarray([0, 0, 0, 2*(ee[3]-desired_ee_rp[0]), 2*(ee[4]-desired_ee_rp[1]), 0]) | ||
return err, grad | ||
|
||
''' | ||
TODO: Fill in start and target joint positions | ||
''' | ||
joints_start = None | ||
joints_target = None | ||
|
||
rrtc = RRTConnect(fr, is_in_collision) | ||
constraint = None # ee_upright_constraint | ||
plan = rrtc.plan(joints_start, joints_target, constraint) | ||
|
||
collision_boxes_publisher = CollisionBoxesPublisher('collision_boxes') | ||
rate = rospy.Rate(10) | ||
i = 0 | ||
while not rospy.is_shutdown(): | ||
rate.sleep() | ||
joints = plan[i % len(plan)] | ||
fr.publish_joints(joints) | ||
fr.publish_collision_boxes(joints) | ||
collision_boxes_publisher.publish_boxes(boxes) | ||
|
||
i += 1 | ||
if args.run_on_robot: | ||
if i == len(plan) - 1: | ||
while True: | ||
inp = input('Would you like to [c]ontinue to execute the plan or [r]eplay the plan? ') | ||
if inp in ('r', 'c'): | ||
break | ||
print('Please enter a valid input! Only c and r are accepted!') | ||
if inp == 'r': | ||
i = 0 | ||
else: | ||
break | ||
|
||
if args.run_on_robot: | ||
while True: | ||
input('Press [Enter] to run guide mode for 10s and move robot to near the strat configuration.') | ||
fa.apply_effector_forces_torques(10, 0, 0, 0) | ||
|
||
while True: | ||
inp = input('Would you like to [c]ontinue or [r]erun guide mode? ') | ||
if inp in ('r', 'c'): | ||
break | ||
print('Please enter a valid input! Only c and r are accepted!') | ||
|
||
if inp == 'c': | ||
break | ||
|
||
print('Running plan...') | ||
fa.goto_joints(joints_start) | ||
forward_plan = plan[::4] # subsample plan by 1 in 4 | ||
backward_plan = forward_plan[::-1] | ||
|
||
while True: | ||
for joints in tqdm(forward_plan): | ||
fa.goto_joints(joints, duration=max(float(max(joints - fa.get_joints()) / 0.1), 1)) | ||
sleep(0.1) | ||
sleep(1) | ||
for joints in tqdm(backward_plan): | ||
fa.goto_joints(joints, duration=max(float(max(joints - fa.get_joints()) / 0.1), 1)) | ||
sleep(0.1) |