Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AOS55/issue9 #12

Merged
merged 64 commits into from
Nov 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
64 commits
Select commit Hold shift + click to select a range
ee7a9c9
Procedure to collect data from teacher
AOS55 Sep 30, 2022
3d9080e
sample over multiple seeds
AOS55 Sep 30, 2022
7686169
push sampling to datasets directory when done and option selected in …
AOS55 Sep 30, 2022
87c4345
changed obs from state input
AOS55 Sep 30, 2022
77a94f3
corrected for state observation
AOS55 Sep 30, 2022
cae00b7
load snapshot correctly
AOS55 Oct 3, 2022
39194a5
check observation is initiated correctly
AOS55 Oct 3, 2022
fcf11ac
add procedure to sample from constraints and goal_states
AOS55 Oct 3, 2022
d759e28
added skill_dim parameter
AOS55 Oct 3, 2022
c8b9db0
added pixel train param
AOS55 Oct 3, 2022
243fb4a
added configs for sampling
AOS55 Oct 3, 2022
8658285
removed extra line break
AOS55 Oct 3, 2022
fd77b64
255 divide to scale obs correctly
AOS55 Oct 3, 2022
3aaaadd
instantiate safe_set loading correctly
AOS55 Oct 3, 2022
3dc1a46
change to procedure used to upload_data
AOS55 Oct 3, 2022
d96f309
Merge branch 'AOS55/issue9' of github.com:AOS55/url-suite into AOS55/…
AOS55 Oct 3, 2022
c90dade
Safe Learning
AOS55 Oct 3, 2022
54d292a
added random_start option to safe environments
AOS55 Oct 7, 2022
eeb7aa6
corrected transform dict length for replay buffer
AOS55 Oct 7, 2022
1f0774d
added random_start
AOS55 Oct 7, 2022
121365f
added seperate model dir if random for now
AOS55 Oct 7, 2022
527ba94
added seperate dir for restart for now
AOS55 Oct 7, 2022
b7e4ab9
changed to 150 data_counts
AOS55 Oct 25, 2022
5d0e845
refactored to reflect new prioritized sampling
AOS55 Oct 25, 2022
c7f000c
printed pretrained_agent name
AOS55 Oct 25, 2022
198d9af
added optimistic forgetting rule
AOS55 Oct 25, 2022
5bf7d0b
sampling_batch updated to use smm and prioritized_sampling
AOS55 Oct 25, 2022
d9d68f1
increased number of samples ot 150
AOS55 Oct 25, 2022
88886c9
increased num_updates to 500
AOS55 Oct 25, 2022
522f899
set random_start to false for pretraining
AOS55 Oct 25, 2022
9caf0f8
corrected based on new prioritized_sampling approach
AOS55 Oct 25, 2022
7aa5182
changed the number of skill dimensions to reflect z
AOS55 Oct 25, 2022
d3fca9e
added method to view goal indicator
AOS55 Oct 25, 2022
7950f4f
changed to view loss plotter
AOS55 Oct 25, 2022
7cc92c4
assert and store transitions
AOS55 Oct 25, 2022
324d99d
convert prior to correct shape in replay_buffer storage
AOS55 Oct 25, 2022
b2d1518
fixed diagram orientation
AOS55 Oct 25, 2022
a19aab8
added svb to environment types
AOS55 Oct 25, 2022
a29bf56
using simple_velocity_bot
AOS55 Oct 26, 2022
c2bfb72
correct for state representation
AOS55 Oct 26, 2022
0341ca5
add goal_behaviour
AOS55 Oct 28, 2022
573a262
added protocol for saving if ep length not 100
AOS55 Oct 31, 2022
ef8983f
working on ant environment
AOS55 Oct 31, 2022
c07c727
sampling prioritized replay
AOS55 Oct 31, 2022
eba2b4e
skill_dim
AOS55 Oct 31, 2022
199f422
remove print
AOS55 Oct 31, 2022
66b3eeb
change to sh script
AOS55 Oct 31, 2022
913b863
added log to gitignore
AOS55 Oct 31, 2022
31edba6
pass custom skill dim through smm.yaml
AOS55 Oct 31, 2022
81a7021
added .out to gitignore
AOS55 Oct 31, 2022
f2b9d46
Merge pull request #11 from AOS55/issue9-detach
AOS55 Oct 31, 2022
1160b3a
removed line that shouldnt be there
AOS55 Nov 1, 2022
f28e793
added reward function
AOS55 Nov 1, 2022
4ec240a
pulled out smm coefficients
AOS55 Nov 1, 2022
fea7dae
change p_reward when available
AOS55 Nov 6, 2022
78312e1
logging losses
AOS55 Nov 6, 2022
8e2849d
adding longer training
AOS55 Nov 6, 2022
9d2d788
added plot param and more training
AOS55 Nov 8, 2022
44b1ed9
fixed range of log_p_star
AOS55 Nov 8, 2022
74e7c6f
added smm params for pass through and log params
AOS55 Nov 8, 2022
36d8577
add pretrain ent_coef sweep
AOS55 Nov 8, 2022
44d77aa
reduced reward required, prioritized_sampling
AOS55 Nov 23, 2022
5df60b0
sampling batch over snapshot and skill dim
AOS55 Nov 23, 2022
50e6ab0
added procedure to rename files
AOS55 Nov 23, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,6 @@ data/
datasets/
libraries/gym
*outputs
models/
models/
*.log
*.out
59 changes: 49 additions & 10 deletions agents/unsupervised_learning/smm.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,17 +147,19 @@ def __init__(self, z_dim, sp_lr, vae_lr, vae_beta, state_ent_coef,
self.update_encoder = update_encoder

kwargs["meta_dim"] = self.z_dim
#TODO: Fix this!
self.obs_type = kwargs["obs_type"]
super().__init__(**kwargs)
# self.obs_dim is now the real obs_dim (or repr_dim) + z_dim
self.smm = SMM(self.obs_dim - z_dim,
z_dim,
hidden_dim=kwargs['hidden_dim'],
vae_beta=vae_beta,
device=kwargs['device']).to(kwargs['device'])
self.pred_optimizer = torch.optim.Adam(
self.smm.z_pred_net.parameters(), lr=sp_lr)
self.vae_optimizer = torch.optim.Adam(self.smm.vae.parameters(),
lr=vae_lr)

self.goal = (150, 75) # TODO: Fix as part of config
self.pred_optimizer = torch.optim.Adam(self.smm.z_pred_net.parameters(), lr=sp_lr)
self.vae_optimizer = torch.optim.Adam(self.smm.vae.parameters(), lr=vae_lr)

self.smm.train()

Expand Down Expand Up @@ -236,6 +238,21 @@ def update_pred(self, obs, z):

return metrics, h_z_s

def get_goal_p_star(self, agent_pos):
x_dist = agent_pos[:, 0] - self.goal[0]
y_dist = agent_pos[:, 1] - self.goal[1]
x_dist = x_dist.cpu().detach().numpy()
y_dist = y_dist.cpu().detach().numpy()
dist = np.linalg.norm((x_dist, y_dist), axis=0)
def _prior_distro(dist):
if dist > 1.0:
p_star = 1/dist
else:
p_star = 1.0
return p_star
p_star = np.array(list(map(_prior_distro, dist)), dtype=np.float32)
return p_star

def update(self, replay_iter, step):
metrics = dict()
if step % self.update_every_steps != 0:
Expand All @@ -244,7 +261,6 @@ def update(self, replay_iter, step):

obs, action, extr_reward, discount, next_obs, z = utils.to_torch(
batch, self.device)

obs = self.aug_and_encode(obs)
with torch.no_grad():
next_obs = self.aug_and_encode(next_obs)
Expand All @@ -258,14 +274,37 @@ def update(self, replay_iter, step):
h_z = np.log(self.z_dim) # One-hot z encoding
h_z *= torch.ones_like(extr_reward).to(self.device)

pred_log_ratios = self.state_ent_coef * h_s_z.detach(
) # p^*(s) is ignored, as state space dimension is inaccessible from pixel input
intr_reward = pred_log_ratios + self.latent_ent_coef * h_z + self.latent_cond_ent_coef * h_z_s.detach(
)
reward = intr_reward
pred_log_ratios = self.state_ent_coef * h_s_z.detach()

if self.obs_type=='pixels':
# p^*(s) is ignored, as state space dimension is inaccessible from pixel input
intr_reward = pred_log_ratios + self.latent_ent_coef * h_z + self.latent_cond_ent_coef * h_z_s.detach()
reward = intr_reward
else:
# p^*(s) is based on the goal hitting time
# TODO: Assumes obs is just (x, y) at front
p_star = self.get_goal_p_star(obs)
log_p_star = np.log(p_star)
log_p_star = torch.tensor(log_p_star).to(self.device)
# TODO: Check signs in this intrinsic reward function, maybe ask author
# intr_reward = log_p_star + pred_log_ratios + self.latent_ent_coef * h_z + self.latent_cond_ent_coef * h_z_s.detach()
intr_reward = log_p_star + pred_log_ratios + self.latent_ent_coef * h_z + self.latent_cond_ent_coef * h_z_s.detach()
# print(f'intr_reward: {intr_reward[0]} = p*: {100 * log_p_star[0]} + rho_pi: {pred_log_ratios[0]} +h(z): {self.latent_ent_coef * h_z[0]} + h(z|s): {self.latent_cond_ent_coef * h_z_s.detach()[0]}')
reward = intr_reward
else:
reward = extr_reward

if self.obs_type=='states' and self.reward_free:
# add reward free to states motivation
metrics['intr_reward'] = intr_reward.mean().item()
metrics['log_p_star'] = log_p_star.mean().item()
metrics['pred_log_ratios'] = pred_log_ratios.mean().item()
metrics['latent_ent_coef'] = (self.latent_ent_coef * h_z).mean().item()
metrics['latent_cond_ent_coef'] = (self.latent_cond_ent_coef * h_z_s.detach()).mean().item()
# add loss values
metrics['loss_vae'] = vae_metrics['loss_vae']
metrics['loss_pred'] = pred_metrics['loss_pred']

if self.use_tb or self.use_wandb:
metrics.update(vae_metrics)
metrics.update(pred_metrics)
Expand Down
137 changes: 137 additions & 0 deletions collect_controlled_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
from libraries.latentsafesets.utils.arg_parser import parse_args
from libraries.latentsafesets.utils import utils
from libraries.latentsafesets.utils import plot_utils as pu

from pathlib import Path

import torch
import pprint
import hydra
import logging
import os
import numpy as np

from libraries.safe import SimplePointBot as SPB
from libraries.safe import SimpleVelocityBot as SVB
from libraries.safe import bottleneck_nav as BottleNeck
from libraries.latentsafesets.utils.teacher import ConstraintTeacher, SimplePointBotTeacher, SimpleVelocityBotTeacher, SimpleVelocityBotConstraintTeacher, BottleNeckTeacher, BottleNeckConstraintTeacher
log = logging.getLogger("collect")
from utils.env_constructor import make

ENV = {
'SimplePointBot' : SPB,
'SimpleVelocityBot' : SVB,
'BottleNeck' : BottleNeck
}

ENV_TEACHERS = {
'SimplePointBot' : [
SimplePointBotTeacher, ConstraintTeacher
],
'SimpleVelocityBot' : [
SimpleVelocityBotTeacher, SimpleVelocityBotConstraintTeacher
],
'BottleNeck' : [
BottleNeckTeacher, BottleNeckConstraintTeacher
]
}

DATA_DIRS = {
'SimplePointBot' : [
'SimplePointBot', 'SimplePointBot'
],
'SimpleVelocityBot' : [
'SimpleVelocityBot', 'SimpleVelocityBotConstraint'
],
'BottleNeck' : [
'BottleNeck', 'BottleNeckConstraints'
]
}

DATA_COUNTS = {
'SimplePointBot' : [
150, 150
],
'SimpleVeocityBot' : [
100, 100
],
'BottleNeck' : [
100, 100
]
}


class Workspace:

def __init__(self, cfg):
self.work_dir = Path.cwd()
self.logdir = cfg.log_dir
print(f'workspace: {self.work_dir}')
self.cfg = cfg
self.device = torch.device(cfg.device)
self.env = ENV[self.cfg.env]
if self.cfg.obs_type == 'pixels':
self.sample_env = self.env(from_pixels=True)
else:
self.sample_env = self.env(from_pixels=False)

def sample_demo_data(self):
teachers = ENV_TEACHERS[self.cfg.env]
data_dirs = DATA_DIRS[self.cfg.env]
data_counts = DATA_COUNTS[self.cfg.env]

idc = 0
for teacher, data_dir, count in list(zip(teachers, data_dirs, data_counts)):
self.generate_teacher_demo_data(data_dir, teacher, count, count_start=idc)
idc += count

def generate_teacher_demo_data(self, data_dir, teacher, count, count_start=0, noisy=False):
demo_dir = os.path.join(self.work_dir, data_dir)
if not os.path.exists(demo_dir):
os.makedirs(demo_dir)
# else:
# raise RuntimeError(f'Directory {demo_dir} already exists!')
teacher = teacher(self.sample_env, noisy=noisy)
demonstrations = []
for idc in range(count):
idc += count_start
traj = teacher.generate_trajectory()
reward = sum([frame['reward'] for frame in traj])
print(f'Trajectory {idc}, Reward {reward}')
demonstrations.append(traj)
self.save_trajectory(traj, demo_dir, idc)
# if idc < 50 and self.logdir is not None:
# pu.make_movie(traj, os.path.join(self.logdir, f'{data_dir}_{idc}.gif'))
return demonstrations

@staticmethod
def save_trajectory(traj, demo_dir, idc):
observation = []
action = []
reward = []
safe_set = []
constraint = []
on_policy = []
rtg = []
done = []
for trajectory in traj:
observation.append(trajectory['obs'])
action.append(trajectory['action'])
reward.append(trajectory['reward'])
safe_set.append(trajectory['safe_set'])
on_policy.append(trajectory['on_policy'])
constraint.append(trajectory['constraint'])
rtg.append(trajectory['rtg'])
done.append(trajectory['done'])
file_name = os.path.join(demo_dir, f'episode_{idc}_100')
np.savez_compressed(file_name, observation=observation, action=action, constraint=constraint, reward=reward,
safe_set=safe_set, on_policy=on_policy, rtg=rtg, done=done)

@hydra.main(config_path='configs/.', config_name='mpc')
def main(cfg):
from collect_controlled_data import Workspace as W
workspace = W(cfg)
workspace.sample_demo_data()

if __name__=='__main__':
main()
8 changes: 4 additions & 4 deletions configs/agent/smm.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ _target_: agents.unsupervised_learning.SMMAgent
name: smm

# z params
z_dim: 4 # default in codebase is 4
z_dim: ${skill_dim} # default in codebase is 4

# z discriminator params
sp_lr: 1e-3
Expand All @@ -13,9 +13,9 @@ vae_lr: 1e-2
vae_beta: 0.5

# reward params
state_ent_coef: 1.0
latent_ent_coef: 1.0
latent_cond_ent_coef: 1.0
state_ent_coef: ${state_ent_coef}
latent_ent_coef: ${latent_ent_coef}
latent_cond_ent_coef: ${latent_cond_ent_coef}

# DDPG params
reward_free: ${reward_free}
Expand Down
30 changes: 30 additions & 0 deletions configs/controlled_data.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
defaults:
- agent: ddpg
- override hydra/launcher: submitit_local


# env settings
env: SimplePointBot
obs_type: states
num_samples: 150
frame_stack: 1
action_repeat: 1
seed: 1

# experiment
experiment: exp

hydra:
run:
dir: ./exp_local/${now:%Y.%m.%d}/${now:%H%M%S}_${teacher}
sweep:
dir: ./exp_sweep/${now:%Y.%m.%d}/${now:%H%M}_${teacher}_${experiment}
subdir: ${hydra.job.num}
launcher:
timeout_min: 4300
cpus_per_task: 10
gpus_per_node: 1
tasks_per_node: 1
mem_gb: 160
nodes: 1
submitit_folder: ./exp_sweep/${now:%Y.%m.%d}/${now:%H%M}_${teacher}_${experiment}/.slurm
1 change: 1 addition & 0 deletions configs/finetune.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ task: walker_stand
obs_type: states # [states, pixels]
frame_stack: 1 # only works if obs_type=pixels
action_repeat: 1 # set to 2 for pixels
skill_dim: 10
discount: 0.99
# train settings
num_train_frames: 2000010
Expand Down
16 changes: 8 additions & 8 deletions configs/mpc.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,22 @@ defaults:


# Task Settings
task: SimplePointBot_goal
env: SimplePointBot
obs_type: pixels # [states, pixels]
task: SimpleVelocityBot_goal
env: SimpleVelocityBot
obs_type: states # [states, pixels]
frame_stack: 1
action_repeat: 1
discount: 1.0
num_updates: 25
num_updates: 500
log_freq: 100
plot_freq: 500
log_dir: ./logs

# Module Settings
# Encoder
enc_checkpoint: ../../../models/spb/vae.pth
d_latent: 32
d_obs: [3, 64, 64]
d_latent: 2
d_obs: [2]
enc_init_iters: 100000
enc_kl_multiplier: 1e-6
enc_data_aug: false
Expand Down Expand Up @@ -97,8 +97,8 @@ constr_hidden_size: 200
constr_lr: 1e-4

# Replay Buffer
data_dirs: datasets/pixels/SimplePointBot/diayn/buffer
data_counts: 1000
data_dirs: datasets/states/SimpleVelocityBot/controller/prioritized_sampling_1200
data_counts: 600
buffer_size: 35000

# Misc Settings
Expand Down
16 changes: 12 additions & 4 deletions configs/pretrain.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,24 @@ domain: walker # primal task will be inferred in runtime
obs_type: states # [states, pixels]
frame_stack: 1 # only works if obs_type=pixels
action_repeat: 1 # set to 2 for pixels
skill_dim: 10

# smm reward params
state_ent_coef: 1.0
latent_ent_coef: 1.0
latent_cond_ent_coef: 1.0

skill_dim: 51
discount: 0.99
random_start: false
plot: false
# train settings
num_train_frames: 2000010
num_train_frames: 16000100
num_seed_frames: 4000
# eval
eval_every_frames: 10000
eval_every_frames: 100000
num_eval_episodes: 10
# snapshot
snapshots: [10000, 50000, 100000, 500000, 1000000, 1500000, 2000000]
snapshots: [10000, 50000, 100000, 500000, 1000000, 1500000, 2000000, 3000000, 4000000, 5000000, 6000000, 7000000, 8000000, 9000000, 10000000, 11000000, 12000000, 13000000, 14000000, 15000000, 16000000]
snapshot_dir: ../../../data/models/${obs_type}/${domain}/${agent.name}/${skill_dim}/${seed}
# replay buffer
replay_buffer_size: 1000000
Expand Down
Loading