Skip to content

Integrating hydra with DQN #201

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 44 commits into from
Jun 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
b035c90
init
vmoens Jun 8, 2022
f0432f5
Added TrainConfig dataclass to replace args
Jun 1, 2022
c67ea95
Added dataclass RecorderConfig for recorder.py
Jun 1, 2022
6a29d09
Added ReplayArgsConfig
Jun 1, 2022
c784e3c
Added comments for ReplayArgsConfig, TrainerConfig, and RecorderConfig
Jun 2, 2022
00b11a0
Added OffPolicyCollectorConfig and OnPolicyCollectorConfig dataclasses
Jun 2, 2022
60a2ed2
Added EnvConfig
Jun 2, 2022
c5a174e
Added LossConfig and LossPPOConfig
Jun 3, 2022
eab1650
Added ContinuousModelConfig and DiscreteModelConfig
Jun 3, 2022
6d5d28f
Integrated hydra w/ ppo example
Jun 6, 2022
ac4ffc9
Able to override parameters w/ yaml file provided through command line
Jun 7, 2022
8b87b3a
PPO example working w/ hydra
Jun 7, 2022
23a4d66
Fixed styling issues
Jun 7, 2022
062ade0
Added hydra dependencies to setup.py
Jun 9, 2022
93d2869
Refactored args from argparser to cfg
Jun 9, 2022
eeb3c81
Fixed style issues
Jun 9, 2022
0fa9d4f
Fixing more style issues
Jun 9, 2022
05b650c
Refactor input config file to overriding_cfg
Jun 9, 2022
c78c495
Removed import of DictConfig, now using str type hinting for DictConfig
Jun 9, 2022
9bfacd4
Integrated hydra into SAC
Jun 9, 2022
5375822
Integrated hydra into DDPG
Jun 9, 2022
4488ca9
Integrated hydra into REDQ
Jun 9, 2022
c246b6a
Merge remote-tracking branch 'upstream/bugfix_noopsreset' into hydra_…
Jun 9, 2022
ecee59c
Integrated hydra into DQN
Jun 9, 2022
e23dd5e
Commented out config file merging
Jun 10, 2022
3e4a58a
Make hydra optional dependency in trainers.py
Jun 11, 2022
5da617b
change cfg comment
Jun 11, 2022
8094add
Removed hydra in trainers.py
Jun 14, 2022
f986e2d
Changing hydra-core version to >=1.1 version
Jun 14, 2022
71e96ec
Modified tests affected by hydra change
Jun 14, 2022
d51f6f0
Merge branch 'hydra_integration' into hydra_dev_ppo
BoboBananas Jun 14, 2022
00cf44f
Fixing style issues in trainers.py
Jun 14, 2022
069a0cf
Added hydra dependency to environment.yml
Jun 14, 2022
f9156ad
Added generate_seeds import
Jun 14, 2022
3dfa976
Merging changes from ppo branch
Jun 15, 2022
29c5cd1
Fixing style issues
Jun 15, 2022
edb48bf
Removing ppo yaml from git
Jun 15, 2022
5209337
Merge branch 'hydra_integration' into hydra_dev_dqn
BoboBananas Jun 16, 2022
8804ed1
Delete humanoid.yaml
BoboBananas Jun 16, 2022
2e30435
Fixing style issues
Jun 16, 2022
8e43e49
Merge remote-tracking branch 'origin/hydra_integration' into hydra_de…
vmoens Jun 16, 2022
03b8c78
Fixed dqn example, removed epsilon greedy
Jun 20, 2022
522f0bc
BugFix: generating random values from CompositeSpec (#218)
vmoens Jun 21, 2022
be2312a
Merged changes from main to fix dqn example
Jun 22, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions examples/ddpg/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import hydra
import torch.cuda
from hydra.core.config_store import ConfigStore
from omegaconf import OmegaConf
from torchrl.envs import ParallelEnv, EnvCreator
from torchrl.envs.transforms import RewardScaling, TransformedEnv
from torchrl.envs.utils import set_exploration_mode
Expand Down Expand Up @@ -71,9 +70,6 @@
def main(cfg: "DictConfig"):
from torch.utils.tensorboard import SummaryWriter

if cfg.config_file is not None:
overriding_cfg = OmegaConf.load(cfg.config_file)
cfg = OmegaConf.merge(cfg, overriding_cfg)
cfg = correct_for_frame_skip(cfg)

if not isinstance(cfg.reward_scaling, float):
Expand Down Expand Up @@ -171,7 +167,7 @@ def main(cfg: "DictConfig"):

# remove video recorder from recorder to have matching state_dict keys
if cfg.record_video:
recorder_rm = TransformedEnv(recorder.env)
recorder_rm = TransformedEnv(recorder.base_env)
for transform in recorder.transform:
if not isinstance(transform, VideoRecorder):
recorder_rm.append_transform(transform)
Expand Down
123 changes: 57 additions & 66 deletions examples/dqn/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,77 +3,68 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import dataclasses
import uuid
from datetime import datetime

from torchrl.envs import ParallelEnv, EnvCreator
from torchrl.record import VideoRecorder

try:
import configargparse as argparse

_configargparse = True
except ImportError:
import argparse

_configargparse = False
import hydra
import torch.cuda
from hydra.core.config_store import ConfigStore
from torchrl.envs import ParallelEnv, EnvCreator
from torchrl.envs.transforms import RewardScaling, TransformedEnv
from torchrl.modules import EGreedyWrapper
from torchrl.record import VideoRecorder
from torchrl.trainers.helpers.collectors import (
make_collector_offpolicy,
parser_collector_args_offpolicy,
OffPolicyCollectorConfig,
)
from torchrl.trainers.helpers.envs import (
correct_for_frame_skip,
get_stats_random_rollout,
parallel_env_constructor,
parser_env_args,
transformed_env_constructor,
EnvConfig,
)
from torchrl.trainers.helpers.losses import make_dqn_loss, parser_loss_args
from torchrl.trainers.helpers.losses import make_dqn_loss, LossConfig
from torchrl.trainers.helpers.models import (
make_dqn_actor,
parser_model_args_discrete,
DiscreteModelConfig,
)
from torchrl.trainers.helpers.recorder import parser_recorder_args
from torchrl.trainers.helpers.recorder import RecorderConfig
from torchrl.trainers.helpers.replay_buffer import (
make_replay_buffer,
parser_replay_args,
ReplayArgsConfig,
)
from torchrl.trainers.helpers.trainers import make_trainer, parser_trainer_args


def make_args():
parser = argparse.ArgumentParser()
if _configargparse:
parser.add_argument(
"-c",
"--config",
required=True,
is_config_file=True,
help="config file path",
)
parser_trainer_args(parser)
parser_collector_args_offpolicy(parser)
parser_env_args(parser)
parser_loss_args(parser, algorithm="DQN")
parser_model_args_discrete(parser)
parser_recorder_args(parser)
parser_replay_args(parser)
return parser

from torchrl.trainers.helpers.trainers import make_trainer, TrainerConfig


config_fields = [
(config_field.name, config_field.type, config_field)
for config_cls in (
TrainerConfig,
OffPolicyCollectorConfig,
EnvConfig,
LossConfig,
DiscreteModelConfig,
RecorderConfig,
ReplayArgsConfig,
)
for config_field in dataclasses.fields(config_cls)
]
Config = dataclasses.make_dataclass(cls_name="Config", fields=config_fields)
cs = ConfigStore.instance()
cs.store(name="config", node=Config)

parser = make_args()

@hydra.main(version_base=None, config_path=None, config_name="config")
def main(cfg: "DictConfig"):

def main(args):
from torch.utils.tensorboard import SummaryWriter

args = correct_for_frame_skip(args)
cfg = correct_for_frame_skip(cfg)

if not isinstance(args.reward_scaling, float):
args.reward_scaling = 1.0
if not isinstance(cfg.reward_scaling, float):
cfg.reward_scaling = 1.0

device = (
torch.device("cpu")
Expand All @@ -84,41 +75,42 @@ def main(args):
exp_name = "_".join(
[
"DQN",
args.exp_name,
cfg.exp_name,
str(uuid.uuid4())[:8],
datetime.now().strftime("%y_%m_%d-%H_%M_%S"),
]
)
writer = SummaryWriter(f"dqn_logging/{exp_name}")
video_tag = exp_name if args.record_video else ""
video_tag = exp_name if cfg.record_video else ""

stats = None
if not args.vecnorm and args.norm_stats:
proof_env = transformed_env_constructor(args=args, use_env_creator=False)()
if not cfg.vecnorm and cfg.norm_stats:
proof_env = transformed_env_constructor(cfg=cfg, use_env_creator=False)()
stats = get_stats_random_rollout(
args, proof_env, key="next_pixels" if args.from_pixels else None
cfg, proof_env, key="next_pixels" if cfg.from_pixels else None
)
# make sure proof_env is closed
proof_env.close()
elif args.from_pixels:
elif cfg.from_pixels:
stats = {"loc": 0.5, "scale": 0.5}
proof_env = transformed_env_constructor(
args=args, use_env_creator=False, stats=stats
cfg=cfg, use_env_creator=False, stats=stats
)()
model = make_dqn_actor(
proof_environment=proof_env,
args=args,
cfg=cfg,
device=device,
)

loss_module, target_net_updater = make_dqn_loss(model, args)
model_explore = EGreedyWrapper(model, annealing_num_steps=args.annealing_frames).to(
loss_module, target_net_updater = make_dqn_loss(model, cfg)
model_explore = EGreedyWrapper(model, annealing_num_steps=cfg.annealing_frames).to(
device
)

action_dim_gsde, state_dim_gsde = None, None
proof_env.close()
create_env_fn = parallel_env_constructor(
args=args,
cfg=cfg,
stats=stats,
action_dim_gsde=action_dim_gsde,
state_dim_gsde=state_dim_gsde,
Expand All @@ -127,26 +119,26 @@ def main(args):
collector = make_collector_offpolicy(
make_env=create_env_fn,
actor_model_explore=model_explore,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's keep model_explore

args=args,
cfg=cfg,
# make_env_kwargs=[
# {"device": device} if device >= 0 else {}
# for device in args.env_rendering_devices
# ],
)

replay_buffer = make_replay_buffer(device, args)
replay_buffer = make_replay_buffer(device, cfg)

recorder = transformed_env_constructor(
args,
cfg,
video_tag=video_tag,
norm_obs_only=True,
stats=stats,
writer=writer,
)()

# remove video recorder from recorder to have matching state_dict keys
if args.record_video:
recorder_rm = TransformedEnv(recorder.env)
if cfg.record_video:
recorder_rm = TransformedEnv(recorder.base_env)
for transform in recorder.transform:
if not isinstance(transform, VideoRecorder):
recorder_rm.append_transform(transform)
Expand All @@ -171,10 +163,10 @@ def main(args):
loss_module,
recorder,
target_net_updater,
model_explore,
model,
replay_buffer,
writer,
args,
cfg,
)

def select_keys(batch):
Expand All @@ -191,13 +183,12 @@ def select_keys(batch):

trainer.register_op("batch_process", select_keys)

final_seed = collector.set_seed(args.seed)
print(f"init seed: {args.seed}, final seed: {final_seed}")
final_seed = collector.set_seed(cfg.seed)
print(f"init seed: {cfg.seed}, final seed: {final_seed}")

trainer.train()
return (writer.log_dir, trainer._log_dict, trainer.state_dict())


if __name__ == "__main__":
args = parser.parse_args()
main(args)
main()
13 changes: 0 additions & 13 deletions examples/ppo/configs/humanoid.yaml

This file was deleted.

7 changes: 1 addition & 6 deletions examples/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import hydra
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure where do those changes in PPO come from in this PR

import torch.cuda
from hydra.core.config_store import ConfigStore
from omegaconf import OmegaConf
from torchrl.envs import ParallelEnv, EnvCreator
from torchrl.envs.transforms import RewardScaling, TransformedEnv
from torchrl.envs.utils import set_exploration_mode
Expand Down Expand Up @@ -56,10 +55,6 @@
def main(cfg: "DictConfig"):
from torch.utils.tensorboard import SummaryWriter

if cfg.config_file is not None:
overriding_cfg = OmegaConf.load(cfg.config_file)
cfg = OmegaConf.merge(cfg, overriding_cfg)

cfg = correct_for_frame_skip(cfg)

if not isinstance(cfg.reward_scaling, float):
Expand Down Expand Up @@ -142,7 +137,7 @@ def main(cfg: "DictConfig"):

# remove video recorder from recorder to have matching state_dict keys
if cfg.record_video:
recorder_rm = TransformedEnv(recorder.env)
recorder_rm = TransformedEnv(recorder.base_env)
for transform in recorder.transform:
if not isinstance(transform, VideoRecorder):
recorder_rm.append_transform(transform)
Expand Down
6 changes: 1 addition & 5 deletions examples/redq/redq.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import hydra
import torch.cuda
from hydra.core.config_store import ConfigStore
from omegaconf import OmegaConf
from torchrl.envs import ParallelEnv, EnvCreator
from torchrl.envs.transforms import RewardScaling, TransformedEnv
from torchrl.envs.utils import set_exploration_mode
Expand Down Expand Up @@ -72,9 +71,6 @@
def main(cfg: "DictConfig"):
from torch.utils.tensorboard import SummaryWriter # avoid loading on each process

if cfg.config_file is not None:
overriding_cfg = OmegaConf.load(cfg.config_file)
cfg = OmegaConf.merge(cfg, overriding_cfg)
cfg = correct_for_frame_skip(cfg)

if not isinstance(cfg.reward_scaling, float):
Expand Down Expand Up @@ -171,7 +167,7 @@ def main(cfg: "DictConfig"):

# remove video recorder from recorder to have matching state_dict keys
if cfg.record_video:
recorder_rm = TransformedEnv(recorder.env)
recorder_rm = TransformedEnv(recorder.base_env)
for transform in recorder.transform:
if not isinstance(transform, VideoRecorder):
recorder_rm.append_transform(transform)
Expand Down
7 changes: 1 addition & 6 deletions examples/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import hydra
import torch.cuda
from hydra.core.config_store import ConfigStore
from omegaconf import OmegaConf
from torchrl.envs import ParallelEnv, EnvCreator
from torchrl.envs.transforms import RewardScaling, TransformedEnv
from torchrl.envs.utils import set_exploration_mode
Expand Down Expand Up @@ -72,10 +71,6 @@
def main(cfg: "DictConfig"):
from torch.utils.tensorboard import SummaryWriter

if cfg.config_file is not None:
overriding_cfg = OmegaConf.load(cfg.config_file)
cfg = OmegaConf.merge(cfg, overriding_cfg)

cfg = correct_for_frame_skip(cfg)

if not isinstance(cfg.reward_scaling, float):
Expand Down Expand Up @@ -168,7 +163,7 @@ def main(cfg: "DictConfig"):

# remove video recorder from recorder to have matching state_dict keys
if cfg.record_video:
recorder_rm = TransformedEnv(recorder.env)
recorder_rm = TransformedEnv(recorder.base_env)
for transform in recorder.transform:
if not isinstance(transform, VideoRecorder):
recorder_rm.append_transform(transform)
Expand Down
6 changes: 5 additions & 1 deletion torchrl/data/tensor_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1003,7 +1003,11 @@ def project(self, val: _TensorDict) -> _TensorDict:

def rand(self, shape=torch.Size([])):
return TensorDict(
{key: value.rand(shape) for key, value in self._specs.items()},
{
key: value.rand(shape)
for key, value in self._specs.items()
if value is not None
},
batch_size=shape,
)

Expand Down
Loading