1
+ #!/usr/bin/env python
2
+
3
+ import rospy
4
+ from random import randint
5
+ #import time
6
+ from geometry_msgs .msg import Twist , Vector3 , Point , Quaternion , Pose2D
7
+ from sensor_msgs .msg import LaserScan
8
+ from std_msgs .msg import String , Int32MultiArray , Bool
9
+ from nav_msgs .msg import OccupancyGrid
10
+ from std_srvs .srv import Empty
11
+ from scipy .spatial import distance
12
+ import copy
13
+ import pickle
14
+ from getpass import getuser
15
+ from nav_msgs .msg import Odometry
16
+ import sys
17
+ import numpy as np
18
+ #from os import getcwd
19
+ from datetime import datetime
20
+ import pandas as pd
21
+ import os
22
+
23
+ ## Global Variables
24
+ # mobile_base velocity publisher
25
+ command_pub = None
26
+ # String command that will be sent to the robot
27
+ cmd_msg = String ()
28
+
29
+ # current time
30
+ #cur_time = 0
31
+ # current state. can be either "init", "drive, "halt", "turn_r", "turn_l"
32
+ cur_state = "init"
33
+
34
+ # Are we training or executing?
35
+ train = None
36
+ map_name = ""
37
+
38
+ vel = Twist ()
39
+ current_position = Pose2D (- 2 ,- 2 ,0 )
40
+
41
+ goal_point = Pose2D (4 ,4 ,0 ) # Weird coord transform error
42
+ path = []
43
+ global_map = np .genfromtxt ("/home/lelliott/turtlepath/rl-ws/map.csv" , delimiter = ',' )
44
+ map = {}
45
+
46
+ finished = False
47
+
48
+
49
+ def reset_stage ():
50
+ global current_position
51
+ current_position = Pose2D (- 2 ,- 2 ,0 )
52
+
53
+ def delete_q (map_name ):
54
+ if os .path .exists (map_name + ".sarsa" ):
55
+ os .remove (map_name + ".sarsa" )
56
+ else :
57
+ print ("The file does not exist" )
58
+
59
+ def master_train (map_name ,goal_point ,train ):
60
+ global map
61
+ delete_q (map_name )
62
+ # train several times in real life to build a map
63
+ count = 0
64
+ accelerate = True
65
+ episode_num = 50
66
+ dt = datetime .now ()
67
+ training_data = [[i ,0 , True ] for i in range (episode_num )]
68
+ satisfied = False
69
+ s_count = 0
70
+ while not satisfied :
71
+ s_count += 1
72
+ count = 0
73
+ for i in range (0 ,episode_num ):
74
+ path , crashed = train_sarsa_real_life (map_name ,goal_point ,train , count ,episode_num )
75
+ print (len (path ))
76
+ reset_stage ()
77
+
78
+ training_data [int (i )][1 ] = len (path )
79
+ training_data [int (i )][2 ] = crashed
80
+ count += 1
81
+ if (len (path ) > 16 ):
82
+ satisfied = False
83
+ else :
84
+ print ("Took: " + str (s_count * episode_num ))
85
+ satisfied = True
86
+
87
+ # save data each count from training to use for plots and analysis
88
+ filepath = "/home/" + getuser ()+ "/turtlepath/rl-ws/data/"
89
+ filename = "sarsa_" + map_name + "_" + dt .strftime ("%Y-%m-%d-%H-%M-%S" ) + "_c" + str (count )
90
+ np .savetxt (filepath + filename + ".csv" , training_data , delimiter = "," )
91
+
92
+
93
+ return path
94
+
95
+ def retrieve_q (map_name ):
96
+ # Save out data to map_name.sarsa so we don't have to retrain on familiar maps
97
+ #pickle.dump(map, open("sarsa_data/"+map_name + ".sarsa","wb"))
98
+ try :
99
+ q = pickle .load ( open (map_name + ".sarsa" , "rb" ) )
100
+ q = dict (q )
101
+ except (OSError , IOError ) as e :
102
+ # There are states(x,y) and 4 Actions (up, down, right, left)
103
+ q = {}
104
+ return q
105
+
106
+ def write_q (map_name , q ):
107
+ pickle .dump (q , open (map_name + ".sarsa" ,"wb" ))
108
+
109
+
110
+ def train_sarsa_real_life (map_name , goal_point , train , count , max_count ):
111
+ # set date which is used in data output filenames
112
+ dt = datetime .now ()
113
+ global current_position
114
+ alpha = .6
115
+ gamma = .99
116
+ eps = .2
117
+
118
+ q = retrieve_q (map_name )
119
+
120
+ if (count < 2 ):
121
+ eps = 1 # totally random walk to build a better model
122
+ elif (count == max_count - 1 ):
123
+ eps = .05
124
+
125
+ timeout = 0
126
+ path = []
127
+ # a is action
128
+ # 0 is North
129
+ # 1 is East
130
+ # 2 is South
131
+ # 3 is West
132
+ # s is state and equals a point in the map
133
+ s = to_str (current_position )
134
+ # Choose A from S using policy dervied from Q (e.g. e-greedy)
135
+ a = e_greedy (eps , q , s )
136
+ path .append (a )
137
+ # Loop through episode
138
+ crashed = False
139
+ while timeout < 5000 and s != to_str (goal_point ) and not crashed :
140
+ # Take action A, observe R,S'
141
+ r , s_prime , crashed = execute_rl (a ,s )
142
+ # Choose A' from S' using policy dervied from Q (e.g. e-greedy)
143
+ a_prime = e_greedy (eps , q , s_prime )
144
+
145
+ if not (s in q ):
146
+ d = {0 :0 , 1 :0 , 2 :0 , 3 :0 }
147
+ q [s ] = d
148
+
149
+ if not (s_prime in q ):
150
+ d = {0 :0 , 1 :0 , 2 :0 , 3 :0 }
151
+ q [s_prime ] = d
152
+
153
+ # Q(S,A) <- Q(S,A) + alpha[R+gamma * Q(S',A')- Q(S,A)]
154
+ expectation = 0
155
+ greedy_actions = 0
156
+ q_max = np .max ([q [s_prime ][action ] for action in q .get (s_prime )])
157
+ for actions in q .get (s_prime ):
158
+ if (q [s_prime ][actions ] == q_max ):
159
+ greedy_actions += 1
160
+ ng_prob = eps / 4
161
+ g_prob = (1 - eps )/ greedy_actions + ng_prob
162
+
163
+ for actions in q .get (s_prime ):
164
+ if (q [s_prime ][actions ] == q_max ):
165
+ expectation += q [s_prime ][actions ] * g_prob
166
+ else :
167
+ expectation += q [s_prime ][actions ] * ng_prob
168
+
169
+ q [s ][a ] = q [s ][a ] + alpha * (r + gamma * expectation - q [s ][a ])
170
+
171
+ # S<- S'; A<-A';
172
+ s = s_prime
173
+ a = a_prime
174
+ path .append (a )
175
+ timeout += 1
176
+
177
+ print (path )
178
+ print ("Crashed: " + str (crashed ))
179
+ write_q (map_name ,q )
180
+ return path , crashed
181
+
182
+ def to_str (s ):
183
+ return ("X: " + str (s .x ) + " " + "Y: " + str (s .y ))
184
+
185
+ def decode_action (a ):
186
+ # a is action
187
+ # 0 is North
188
+ # 1 is East
189
+ # 2 is South
190
+ # 3 is West
191
+
192
+ if a == 0 :
193
+ action = "north"
194
+ elif a == 1 :
195
+ action = "east"
196
+ elif a == 2 :
197
+ action = "south"
198
+ else :
199
+ action = "west"
200
+ return action
201
+
202
+ def e_greedy (eps , q , s ):
203
+ # If we have been in this state try to epsilon greedy it
204
+ if s in q :
205
+ p = np .random .random ()
206
+ if p < eps :
207
+ a = np .random .randint (0 ,4 )
208
+ else :
209
+ a = np .argmax ([q [s ][action ] for action in q .get (s )])
210
+ # If we've never been here we don't know where to go, so it's random
211
+ else :
212
+ a = np .random .randint (0 ,4 )
213
+ return a
214
+
215
+ def decode_string (s ):
216
+ if (s [8 ] == '-' ):
217
+ y = - 1 * int (s [9 ])
218
+ elif (len (s ) >= 10 and s [9 ] == '-' ):
219
+ y = - 1 * int (s [10 ])
220
+ elif (s [3 ] == '-' ):
221
+ y = int (s [9 ])
222
+ else :
223
+ y = int (s [8 ])
224
+
225
+ if (s [3 ] == '-' ):
226
+ x = - 1 * int (s [4 ])
227
+ else :
228
+ x = int (s [3 ])
229
+
230
+ return x ,y
231
+
232
+ def execute_rl (a ,s ):
233
+ global map
234
+ global vel
235
+ global current_position
236
+ global finished
237
+ prev = copy .deepcopy (current_position )
238
+ #send_command(decode_action(a))
239
+ x ,y = decode_string (s )
240
+
241
+ if (a == 0 ):
242
+ x += 1
243
+ elif (a == 1 ):
244
+ y -= 1
245
+ elif (a == 2 ):
246
+ x -= 1
247
+ elif (a == 3 ):
248
+ y += 1
249
+
250
+ new_position = Pose2D (x ,y ,0 )
251
+ if not (check_global (new_position )):
252
+ current_position = new_position
253
+
254
+ crashed = False
255
+ #print("Next")
256
+ # return it's new location
257
+ reward = - 1
258
+ s_prime = to_str (current_position )
259
+ if (s == s_prime ):
260
+ reward = - 5
261
+ if (crashed ):
262
+ reward = - 10
263
+ #print("Start: "+s+ " Finish: " + s_prime)
264
+ #print(reward)
265
+ if not (s in map ): # If Q is not initalized for our action set it to 0
266
+ d = {}
267
+ d [a ] = {"state" :s_prime ,"reward" :reward }
268
+ map [s ] = d
269
+ elif not (a in map [s ]):
270
+ map [s ][a ] = {"state" :s_prime ,"reward" :reward }
271
+ else :
272
+ map [s ][a ] = {"state" :s_prime ,"reward" :reward }
273
+ return reward , s_prime , crashed
274
+
275
+ def check_global (point ):
276
+ global global_map
277
+ x = point .x + 5
278
+ y = (point .y * - 1 + 5 )
279
+ return bool (global_map [y ][x ])
280
+
281
+ master_train ("test1" , goal_point , True )
282
+ print ("done" )
0 commit comments