Skip to content

Commit 87c4345

Browse files
committed
changed obs from state input
1 parent 7686169 commit 87c4345

File tree

2 files changed

+35
-15
lines changed

2 files changed

+35
-15
lines changed

configs/mpc.yaml

+5-5
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ defaults:
66
# Task Settings
77
task: SimplePointBot_goal
88
env: SimplePointBot
9-
obs_type: pixels # [states, pixels]
9+
obs_type: states # [states, pixels]
1010
frame_stack: 1
1111
action_repeat: 1
1212
discount: 1.0
@@ -18,8 +18,8 @@ log_dir: ./logs
1818
# Module Settings
1919
# Encoder
2020
enc_checkpoint: ../../../models/spb/vae.pth
21-
d_latent: 32
22-
d_obs: [3, 64, 64]
21+
d_latent: 2
22+
d_obs: [2]
2323
enc_init_iters: 100000
2424
enc_kl_multiplier: 1e-6
2525
enc_data_aug: false
@@ -97,8 +97,8 @@ constr_hidden_size: 200
9797
constr_lr: 1e-4
9898

9999
# Replay Buffer
100-
data_dirs: datasets/pixels/SimplePointBot/diayn/buffer
101-
data_counts: 1000
100+
data_dirs: datasets/states/SimplePointBot/controller/SimplePointBot
101+
data_counts: 200
102102
buffer_size: 35000
103103

104104
# Misc Settings

train_mpc.py

+30-10
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
from libraries.latentsafesets.policy import CEMSafeSetPolicy
2424
from libraries.latentsafesets.utils import utils
2525
from libraries.latentsafesets.utils import plot_utils as pu
26-
from libraries.latentsafesets.utils.arg_parser import parse_args
2726
from libraries.latentsafesets.rl_trainers import MPCTrainer
2827
from libraries.safe import SimplePointBot as SPB
2928
from gym.wrappers import FrameStack
@@ -32,15 +31,16 @@
3231
def make_env(cfg):
3332
# create env
3433
if cfg.obs_type=='pixels':
35-
env = SPB(from_pixels=cfg.obs_type)
34+
env = SPB(from_pixels=True)
3635
elif cfg.obs_type=='states':
37-
env = SPB(from_pixels=cfg.obs_type)
36+
env = SPB(from_pixels=False)
3837
else:
3938
print(f'obs_type: {cfg.obs_type} is not valid should be pixels or states')
4039
if cfg.frame_stack > 1:
4140
env = FrameStack(env, cfg.frame_stack)
4241
return env
4342

43+
4444
class Workspace:
4545
def __init__(self, cfg):
4646
self.work_dir = Path.cwd()
@@ -72,7 +72,10 @@ def __init__(self, cfg):
7272
self.goal_indicator = modules['gi']
7373
self.constraint_function = modules['constr']
7474

75-
self.replay_buffer = utils.load_replay_buffer(cfg, self.encoder)
75+
if cfg.obs_type == 'pixels':
76+
self.replay_buffer = utils.load_replay_buffer(cfg, self.encoder)
77+
else:
78+
self.replay_buffer = utils.load_replay_buffer(cfg, None)
7679
self.trainer = MPCTrainer(self.train_env, cfg, modules)
7780
self.trainer.initial_train(self.replay_buffer)
7881
print('Creating Policy')
@@ -110,26 +113,42 @@ def train(self):
110113
done = False
111114

112115
# Maintain ground truth info for plotting purposes
113-
movie_traj = [{'obs': obs.reshape((-1, 3, 64, 64))[0]}]
116+
if self.cfg.obs_type == 'pixels':
117+
movie_traj = [{'obs': obs.reshape((-1, 3, 64, 64))[0]}]
118+
else:
119+
image = self.train_env._state_to_image(obs)
120+
movie_traj = [{'obs': image}]
114121
traj_rews = []
115122
constr_viol = False
116123
succ = False
117124
for idz in trange(self.horizon):
125+
# TODO: Check if this is needed when working from state inputs
118126
action = self.policy.act(obs / 255)
119127
next_obs, reward, done, info = self.train_env.step(action)
120128
next_obs = np.array(next_obs)
121-
movie_traj.append({'obs': next_obs.reshape((-1, 3, 64, 64))[0]})
129+
if self.cfg.obs_type == 'pixels':
130+
movie_traj.append({'obs': next_obs.reshape((-1, 3, 64, 64))[0]})
131+
else:
132+
image = self.train_env._state_to_image(obs)
133+
movie_traj.append({'obs': image})
122134
traj_rews.append(reward)
123135

124136
constr = info['constraint']
125-
transition = {'obs': obs, 'action': action, 'reward': reward,
126-
'next_obs': next_obs, 'done': done,
127-
'constraint': constr, 'safe_set': 0, 'on_policy': 1, 'discount': self.discount}
137+
if self.cfg.obs_type == 'pixels':
138+
transition = {'obs': obs, 'action': action, 'reward': reward,
139+
'next_obs': next_obs, 'done': done,
140+
'constraint': constr, 'safe_set': 0,
141+
'on_policy': 1, 'discount': self.discount}
142+
else:
143+
transition = {'obs': obs, 'action': action, 'reward': reward,
144+
'next_obs': next_obs, 'done': done,
145+
'constraint': constr, 'safe_set': 0,
146+
'on_policy': 1}
128147
transitions.append(transition)
129148
obs = next_obs
130149
constr_viol = constr_viol or info['constraint']
131150
succ = succ or reward == 0
132-
if done:
151+
if done:
133152
break
134153
transitions[-1]['done'] = 1
135154
traj_reward = sum(traj_rews)
@@ -154,6 +173,7 @@ def train(self):
154173

155174
rtg = rtg + transition['reward']
156175

176+
157177
self.replay_buffer.store_transitions(transitions)
158178
update_rewards.append(traj_reward)
159179

0 commit comments

Comments
 (0)