Skip to content

Commit 47066f0

Browse files
BoboBananasBhuvan Basireddy
and
Bhuvan Basireddy
authored
Integrating hydra with SAC (#198)
* Added TrainConfig dataclass to replace args * Added dataclass RecorderConfig for recorder.py * Added ReplayArgsConfig * Added comments for ReplayArgsConfig, TrainerConfig, and RecorderConfig * Added OffPolicyCollectorConfig and OnPolicyCollectorConfig dataclasses * Added EnvConfig * Added LossConfig and LossPPOConfig * Added ContinuousModelConfig and DiscreteModelConfig * Integrated hydra w/ ppo example * Able to override parameters w/ yaml file provided through command line * PPO example working w/ hydra * Fixed styling issues * Added hydra dependencies to setup.py * Refactored args from argparser to cfg * Fixed style issues * Fixing more style issues * Refactor input config file to overriding_cfg * Removed import of DictConfig, now using str type hinting for DictConfig * Integrated hydra into SAC * Make hydra optional dependency in trainers.py * change cfg comment * Removed hydra in trainers.py * Changing hydra-core version to >=1.1 version * Modified tests affected by hydra change * Fixing style issues in trainers.py * Added hydra dependency to environment.yml * Added generate_seeds import Co-authored-by: Bhuvan Basireddy <bbreddy@devfair0832.h2.fair>
1 parent f1fedf3 commit 47066f0

File tree

2 files changed

+65
-71
lines changed

2 files changed

+65
-71
lines changed

examples/sac/sac.py

Lines changed: 63 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -3,69 +3,59 @@
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
55

6+
import dataclasses
67
import uuid
78
from datetime import datetime
89

9-
from torchrl.envs import ParallelEnv, EnvCreator
10-
from torchrl.envs.utils import set_exploration_mode
11-
from torchrl.record import VideoRecorder
12-
13-
try:
14-
import configargparse as argparse
15-
16-
_configargparse = True
17-
except ImportError:
18-
import argparse
19-
20-
_configargparse = False
10+
import hydra
2111
import torch.cuda
12+
from hydra.core.config_store import ConfigStore
13+
from omegaconf import OmegaConf
14+
from torchrl.envs import ParallelEnv, EnvCreator
2215
from torchrl.envs.transforms import RewardScaling, TransformedEnv
16+
from torchrl.envs.utils import set_exploration_mode
2317
from torchrl.modules import OrnsteinUhlenbeckProcessWrapper
18+
from torchrl.record import VideoRecorder
2419
from torchrl.trainers.helpers.collectors import (
2520
make_collector_offpolicy,
26-
parser_collector_args_offpolicy,
21+
OffPolicyCollectorConfig,
2722
)
2823
from torchrl.trainers.helpers.envs import (
2924
correct_for_frame_skip,
3025
get_stats_random_rollout,
3126
parallel_env_constructor,
32-
parser_env_args,
3327
transformed_env_constructor,
28+
EnvConfig,
3429
)
35-
from torchrl.trainers.helpers.losses import make_sac_loss, parser_loss_args
30+
from torchrl.trainers.helpers.losses import make_sac_loss, LossConfig
3631
from torchrl.trainers.helpers.models import (
3732
make_sac_model,
38-
parser_model_args_continuous,
33+
SACModelConfig,
3934
)
40-
from torchrl.trainers.helpers.recorder import parser_recorder_args
35+
from torchrl.trainers.helpers.recorder import RecorderConfig
4136
from torchrl.trainers.helpers.replay_buffer import (
4237
make_replay_buffer,
43-
parser_replay_args,
38+
ReplayArgsConfig,
4439
)
45-
from torchrl.trainers.helpers.trainers import make_trainer, parser_trainer_args
46-
47-
48-
def make_args():
49-
parser = argparse.ArgumentParser()
50-
if _configargparse:
51-
parser.add_argument(
52-
"-c",
53-
"--config",
54-
required=True,
55-
is_config_file=True,
56-
help="config file path",
57-
)
58-
parser_trainer_args(parser)
59-
parser_collector_args_offpolicy(parser)
60-
parser_env_args(parser)
61-
parser_loss_args(parser, algorithm="SAC")
62-
parser_model_args_continuous(parser, "SAC")
63-
parser_recorder_args(parser)
64-
parser_replay_args(parser)
65-
return parser
66-
40+
from torchrl.trainers.helpers.trainers import make_trainer, TrainerConfig
41+
42+
config_fields = [
43+
(config_field.name, config_field.type, config_field)
44+
for config_cls in (
45+
TrainerConfig,
46+
OffPolicyCollectorConfig,
47+
EnvConfig,
48+
LossConfig,
49+
SACModelConfig,
50+
RecorderConfig,
51+
ReplayArgsConfig,
52+
)
53+
for config_field in dataclasses.fields(config_cls)
54+
]
6755

68-
parser = make_args()
56+
Config = dataclasses.make_dataclass(cls_name="Config", fields=config_fields)
57+
cs = ConfigStore.instance()
58+
cs.store(name="config", node=Config)
6959

7060
DEFAULT_REWARD_SCALING = {
7161
"Hopper-v1": 5,
@@ -78,13 +68,18 @@ def make_args():
7868
}
7969

8070

81-
def main(args):
71+
@hydra.main(version_base=None, config_path=None, config_name="config")
72+
def main(cfg: "DictConfig"):
8273
from torch.utils.tensorboard import SummaryWriter
8374

84-
args = correct_for_frame_skip(args)
75+
if cfg.config_file is not None:
76+
overriding_cfg = OmegaConf.load(cfg.config_file)
77+
cfg = OmegaConf.merge(cfg, overriding_cfg)
78+
79+
cfg = correct_for_frame_skip(cfg)
8580

86-
if not isinstance(args.reward_scaling, float):
87-
args.reward_scaling = DEFAULT_REWARD_SCALING.get(args.env_name, 5.0)
81+
if not isinstance(cfg.reward_scaling, float):
82+
cfg.reward_scaling = DEFAULT_REWARD_SCALING.get(cfg.env_name, 5.0)
8883

8984
device = (
9085
torch.device("cpu")
@@ -95,47 +90,47 @@ def main(args):
9590
exp_name = "_".join(
9691
[
9792
"SAC",
98-
args.exp_name,
93+
cfg.exp_name,
9994
str(uuid.uuid4())[:8],
10095
datetime.now().strftime("%y_%m_%d-%H_%M_%S"),
10196
]
10297
)
10398
writer = SummaryWriter(f"sac_logging/{exp_name}")
104-
video_tag = exp_name if args.record_video else ""
99+
video_tag = exp_name if cfg.record_video else ""
105100

106101
stats = None
107-
if not args.vecnorm and args.norm_stats:
108-
proof_env = transformed_env_constructor(args=args, use_env_creator=False)()
102+
if not cfg.vecnorm and cfg.norm_stats:
103+
proof_env = transformed_env_constructor(cfg=cfg, use_env_creator=False)()
109104
stats = get_stats_random_rollout(
110-
args, proof_env, key="next_pixels" if args.from_pixels else None
105+
cfg, proof_env, key="next_pixels" if cfg.from_pixels else None
111106
)
112107
# make sure proof_env is closed
113108
proof_env.close()
114-
elif args.from_pixels:
109+
elif cfg.from_pixels:
115110
stats = {"loc": 0.5, "scale": 0.5}
116111
proof_env = transformed_env_constructor(
117-
args=args, use_env_creator=False, stats=stats
112+
cfg=cfg, use_env_creator=False, stats=stats
118113
)()
119114
model = make_sac_model(
120115
proof_env,
121-
args=args,
116+
cfg=cfg,
122117
device=device,
123118
)
124-
loss_module, target_net_updater = make_sac_loss(model, args)
119+
loss_module, target_net_updater = make_sac_loss(model, cfg)
125120

126121
actor_model_explore = model[0]
127-
if args.ou_exploration:
122+
if cfg.ou_exploration:
128123
actor_model_explore = OrnsteinUhlenbeckProcessWrapper(
129124
actor_model_explore,
130-
annealing_num_steps=args.annealing_frames,
131-
sigma=args.ou_sigma,
132-
theta=args.ou_theta,
125+
annealing_num_steps=cfg.annealing_frames,
126+
sigma=cfg.ou_sigma,
127+
theta=cfg.ou_theta,
133128
).to(device)
134129
if device == torch.device("cpu"):
135130
# mostly for debugging
136131
actor_model_explore.share_memory()
137132

138-
if args.gSDE:
133+
if cfg.gSDE:
139134
with torch.no_grad(), set_exploration_mode("random"):
140135
# get dimensions to build the parallel env
141136
proof_td = actor_model_explore(proof_env.reset().to(device))
@@ -145,7 +140,7 @@ def main(args):
145140
action_dim_gsde, state_dim_gsde = None, None
146141
proof_env.close()
147142
create_env_fn = parallel_env_constructor(
148-
args=args,
143+
cfg=cfg,
149144
stats=stats,
150145
action_dim_gsde=action_dim_gsde,
151146
state_dim_gsde=state_dim_gsde,
@@ -154,25 +149,25 @@ def main(args):
154149
collector = make_collector_offpolicy(
155150
make_env=create_env_fn,
156151
actor_model_explore=actor_model_explore,
157-
args=args,
152+
cfg=cfg,
158153
# make_env_kwargs=[
159154
# {"device": device} if device >= 0 else {}
160155
# for device in args.env_rendering_devices
161156
# ],
162157
)
163158

164-
replay_buffer = make_replay_buffer(device, args)
159+
replay_buffer = make_replay_buffer(device, cfg)
165160

166161
recorder = transformed_env_constructor(
167-
args,
162+
cfg,
168163
video_tag=video_tag,
169164
norm_obs_only=True,
170165
stats=stats,
171166
writer=writer,
172167
)()
173168

174169
# remove video recorder from recorder to have matching state_dict keys
175-
if args.record_video:
170+
if cfg.record_video:
176171
recorder_rm = TransformedEnv(recorder.env)
177172
for transform in recorder.transform:
178173
if not isinstance(transform, VideoRecorder):
@@ -202,7 +197,7 @@ def main(args):
202197
actor_model_explore,
203198
replay_buffer,
204199
writer,
205-
args,
200+
cfg,
206201
)
207202

208203
def select_keys(batch):
@@ -219,13 +214,12 @@ def select_keys(batch):
219214

220215
trainer.register_op("batch_process", select_keys)
221216

222-
final_seed = collector.set_seed(args.seed)
223-
print(f"init seed: {args.seed}, final seed: {final_seed}")
217+
final_seed = collector.set_seed(cfg.seed)
218+
print(f"init seed: {cfg.seed}, final seed: {final_seed}")
224219

225220
trainer.train()
226221
return (writer.log_dir, trainer._log_dict, trainer.state_dict())
227222

228223

229224
if __name__ == "__main__":
230-
args = parser.parse_args()
231-
main(args)
225+
main()

torchrl/trainers/helpers/losses.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
"make_redq_loss",
1515
]
1616

17-
from typing import Optional, Tuple
17+
from typing import Optional, Tuple, Any
1818

1919
from torchrl.modules import ActorValueOperator, ActorCriticOperator
2020
from torchrl.objectives import (
@@ -226,7 +226,7 @@ class LossConfig:
226226
# use two (or more!) different qvalue networks trained independently and choose the lowest value
227227
# predicted to predict the state action value. This can be disabled by using this flag.
228228
# REDQ uses an arbitrary number of Q-value functions to speed up learning in MF contexts.
229-
target_entropy: float = None
229+
target_entropy: Any = None
230230
# Target entropy for the policy distribution. Default is None (auto calculated as the `target_entropy = -action_dim`)
231231

232232

0 commit comments

Comments
 (0)