23
23
from libraries .latentsafesets .policy import CEMSafeSetPolicy
24
24
from libraries .latentsafesets .utils import utils
25
25
from libraries .latentsafesets .utils import plot_utils as pu
26
- from libraries .latentsafesets .utils .arg_parser import parse_args
27
26
from libraries .latentsafesets .rl_trainers import MPCTrainer
28
27
from libraries .safe import SimplePointBot as SPB
29
28
from gym .wrappers import FrameStack
32
31
def make_env (cfg ):
33
32
# create env
34
33
if cfg .obs_type == 'pixels' :
35
- env = SPB (from_pixels = cfg . obs_type )
34
+ env = SPB (from_pixels = True )
36
35
elif cfg .obs_type == 'states' :
37
- env = SPB (from_pixels = cfg . obs_type )
36
+ env = SPB (from_pixels = False )
38
37
else :
39
38
print (f'obs_type: { cfg .obs_type } is not valid should be pixels or states' )
40
39
if cfg .frame_stack > 1 :
41
40
env = FrameStack (env , cfg .frame_stack )
42
41
return env
43
42
43
+
44
44
class Workspace :
45
45
def __init__ (self , cfg ):
46
46
self .work_dir = Path .cwd ()
@@ -72,7 +72,10 @@ def __init__(self, cfg):
72
72
self .goal_indicator = modules ['gi' ]
73
73
self .constraint_function = modules ['constr' ]
74
74
75
- self .replay_buffer = utils .load_replay_buffer (cfg , self .encoder )
75
+ if cfg .obs_type == 'pixels' :
76
+ self .replay_buffer = utils .load_replay_buffer (cfg , self .encoder )
77
+ else :
78
+ self .replay_buffer = utils .load_replay_buffer (cfg , None )
76
79
self .trainer = MPCTrainer (self .train_env , cfg , modules )
77
80
self .trainer .initial_train (self .replay_buffer )
78
81
print ('Creating Policy' )
@@ -110,26 +113,42 @@ def train(self):
110
113
done = False
111
114
112
115
# Maintain ground truth info for plotting purposes
113
- movie_traj = [{'obs' : obs .reshape ((- 1 , 3 , 64 , 64 ))[0 ]}]
116
+ if self .cfg .obs_type == 'pixels' :
117
+ movie_traj = [{'obs' : obs .reshape ((- 1 , 3 , 64 , 64 ))[0 ]}]
118
+ else :
119
+ image = self .train_env ._state_to_image (obs )
120
+ movie_traj = [{'obs' : image }]
114
121
traj_rews = []
115
122
constr_viol = False
116
123
succ = False
117
124
for idz in trange (self .horizon ):
125
+ # TODO: Check if this is needed when working from state inputs
118
126
action = self .policy .act (obs / 255 )
119
127
next_obs , reward , done , info = self .train_env .step (action )
120
128
next_obs = np .array (next_obs )
121
- movie_traj .append ({'obs' : next_obs .reshape ((- 1 , 3 , 64 , 64 ))[0 ]})
129
+ if self .cfg .obs_type == 'pixels' :
130
+ movie_traj .append ({'obs' : next_obs .reshape ((- 1 , 3 , 64 , 64 ))[0 ]})
131
+ else :
132
+ image = self .train_env ._state_to_image (obs )
133
+ movie_traj .append ({'obs' : image })
122
134
traj_rews .append (reward )
123
135
124
136
constr = info ['constraint' ]
125
- transition = {'obs' : obs , 'action' : action , 'reward' : reward ,
126
- 'next_obs' : next_obs , 'done' : done ,
127
- 'constraint' : constr , 'safe_set' : 0 , 'on_policy' : 1 , 'discount' : self .discount }
137
+ if self .cfg .obs_type == 'pixels' :
138
+ transition = {'obs' : obs , 'action' : action , 'reward' : reward ,
139
+ 'next_obs' : next_obs , 'done' : done ,
140
+ 'constraint' : constr , 'safe_set' : 0 ,
141
+ 'on_policy' : 1 , 'discount' : self .discount }
142
+ else :
143
+ transition = {'obs' : obs , 'action' : action , 'reward' : reward ,
144
+ 'next_obs' : next_obs , 'done' : done ,
145
+ 'constraint' : constr , 'safe_set' : 0 ,
146
+ 'on_policy' : 1 }
128
147
transitions .append (transition )
129
148
obs = next_obs
130
149
constr_viol = constr_viol or info ['constraint' ]
131
150
succ = succ or reward == 0
132
- if done :
151
+ if done :
133
152
break
134
153
transitions [- 1 ]['done' ] = 1
135
154
traj_reward = sum (traj_rews )
@@ -154,6 +173,7 @@ def train(self):
154
173
155
174
rtg = rtg + transition ['reward' ]
156
175
176
+
157
177
self .replay_buffer .store_transitions (transitions )
158
178
update_rewards .append (traj_reward )
159
179
0 commit comments