3
3
# This source code is licensed under the MIT license found in the
4
4
# LICENSE file in the root directory of this source tree.
5
5
6
+ import dataclasses
6
7
import uuid
7
8
from datetime import datetime
8
9
9
- from torchrl .envs import ParallelEnv , EnvCreator
10
- from torchrl .envs .utils import set_exploration_mode
11
- from torchrl .record import VideoRecorder
12
-
13
- try :
14
- import configargparse as argparse
15
-
16
- _configargparse = True
17
- except ImportError :
18
- import argparse
19
-
20
- _configargparse = False
10
+ import hydra
21
11
import torch .cuda
12
+ from hydra .core .config_store import ConfigStore
13
+ from omegaconf import OmegaConf
14
+ from torchrl .envs import ParallelEnv , EnvCreator
22
15
from torchrl .envs .transforms import RewardScaling , TransformedEnv
16
+ from torchrl .envs .utils import set_exploration_mode
23
17
from torchrl .modules import OrnsteinUhlenbeckProcessWrapper
18
+ from torchrl .record import VideoRecorder
24
19
from torchrl .trainers .helpers .collectors import (
25
20
make_collector_offpolicy ,
26
- parser_collector_args_offpolicy ,
21
+ OffPolicyCollectorConfig ,
27
22
)
28
23
from torchrl .trainers .helpers .envs import (
29
24
correct_for_frame_skip ,
30
25
get_stats_random_rollout ,
31
26
parallel_env_constructor ,
32
- parser_env_args ,
33
27
transformed_env_constructor ,
28
+ EnvConfig ,
34
29
)
35
- from torchrl .trainers .helpers .losses import make_sac_loss , parser_loss_args
30
+ from torchrl .trainers .helpers .losses import make_sac_loss , LossConfig
36
31
from torchrl .trainers .helpers .models import (
37
32
make_sac_model ,
38
- parser_model_args_continuous ,
33
+ SACModelConfig ,
39
34
)
40
- from torchrl .trainers .helpers .recorder import parser_recorder_args
35
+ from torchrl .trainers .helpers .recorder import RecorderConfig
41
36
from torchrl .trainers .helpers .replay_buffer import (
42
37
make_replay_buffer ,
43
- parser_replay_args ,
38
+ ReplayArgsConfig ,
44
39
)
45
- from torchrl .trainers .helpers .trainers import make_trainer , parser_trainer_args
46
-
47
-
48
- def make_args ():
49
- parser = argparse .ArgumentParser ()
50
- if _configargparse :
51
- parser .add_argument (
52
- "-c" ,
53
- "--config" ,
54
- required = True ,
55
- is_config_file = True ,
56
- help = "config file path" ,
57
- )
58
- parser_trainer_args (parser )
59
- parser_collector_args_offpolicy (parser )
60
- parser_env_args (parser )
61
- parser_loss_args (parser , algorithm = "SAC" )
62
- parser_model_args_continuous (parser , "SAC" )
63
- parser_recorder_args (parser )
64
- parser_replay_args (parser )
65
- return parser
66
-
40
+ from torchrl .trainers .helpers .trainers import make_trainer , TrainerConfig
41
+
42
+ config_fields = [
43
+ (config_field .name , config_field .type , config_field )
44
+ for config_cls in (
45
+ TrainerConfig ,
46
+ OffPolicyCollectorConfig ,
47
+ EnvConfig ,
48
+ LossConfig ,
49
+ SACModelConfig ,
50
+ RecorderConfig ,
51
+ ReplayArgsConfig ,
52
+ )
53
+ for config_field in dataclasses .fields (config_cls )
54
+ ]
67
55
68
- parser = make_args ()
56
+ Config = dataclasses .make_dataclass (cls_name = "Config" , fields = config_fields )
57
+ cs = ConfigStore .instance ()
58
+ cs .store (name = "config" , node = Config )
69
59
70
60
DEFAULT_REWARD_SCALING = {
71
61
"Hopper-v1" : 5 ,
@@ -78,13 +68,18 @@ def make_args():
78
68
}
79
69
80
70
81
- def main (args ):
71
+ @hydra .main (version_base = None , config_path = None , config_name = "config" )
72
+ def main (cfg : "DictConfig" ):
82
73
from torch .utils .tensorboard import SummaryWriter
83
74
84
- args = correct_for_frame_skip (args )
75
+ if cfg .config_file is not None :
76
+ overriding_cfg = OmegaConf .load (cfg .config_file )
77
+ cfg = OmegaConf .merge (cfg , overriding_cfg )
78
+
79
+ cfg = correct_for_frame_skip (cfg )
85
80
86
- if not isinstance (args .reward_scaling , float ):
87
- args .reward_scaling = DEFAULT_REWARD_SCALING .get (args .env_name , 5.0 )
81
+ if not isinstance (cfg .reward_scaling , float ):
82
+ cfg .reward_scaling = DEFAULT_REWARD_SCALING .get (cfg .env_name , 5.0 )
88
83
89
84
device = (
90
85
torch .device ("cpu" )
@@ -95,47 +90,47 @@ def main(args):
95
90
exp_name = "_" .join (
96
91
[
97
92
"SAC" ,
98
- args .exp_name ,
93
+ cfg .exp_name ,
99
94
str (uuid .uuid4 ())[:8 ],
100
95
datetime .now ().strftime ("%y_%m_%d-%H_%M_%S" ),
101
96
]
102
97
)
103
98
writer = SummaryWriter (f"sac_logging/{ exp_name } " )
104
- video_tag = exp_name if args .record_video else ""
99
+ video_tag = exp_name if cfg .record_video else ""
105
100
106
101
stats = None
107
- if not args .vecnorm and args .norm_stats :
108
- proof_env = transformed_env_constructor (args = args , use_env_creator = False )()
102
+ if not cfg .vecnorm and cfg .norm_stats :
103
+ proof_env = transformed_env_constructor (cfg = cfg , use_env_creator = False )()
109
104
stats = get_stats_random_rollout (
110
- args , proof_env , key = "next_pixels" if args .from_pixels else None
105
+ cfg , proof_env , key = "next_pixels" if cfg .from_pixels else None
111
106
)
112
107
# make sure proof_env is closed
113
108
proof_env .close ()
114
- elif args .from_pixels :
109
+ elif cfg .from_pixels :
115
110
stats = {"loc" : 0.5 , "scale" : 0.5 }
116
111
proof_env = transformed_env_constructor (
117
- args = args , use_env_creator = False , stats = stats
112
+ cfg = cfg , use_env_creator = False , stats = stats
118
113
)()
119
114
model = make_sac_model (
120
115
proof_env ,
121
- args = args ,
116
+ cfg = cfg ,
122
117
device = device ,
123
118
)
124
- loss_module , target_net_updater = make_sac_loss (model , args )
119
+ loss_module , target_net_updater = make_sac_loss (model , cfg )
125
120
126
121
actor_model_explore = model [0 ]
127
- if args .ou_exploration :
122
+ if cfg .ou_exploration :
128
123
actor_model_explore = OrnsteinUhlenbeckProcessWrapper (
129
124
actor_model_explore ,
130
- annealing_num_steps = args .annealing_frames ,
131
- sigma = args .ou_sigma ,
132
- theta = args .ou_theta ,
125
+ annealing_num_steps = cfg .annealing_frames ,
126
+ sigma = cfg .ou_sigma ,
127
+ theta = cfg .ou_theta ,
133
128
).to (device )
134
129
if device == torch .device ("cpu" ):
135
130
# mostly for debugging
136
131
actor_model_explore .share_memory ()
137
132
138
- if args .gSDE :
133
+ if cfg .gSDE :
139
134
with torch .no_grad (), set_exploration_mode ("random" ):
140
135
# get dimensions to build the parallel env
141
136
proof_td = actor_model_explore (proof_env .reset ().to (device ))
@@ -145,7 +140,7 @@ def main(args):
145
140
action_dim_gsde , state_dim_gsde = None , None
146
141
proof_env .close ()
147
142
create_env_fn = parallel_env_constructor (
148
- args = args ,
143
+ cfg = cfg ,
149
144
stats = stats ,
150
145
action_dim_gsde = action_dim_gsde ,
151
146
state_dim_gsde = state_dim_gsde ,
@@ -154,25 +149,25 @@ def main(args):
154
149
collector = make_collector_offpolicy (
155
150
make_env = create_env_fn ,
156
151
actor_model_explore = actor_model_explore ,
157
- args = args ,
152
+ cfg = cfg ,
158
153
# make_env_kwargs=[
159
154
# {"device": device} if device >= 0 else {}
160
155
# for device in args.env_rendering_devices
161
156
# ],
162
157
)
163
158
164
- replay_buffer = make_replay_buffer (device , args )
159
+ replay_buffer = make_replay_buffer (device , cfg )
165
160
166
161
recorder = transformed_env_constructor (
167
- args ,
162
+ cfg ,
168
163
video_tag = video_tag ,
169
164
norm_obs_only = True ,
170
165
stats = stats ,
171
166
writer = writer ,
172
167
)()
173
168
174
169
# remove video recorder from recorder to have matching state_dict keys
175
- if args .record_video :
170
+ if cfg .record_video :
176
171
recorder_rm = TransformedEnv (recorder .env )
177
172
for transform in recorder .transform :
178
173
if not isinstance (transform , VideoRecorder ):
@@ -202,7 +197,7 @@ def main(args):
202
197
actor_model_explore ,
203
198
replay_buffer ,
204
199
writer ,
205
- args ,
200
+ cfg ,
206
201
)
207
202
208
203
def select_keys (batch ):
@@ -219,13 +214,12 @@ def select_keys(batch):
219
214
220
215
trainer .register_op ("batch_process" , select_keys )
221
216
222
- final_seed = collector .set_seed (args .seed )
223
- print (f"init seed: { args .seed } , final seed: { final_seed } " )
217
+ final_seed = collector .set_seed (cfg .seed )
218
+ print (f"init seed: { cfg .seed } , final seed: { final_seed } " )
224
219
225
220
trainer .train ()
226
221
return (writer .log_dir , trainer ._log_dict , trainer .state_dict ())
227
222
228
223
229
224
if __name__ == "__main__" :
230
- args = parser .parse_args ()
231
- main (args )
225
+ main ()
0 commit comments