Skip to content

Commit 336b981

Browse files
authored
Hydra integration (#202)
1 parent b162180 commit 336b981

File tree

18 files changed

+1359
-1726
lines changed

18 files changed

+1359
-1726
lines changed

.circleci/unittest/linux/scripts/environment.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,4 +24,5 @@ dependencies:
2424
- scipy
2525
- dm_control
2626
- mujoco_py
27+
- hydra-core
2728
- pyrender

.circleci/unittest/linux_optdeps/scripts/environment.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,4 @@ dependencies:
1414
- expecttest
1515
- pyyaml
1616
- scipy
17+
- hydra-core

.circleci/unittest/linux_stable/scripts/environment.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,4 +25,5 @@ dependencies:
2525
- scipy
2626
- dm_control
2727
- mujoco_py
28+
- hydra-core
2829
- pyrender

examples/ddpg/ddpg.py

Lines changed: 59 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -3,70 +3,57 @@
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
55

6+
import dataclasses
67
import uuid
78
from datetime import datetime
89

9-
from torchrl.envs import ParallelEnv, EnvCreator
10-
from torchrl.envs.utils import set_exploration_mode
11-
from torchrl.record import VideoRecorder
12-
13-
try:
14-
import configargparse as argparse
15-
16-
_configargparse = True
17-
except ImportError:
18-
import argparse
19-
20-
_configargparse = False
21-
10+
import hydra
2211
import torch.cuda
12+
from hydra.core.config_store import ConfigStore
13+
from torchrl.envs import ParallelEnv, EnvCreator
2314
from torchrl.envs.transforms import RewardScaling, TransformedEnv
15+
from torchrl.envs.utils import set_exploration_mode
2416
from torchrl.modules import OrnsteinUhlenbeckProcessWrapper
17+
from torchrl.record import VideoRecorder
2518
from torchrl.trainers.helpers.collectors import (
2619
make_collector_offpolicy,
27-
parser_collector_args_offpolicy,
20+
OffPolicyCollectorConfig,
2821
)
2922
from torchrl.trainers.helpers.envs import (
3023
correct_for_frame_skip,
3124
get_stats_random_rollout,
3225
parallel_env_constructor,
33-
parser_env_args,
3426
transformed_env_constructor,
27+
EnvConfig,
3528
)
36-
from torchrl.trainers.helpers.losses import make_ddpg_loss, parser_loss_args
29+
from torchrl.trainers.helpers.losses import make_ddpg_loss, LossConfig
3730
from torchrl.trainers.helpers.models import (
3831
make_ddpg_actor,
39-
parser_model_args_continuous,
32+
DDPGModelConfig,
4033
)
41-
from torchrl.trainers.helpers.recorder import parser_recorder_args
34+
from torchrl.trainers.helpers.recorder import RecorderConfig
4235
from torchrl.trainers.helpers.replay_buffer import (
4336
make_replay_buffer,
44-
parser_replay_args,
37+
ReplayArgsConfig,
4538
)
46-
from torchrl.trainers.helpers.trainers import make_trainer, parser_trainer_args
47-
48-
49-
def make_args():
50-
parser = argparse.ArgumentParser()
51-
if _configargparse:
52-
parser.add_argument(
53-
"-c",
54-
"--config",
55-
required=True,
56-
is_config_file=True,
57-
help="config file path",
58-
)
59-
parser_trainer_args(parser)
60-
parser_collector_args_offpolicy(parser)
61-
parser_env_args(parser)
62-
parser_loss_args(parser, algorithm="DDPG")
63-
parser_model_args_continuous(parser, "DDPG")
64-
parser_recorder_args(parser)
65-
parser_replay_args(parser)
66-
return parser
67-
68-
69-
parser = make_args()
39+
from torchrl.trainers.helpers.trainers import make_trainer, TrainerConfig
40+
41+
config_fields = [
42+
(config_field.name, config_field.type, config_field)
43+
for config_cls in (
44+
TrainerConfig,
45+
OffPolicyCollectorConfig,
46+
EnvConfig,
47+
LossConfig,
48+
DDPGModelConfig,
49+
RecorderConfig,
50+
ReplayArgsConfig,
51+
)
52+
for config_field in dataclasses.fields(config_cls)
53+
]
54+
Config = dataclasses.make_dataclass(cls_name="Config", fields=config_fields)
55+
cs = ConfigStore.instance()
56+
cs.store(name="config", node=Config)
7057

7158
DEFAULT_REWARD_SCALING = {
7259
"Hopper-v1": 5,
@@ -79,13 +66,14 @@ def make_args():
7966
}
8067

8168

82-
def main(args):
69+
@hydra.main(version_base=None, config_path=None, config_name="config")
70+
def main(cfg: "DictConfig"):
8371
from torch.utils.tensorboard import SummaryWriter
8472

85-
args = correct_for_frame_skip(args)
73+
cfg = correct_for_frame_skip(cfg)
8674

87-
if not isinstance(args.reward_scaling, float):
88-
args.reward_scaling = DEFAULT_REWARD_SCALING.get(args.env_name, 5.0)
75+
if not isinstance(cfg.reward_scaling, float):
76+
cfg.reward_scaling = DEFAULT_REWARD_SCALING.get(cfg.env_name, 5.0)
8977

9078
device = (
9179
torch.device("cpu")
@@ -96,50 +84,50 @@ def main(args):
9684
exp_name = "_".join(
9785
[
9886
"DDPG",
99-
args.exp_name,
87+
cfg.exp_name,
10088
str(uuid.uuid4())[:8],
10189
datetime.now().strftime("%y_%m_%d-%H_%M_%S"),
10290
]
10391
)
10492
writer = SummaryWriter(f"ddpg_logging/{exp_name}")
105-
video_tag = exp_name if args.record_video else ""
93+
video_tag = exp_name if cfg.record_video else ""
10694

10795
stats = None
108-
if not args.vecnorm and args.norm_stats:
109-
proof_env = transformed_env_constructor(args=args, use_env_creator=False)()
96+
if not cfg.vecnorm and cfg.norm_stats:
97+
proof_env = transformed_env_constructor(cfg=cfg, use_env_creator=False)()
11098
stats = get_stats_random_rollout(
111-
args, proof_env, key="next_pixels" if args.from_pixels else None
99+
cfg, proof_env, key="next_pixels" if cfg.from_pixels else None
112100
)
113101
# make sure proof_env is closed
114102
proof_env.close()
115-
elif args.from_pixels:
103+
elif cfg.from_pixels:
116104
stats = {"loc": 0.5, "scale": 0.5}
117105
proof_env = transformed_env_constructor(
118-
args=args, use_env_creator=False, stats=stats
106+
cfg=cfg, use_env_creator=False, stats=stats
119107
)()
120108

121109
model = make_ddpg_actor(
122110
proof_env,
123-
args=args,
111+
cfg=cfg,
124112
device=device,
125113
)
126-
loss_module, target_net_updater = make_ddpg_loss(model, args)
114+
loss_module, target_net_updater = make_ddpg_loss(model, cfg)
127115

128116
actor_model_explore = model[0]
129-
if args.ou_exploration:
130-
if args.gSDE:
117+
if cfg.ou_exploration:
118+
if cfg.gSDE:
131119
raise RuntimeError("gSDE and ou_exploration are incompatible")
132120
actor_model_explore = OrnsteinUhlenbeckProcessWrapper(
133121
actor_model_explore,
134-
annealing_num_steps=args.annealing_frames,
135-
sigma=args.ou_sigma,
136-
theta=args.ou_theta,
122+
annealing_num_steps=cfg.annealing_frames,
123+
sigma=cfg.ou_sigma,
124+
theta=cfg.ou_theta,
137125
).to(device)
138126
if device == torch.device("cpu"):
139127
# mostly for debugging
140128
actor_model_explore.share_memory()
141129

142-
if args.gSDE:
130+
if cfg.gSDE:
143131
with torch.no_grad(), set_exploration_mode("random"):
144132
# get dimensions to build the parallel env
145133
proof_td = actor_model_explore(proof_env.reset().to(device))
@@ -150,7 +138,7 @@ def main(args):
150138

151139
proof_env.close()
152140
create_env_fn = parallel_env_constructor(
153-
args=args,
141+
cfg=cfg,
154142
stats=stats,
155143
action_dim_gsde=action_dim_gsde,
156144
state_dim_gsde=state_dim_gsde,
@@ -159,17 +147,17 @@ def main(args):
159147
collector = make_collector_offpolicy(
160148
make_env=create_env_fn,
161149
actor_model_explore=actor_model_explore,
162-
args=args,
150+
cfg=cfg,
163151
# make_env_kwargs=[
164152
# {"device": device} if device >= 0 else {}
165153
# for device in args.env_rendering_devices
166154
# ],
167155
)
168156

169-
replay_buffer = make_replay_buffer(device, args)
157+
replay_buffer = make_replay_buffer(device, cfg)
170158

171159
recorder = transformed_env_constructor(
172-
args,
160+
cfg,
173161
video_tag=video_tag,
174162
norm_obs_only=True,
175163
stats=stats,
@@ -178,7 +166,7 @@ def main(args):
178166
)()
179167

180168
# remove video recorder from recorder to have matching state_dict keys
181-
if args.record_video:
169+
if cfg.record_video:
182170
recorder_rm = TransformedEnv(recorder.base_env)
183171
for transform in recorder.transform:
184172
if not isinstance(transform, VideoRecorder):
@@ -208,7 +196,7 @@ def main(args):
208196
actor_model_explore,
209197
replay_buffer,
210198
writer,
211-
args,
199+
cfg,
212200
)
213201

214202
def select_keys(batch):
@@ -225,13 +213,12 @@ def select_keys(batch):
225213

226214
trainer.register_op("batch_process", select_keys)
227215

228-
final_seed = collector.set_seed(args.seed)
229-
print(f"init seed: {args.seed}, final seed: {final_seed}")
216+
final_seed = collector.set_seed(cfg.seed)
217+
print(f"init seed: {cfg.seed}, final seed: {final_seed}")
230218

231219
trainer.train()
232220
return (writer.log_dir, trainer._log_dict, trainer.state_dict())
233221

234222

235223
if __name__ == "__main__":
236-
args = parser.parse_args()
237-
main(args)
224+
main()

0 commit comments

Comments
 (0)