Skip to content

Commit 01040eb

Browse files
author
Logan
committed
sarsa updates
1 parent c050775 commit 01040eb

File tree

6 files changed

+992
-21
lines changed

6 files changed

+992
-21
lines changed
-2 Bytes
Loading

rl-ws/src/ml_long_project/src/control_node.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
rear_scan = None
2525
l_scan = None
2626
r_scan = None
27+
off_left = None
28+
off_right = None
2729
# Twist command that will be sent to the robot
2830
cmd = Twist()
2931

@@ -50,7 +52,7 @@
5052
command_list = []
5153

5254
def check_scan(scan_msg):
53-
global fwd_scan, rear_scan, r_scan, l_scan
55+
global fwd_scan, rear_scan, r_scan, l_scan,off_right,off_left
5456
# scan_msg.ranges is an array of 640 elements representing
5557
# distance measurements in a full circle around the robot (0=fwd, CCW?)
5658

@@ -62,6 +64,9 @@ def check_scan(scan_msg):
6264
r45_scan = scan_discrete(scan_msg.ranges[560]) # ~45 degrees right
6365
r_scan = scan_discrete(scan_msg.ranges[480]) # 90 degrees right
6466

67+
off_left = scan_discrete(scan_msg.ranges[40])
68+
off_right = scan_discrete(scan_msg.ranges[600])
69+
6570
# group and publish the relevant scan ranges
6671
scan_group = Int32MultiArray()
6772
scan_group.data = [fwd_scan, rear_scan, l45_scan, l_scan, r45_scan, r_scan]
@@ -155,11 +160,12 @@ def reset_stage():
155160
current_cmd = ""
156161

157162
def is_cmd_valid(str_cmd):
163+
global fwd_scan, rear_scan, r_scan, l_scan,off_right,off_left
158164
# check to ensure a given command (string) will not cause
159165
# the robot to move to an occupied vertex.
160166
if str_cmd == "forward":
161167
print("Trying to Move Forward", fwd_scan)
162-
return fwd_scan > 1
168+
return fwd_scan > 1.75 and off_right > 1 and off_left > 1
163169
elif str_cmd == "turn_left" or str_cmd == "turn_right" or str_cmd == "turn_180":
164170
print("Turning in place")
165171
return True
Lines changed: 282 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,282 @@
1+
#!/usr/bin/env python
2+
3+
import rospy
4+
from random import randint
5+
#import time
6+
from geometry_msgs.msg import Twist, Vector3, Point, Quaternion, Pose2D
7+
from sensor_msgs.msg import LaserScan
8+
from std_msgs.msg import String, Int32MultiArray, Bool
9+
from nav_msgs.msg import OccupancyGrid
10+
from std_srvs.srv import Empty
11+
from scipy.spatial import distance
12+
import copy
13+
import pickle
14+
from getpass import getuser
15+
from nav_msgs.msg import Odometry
16+
import sys
17+
import numpy as np
18+
#from os import getcwd
19+
from datetime import datetime
20+
import pandas as pd
21+
import os
22+
23+
## Global Variables
24+
# mobile_base velocity publisher
25+
command_pub = None
26+
# String command that will be sent to the robot
27+
cmd_msg = String()
28+
29+
# current time
30+
#cur_time = 0
31+
# current state. can be either "init", "drive, "halt", "turn_r", "turn_l"
32+
cur_state = "init"
33+
34+
# Are we training or executing?
35+
train = None
36+
map_name = ""
37+
38+
vel = Twist()
39+
current_position = Pose2D(-2,-2,0)
40+
41+
goal_point = Pose2D(4,4,0) # Weird coord transform error
42+
path = []
43+
global_map =np.genfromtxt("/home/lelliott/turtlepath/rl-ws/map.csv", delimiter=',')
44+
map = {}
45+
46+
finished = False
47+
48+
49+
def reset_stage():
50+
global current_position
51+
current_position = Pose2D(-2,-2,0)
52+
53+
def delete_q(map_name):
54+
if os.path.exists(map_name + ".sarsa"):
55+
os.remove(map_name + ".sarsa")
56+
else:
57+
print("The file does not exist")
58+
59+
def master_train(map_name,goal_point,train):
60+
global map
61+
delete_q(map_name)
62+
# train several times in real life to build a map
63+
count = 0
64+
accelerate = True
65+
episode_num = 50
66+
dt = datetime.now()
67+
training_data = [[i,0, True] for i in range(episode_num)]
68+
satisfied = False
69+
s_count = 0
70+
while not satisfied:
71+
s_count +=1
72+
count = 0
73+
for i in range(0,episode_num):
74+
path, crashed = train_sarsa_real_life(map_name,goal_point,train, count,episode_num)
75+
print(len(path))
76+
reset_stage()
77+
78+
training_data[int(i)][1] = len(path)
79+
training_data[int(i)][2] = crashed
80+
count += 1
81+
if(len(path) > 16):
82+
satisfied = False
83+
else:
84+
print("Took: " + str(s_count*episode_num))
85+
satisfied = True
86+
87+
# save data each count from training to use for plots and analysis
88+
filepath = "/home/"+getuser()+"/turtlepath/rl-ws/data/"
89+
filename = "sarsa_" + map_name + "_" + dt.strftime("%Y-%m-%d-%H-%M-%S") + "_c" + str(count)
90+
np.savetxt(filepath + filename + ".csv", training_data, delimiter=",")
91+
92+
93+
return path
94+
95+
def retrieve_q(map_name):
96+
# Save out data to map_name.sarsa so we don't have to retrain on familiar maps
97+
#pickle.dump(map, open("sarsa_data/"+map_name + ".sarsa","wb"))
98+
try:
99+
q = pickle.load( open(map_name + ".sarsa", "rb" ) )
100+
q = dict(q)
101+
except (OSError, IOError) as e:
102+
# There are states(x,y) and 4 Actions (up, down, right, left)
103+
q = {}
104+
return q
105+
106+
def write_q(map_name, q):
107+
pickle.dump(q, open(map_name + ".sarsa","wb"))
108+
109+
110+
def train_sarsa_real_life(map_name, goal_point, train, count, max_count):
111+
# set date which is used in data output filenames
112+
dt = datetime.now()
113+
global current_position
114+
alpha = .6
115+
gamma = .99
116+
eps = .2
117+
118+
q = retrieve_q(map_name)
119+
120+
if(count < 2):
121+
eps = 1 # totally random walk to build a better model
122+
elif(count == max_count -1):
123+
eps = .05
124+
125+
timeout = 0
126+
path = []
127+
# a is action
128+
# 0 is North
129+
# 1 is East
130+
# 2 is South
131+
# 3 is West
132+
# s is state and equals a point in the map
133+
s = to_str(current_position)
134+
# Choose A from S using policy dervied from Q (e.g. e-greedy)
135+
a = e_greedy(eps, q, s)
136+
path.append(a)
137+
# Loop through episode
138+
crashed = False
139+
while timeout < 5000 and s != to_str(goal_point) and not crashed:
140+
# Take action A, observe R,S'
141+
r, s_prime, crashed = execute_rl(a,s)
142+
# Choose A' from S' using policy dervied from Q (e.g. e-greedy)
143+
a_prime = e_greedy(eps, q, s_prime)
144+
145+
if not (s in q):
146+
d={0:0, 1:0, 2:0, 3:0}
147+
q[s] = d
148+
149+
if not (s_prime in q):
150+
d={0:0, 1:0, 2:0, 3:0}
151+
q[s_prime] = d
152+
153+
# Q(S,A) <- Q(S,A) + alpha[R+gamma * Q(S',A')- Q(S,A)]
154+
expectation = 0
155+
greedy_actions = 0
156+
q_max = np.max([q[s_prime][action] for action in q.get(s_prime)])
157+
for actions in q.get(s_prime):
158+
if(q[s_prime][actions] == q_max):
159+
greedy_actions +=1
160+
ng_prob = eps/4
161+
g_prob = (1-eps)/greedy_actions + ng_prob
162+
163+
for actions in q.get(s_prime):
164+
if(q[s_prime][actions] == q_max):
165+
expectation += q[s_prime][actions] * g_prob
166+
else:
167+
expectation += q[s_prime][actions] * ng_prob
168+
169+
q[s][a] = q[s][a] + alpha * (r+gamma*expectation - q[s][a])
170+
171+
# S<- S'; A<-A';
172+
s = s_prime
173+
a = a_prime
174+
path.append(a)
175+
timeout+=1
176+
177+
print(path)
178+
print("Crashed: " + str(crashed))
179+
write_q(map_name,q)
180+
return path, crashed
181+
182+
def to_str(s):
183+
return("X: "+ str(s.x) + " "+"Y: "+ str(s.y))
184+
185+
def decode_action(a):
186+
# a is action
187+
# 0 is North
188+
# 1 is East
189+
# 2 is South
190+
# 3 is West
191+
192+
if a == 0:
193+
action = "north"
194+
elif a == 1:
195+
action = "east"
196+
elif a == 2:
197+
action = "south"
198+
else:
199+
action = "west"
200+
return action
201+
202+
def e_greedy(eps, q, s):
203+
# If we have been in this state try to epsilon greedy it
204+
if s in q:
205+
p = np.random.random()
206+
if p < eps:
207+
a = np.random.randint(0,4)
208+
else:
209+
a = np.argmax([q[s][action] for action in q.get(s)])
210+
# If we've never been here we don't know where to go, so it's random
211+
else:
212+
a = np.random.randint(0,4)
213+
return a
214+
215+
def decode_string(s):
216+
if(s[8] == '-'):
217+
y = -1 * int(s[9])
218+
elif(len(s) >= 10 and s[9] == '-'):
219+
y = -1 * int(s[10])
220+
elif(s[3] == '-'):
221+
y = int(s[9])
222+
else:
223+
y = int(s[8])
224+
225+
if(s[3] == '-'):
226+
x = -1 * int(s[4])
227+
else:
228+
x = int(s[3])
229+
230+
return x,y
231+
232+
def execute_rl(a,s):
233+
global map
234+
global vel
235+
global current_position
236+
global finished
237+
prev = copy.deepcopy(current_position)
238+
#send_command(decode_action(a))
239+
x,y = decode_string(s)
240+
241+
if(a == 0):
242+
x +=1
243+
elif(a == 1):
244+
y -= 1
245+
elif(a == 2):
246+
x -=1
247+
elif(a ==3):
248+
y +=1
249+
250+
new_position = Pose2D(x,y,0)
251+
if not (check_global(new_position)):
252+
current_position = new_position
253+
254+
crashed = False
255+
#print("Next")
256+
# return it's new location
257+
reward = -1
258+
s_prime = to_str(current_position)
259+
if(s == s_prime):
260+
reward = -5
261+
if(crashed):
262+
reward = -10
263+
#print("Start: "+s+ " Finish: " + s_prime)
264+
#print(reward)
265+
if not (s in map): # If Q is not initalized for our action set it to 0
266+
d={}
267+
d[a] = {"state":s_prime,"reward":reward}
268+
map[s] = d
269+
elif not (a in map[s]):
270+
map[s][a] = {"state":s_prime,"reward":reward}
271+
else:
272+
map[s][a] = {"state":s_prime,"reward":reward}
273+
return reward, s_prime, crashed
274+
275+
def check_global(point):
276+
global global_map
277+
x = point.x + 5
278+
y = (point.y*-1 + 5)
279+
return bool(global_map[y][x])
280+
281+
master_train("test1", goal_point, True)
282+
print("done")

0 commit comments

Comments
 (0)