Skip to content

Commit 18e2a41

Browse files
authored
Integrating hydra with DQN (#201)
1 parent 796c3bc commit 18e2a41

File tree

10 files changed

+87
-111
lines changed

10 files changed

+87
-111
lines changed

examples/ddpg/ddpg.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
import hydra
1111
import torch.cuda
1212
from hydra.core.config_store import ConfigStore
13-
from omegaconf import OmegaConf
1413
from torchrl.envs import ParallelEnv, EnvCreator
1514
from torchrl.envs.transforms import RewardScaling, TransformedEnv
1615
from torchrl.envs.utils import set_exploration_mode
@@ -71,9 +70,6 @@
7170
def main(cfg: "DictConfig"):
7271
from torch.utils.tensorboard import SummaryWriter
7372

74-
if cfg.config_file is not None:
75-
overriding_cfg = OmegaConf.load(cfg.config_file)
76-
cfg = OmegaConf.merge(cfg, overriding_cfg)
7773
cfg = correct_for_frame_skip(cfg)
7874

7975
if not isinstance(cfg.reward_scaling, float):
@@ -171,7 +167,7 @@ def main(cfg: "DictConfig"):
171167

172168
# remove video recorder from recorder to have matching state_dict keys
173169
if cfg.record_video:
174-
recorder_rm = TransformedEnv(recorder.env)
170+
recorder_rm = TransformedEnv(recorder.base_env)
175171
for transform in recorder.transform:
176172
if not isinstance(transform, VideoRecorder):
177173
recorder_rm.append_transform(transform)

examples/dqn/dqn.py

Lines changed: 57 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -3,77 +3,68 @@
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
55

6+
import dataclasses
67
import uuid
78
from datetime import datetime
89

9-
from torchrl.envs import ParallelEnv, EnvCreator
10-
from torchrl.record import VideoRecorder
11-
12-
try:
13-
import configargparse as argparse
14-
15-
_configargparse = True
16-
except ImportError:
17-
import argparse
18-
19-
_configargparse = False
10+
import hydra
2011
import torch.cuda
12+
from hydra.core.config_store import ConfigStore
13+
from torchrl.envs import ParallelEnv, EnvCreator
2114
from torchrl.envs.transforms import RewardScaling, TransformedEnv
2215
from torchrl.modules import EGreedyWrapper
16+
from torchrl.record import VideoRecorder
2317
from torchrl.trainers.helpers.collectors import (
2418
make_collector_offpolicy,
25-
parser_collector_args_offpolicy,
19+
OffPolicyCollectorConfig,
2620
)
2721
from torchrl.trainers.helpers.envs import (
2822
correct_for_frame_skip,
2923
get_stats_random_rollout,
3024
parallel_env_constructor,
31-
parser_env_args,
3225
transformed_env_constructor,
26+
EnvConfig,
3327
)
34-
from torchrl.trainers.helpers.losses import make_dqn_loss, parser_loss_args
28+
from torchrl.trainers.helpers.losses import make_dqn_loss, LossConfig
3529
from torchrl.trainers.helpers.models import (
3630
make_dqn_actor,
37-
parser_model_args_discrete,
31+
DiscreteModelConfig,
3832
)
39-
from torchrl.trainers.helpers.recorder import parser_recorder_args
33+
from torchrl.trainers.helpers.recorder import RecorderConfig
4034
from torchrl.trainers.helpers.replay_buffer import (
4135
make_replay_buffer,
42-
parser_replay_args,
36+
ReplayArgsConfig,
4337
)
44-
from torchrl.trainers.helpers.trainers import make_trainer, parser_trainer_args
45-
46-
47-
def make_args():
48-
parser = argparse.ArgumentParser()
49-
if _configargparse:
50-
parser.add_argument(
51-
"-c",
52-
"--config",
53-
required=True,
54-
is_config_file=True,
55-
help="config file path",
56-
)
57-
parser_trainer_args(parser)
58-
parser_collector_args_offpolicy(parser)
59-
parser_env_args(parser)
60-
parser_loss_args(parser, algorithm="DQN")
61-
parser_model_args_discrete(parser)
62-
parser_recorder_args(parser)
63-
parser_replay_args(parser)
64-
return parser
65-
38+
from torchrl.trainers.helpers.trainers import make_trainer, TrainerConfig
39+
40+
41+
config_fields = [
42+
(config_field.name, config_field.type, config_field)
43+
for config_cls in (
44+
TrainerConfig,
45+
OffPolicyCollectorConfig,
46+
EnvConfig,
47+
LossConfig,
48+
DiscreteModelConfig,
49+
RecorderConfig,
50+
ReplayArgsConfig,
51+
)
52+
for config_field in dataclasses.fields(config_cls)
53+
]
54+
Config = dataclasses.make_dataclass(cls_name="Config", fields=config_fields)
55+
cs = ConfigStore.instance()
56+
cs.store(name="config", node=Config)
6657

67-
parser = make_args()
6858

59+
@hydra.main(version_base=None, config_path=None, config_name="config")
60+
def main(cfg: "DictConfig"):
6961

70-
def main(args):
7162
from torch.utils.tensorboard import SummaryWriter
7263

73-
args = correct_for_frame_skip(args)
64+
cfg = correct_for_frame_skip(cfg)
7465

75-
if not isinstance(args.reward_scaling, float):
76-
args.reward_scaling = 1.0
66+
if not isinstance(cfg.reward_scaling, float):
67+
cfg.reward_scaling = 1.0
7768

7869
device = (
7970
torch.device("cpu")
@@ -84,41 +75,42 @@ def main(args):
8475
exp_name = "_".join(
8576
[
8677
"DQN",
87-
args.exp_name,
78+
cfg.exp_name,
8879
str(uuid.uuid4())[:8],
8980
datetime.now().strftime("%y_%m_%d-%H_%M_%S"),
9081
]
9182
)
9283
writer = SummaryWriter(f"dqn_logging/{exp_name}")
93-
video_tag = exp_name if args.record_video else ""
84+
video_tag = exp_name if cfg.record_video else ""
9485

9586
stats = None
96-
if not args.vecnorm and args.norm_stats:
97-
proof_env = transformed_env_constructor(args=args, use_env_creator=False)()
87+
if not cfg.vecnorm and cfg.norm_stats:
88+
proof_env = transformed_env_constructor(cfg=cfg, use_env_creator=False)()
9889
stats = get_stats_random_rollout(
99-
args, proof_env, key="next_pixels" if args.from_pixels else None
90+
cfg, proof_env, key="next_pixels" if cfg.from_pixels else None
10091
)
10192
# make sure proof_env is closed
10293
proof_env.close()
103-
elif args.from_pixels:
94+
elif cfg.from_pixels:
10495
stats = {"loc": 0.5, "scale": 0.5}
10596
proof_env = transformed_env_constructor(
106-
args=args, use_env_creator=False, stats=stats
97+
cfg=cfg, use_env_creator=False, stats=stats
10798
)()
10899
model = make_dqn_actor(
109100
proof_environment=proof_env,
110-
args=args,
101+
cfg=cfg,
111102
device=device,
112103
)
113104

114-
loss_module, target_net_updater = make_dqn_loss(model, args)
115-
model_explore = EGreedyWrapper(model, annealing_num_steps=args.annealing_frames).to(
105+
loss_module, target_net_updater = make_dqn_loss(model, cfg)
106+
model_explore = EGreedyWrapper(model, annealing_num_steps=cfg.annealing_frames).to(
116107
device
117108
)
109+
118110
action_dim_gsde, state_dim_gsde = None, None
119111
proof_env.close()
120112
create_env_fn = parallel_env_constructor(
121-
args=args,
113+
cfg=cfg,
122114
stats=stats,
123115
action_dim_gsde=action_dim_gsde,
124116
state_dim_gsde=state_dim_gsde,
@@ -127,26 +119,26 @@ def main(args):
127119
collector = make_collector_offpolicy(
128120
make_env=create_env_fn,
129121
actor_model_explore=model_explore,
130-
args=args,
122+
cfg=cfg,
131123
# make_env_kwargs=[
132124
# {"device": device} if device >= 0 else {}
133125
# for device in args.env_rendering_devices
134126
# ],
135127
)
136128

137-
replay_buffer = make_replay_buffer(device, args)
129+
replay_buffer = make_replay_buffer(device, cfg)
138130

139131
recorder = transformed_env_constructor(
140-
args,
132+
cfg,
141133
video_tag=video_tag,
142134
norm_obs_only=True,
143135
stats=stats,
144136
writer=writer,
145137
)()
146138

147139
# remove video recorder from recorder to have matching state_dict keys
148-
if args.record_video:
149-
recorder_rm = TransformedEnv(recorder.env)
140+
if cfg.record_video:
141+
recorder_rm = TransformedEnv(recorder.base_env)
150142
for transform in recorder.transform:
151143
if not isinstance(transform, VideoRecorder):
152144
recorder_rm.append_transform(transform)
@@ -171,10 +163,10 @@ def main(args):
171163
loss_module,
172164
recorder,
173165
target_net_updater,
174-
model_explore,
166+
model,
175167
replay_buffer,
176168
writer,
177-
args,
169+
cfg,
178170
)
179171

180172
def select_keys(batch):
@@ -191,13 +183,12 @@ def select_keys(batch):
191183

192184
trainer.register_op("batch_process", select_keys)
193185

194-
final_seed = collector.set_seed(args.seed)
195-
print(f"init seed: {args.seed}, final seed: {final_seed}")
186+
final_seed = collector.set_seed(cfg.seed)
187+
print(f"init seed: {cfg.seed}, final seed: {final_seed}")
196188

197189
trainer.train()
198190
return (writer.log_dir, trainer._log_dict, trainer.state_dict())
199191

200192

201193
if __name__ == "__main__":
202-
args = parser.parse_args()
203-
main(args)
194+
main()

examples/ppo/configs/humanoid.yaml

Lines changed: 0 additions & 13 deletions
This file was deleted.

examples/ppo/ppo.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
import hydra
1111
import torch.cuda
1212
from hydra.core.config_store import ConfigStore
13-
from omegaconf import OmegaConf
1413
from torchrl.envs import ParallelEnv, EnvCreator
1514
from torchrl.envs.transforms import RewardScaling, TransformedEnv
1615
from torchrl.envs.utils import set_exploration_mode
@@ -56,10 +55,6 @@
5655
def main(cfg: "DictConfig"):
5756
from torch.utils.tensorboard import SummaryWriter
5857

59-
if cfg.config_file is not None:
60-
overriding_cfg = OmegaConf.load(cfg.config_file)
61-
cfg = OmegaConf.merge(cfg, overriding_cfg)
62-
6358
cfg = correct_for_frame_skip(cfg)
6459

6560
if not isinstance(cfg.reward_scaling, float):
@@ -142,7 +137,7 @@ def main(cfg: "DictConfig"):
142137

143138
# remove video recorder from recorder to have matching state_dict keys
144139
if cfg.record_video:
145-
recorder_rm = TransformedEnv(recorder.env)
140+
recorder_rm = TransformedEnv(recorder.base_env)
146141
for transform in recorder.transform:
147142
if not isinstance(transform, VideoRecorder):
148143
recorder_rm.append_transform(transform)

examples/redq/redq.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
import hydra
1111
import torch.cuda
1212
from hydra.core.config_store import ConfigStore
13-
from omegaconf import OmegaConf
1413
from torchrl.envs import ParallelEnv, EnvCreator
1514
from torchrl.envs.transforms import RewardScaling, TransformedEnv
1615
from torchrl.envs.utils import set_exploration_mode
@@ -72,9 +71,6 @@
7271
def main(cfg: "DictConfig"):
7372
from torch.utils.tensorboard import SummaryWriter # avoid loading on each process
7473

75-
if cfg.config_file is not None:
76-
overriding_cfg = OmegaConf.load(cfg.config_file)
77-
cfg = OmegaConf.merge(cfg, overriding_cfg)
7874
cfg = correct_for_frame_skip(cfg)
7975

8076
if not isinstance(cfg.reward_scaling, float):
@@ -171,7 +167,7 @@ def main(cfg: "DictConfig"):
171167

172168
# remove video recorder from recorder to have matching state_dict keys
173169
if cfg.record_video:
174-
recorder_rm = TransformedEnv(recorder.env)
170+
recorder_rm = TransformedEnv(recorder.base_env)
175171
for transform in recorder.transform:
176172
if not isinstance(transform, VideoRecorder):
177173
recorder_rm.append_transform(transform)

examples/sac/sac.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
import hydra
1111
import torch.cuda
1212
from hydra.core.config_store import ConfigStore
13-
from omegaconf import OmegaConf
1413
from torchrl.envs import ParallelEnv, EnvCreator
1514
from torchrl.envs.transforms import RewardScaling, TransformedEnv
1615
from torchrl.envs.utils import set_exploration_mode
@@ -72,10 +71,6 @@
7271
def main(cfg: "DictConfig"):
7372
from torch.utils.tensorboard import SummaryWriter
7473

75-
if cfg.config_file is not None:
76-
overriding_cfg = OmegaConf.load(cfg.config_file)
77-
cfg = OmegaConf.merge(cfg, overriding_cfg)
78-
7974
cfg = correct_for_frame_skip(cfg)
8075

8176
if not isinstance(cfg.reward_scaling, float):
@@ -168,7 +163,7 @@ def main(cfg: "DictConfig"):
168163

169164
# remove video recorder from recorder to have matching state_dict keys
170165
if cfg.record_video:
171-
recorder_rm = TransformedEnv(recorder.env)
166+
recorder_rm = TransformedEnv(recorder.base_env)
172167
for transform in recorder.transform:
173168
if not isinstance(transform, VideoRecorder):
174169
recorder_rm.append_transform(transform)

torchrl/data/tensor_specs.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1003,7 +1003,11 @@ def project(self, val: _TensorDict) -> _TensorDict:
10031003

10041004
def rand(self, shape=torch.Size([])):
10051005
return TensorDict(
1006-
{key: value.rand(shape) for key, value in self._specs.items()},
1006+
{
1007+
key: value.rand(shape)
1008+
for key, value in self._specs.items()
1009+
if value is not None
1010+
},
10071011
batch_size=shape,
10081012
)
10091013

0 commit comments

Comments
 (0)