3
3
# This source code is licensed under the MIT license found in the
4
4
# LICENSE file in the root directory of this source tree.
5
5
6
+ import dataclasses
6
7
import uuid
7
8
from datetime import datetime
8
9
9
- from torchrl .envs import ParallelEnv , EnvCreator
10
- from torchrl .envs .utils import set_exploration_mode
11
- from torchrl .record import VideoRecorder
12
-
13
- try :
14
- import configargparse as argparse
15
-
16
- _configargparse = True
17
- except ImportError :
18
- import argparse
19
-
20
- _configargparse = False
21
-
10
+ import hydra
22
11
import torch .cuda
12
+ from hydra .core .config_store import ConfigStore
13
+ from torchrl .envs import ParallelEnv , EnvCreator
23
14
from torchrl .envs .transforms import RewardScaling , TransformedEnv
15
+ from torchrl .envs .utils import set_exploration_mode
24
16
from torchrl .modules import OrnsteinUhlenbeckProcessWrapper
17
+ from torchrl .record import VideoRecorder
25
18
from torchrl .trainers .helpers .collectors import (
26
19
make_collector_offpolicy ,
27
- parser_collector_args_offpolicy ,
20
+ OffPolicyCollectorConfig ,
28
21
)
29
22
from torchrl .trainers .helpers .envs import (
30
23
correct_for_frame_skip ,
31
24
get_stats_random_rollout ,
32
25
parallel_env_constructor ,
33
- parser_env_args ,
34
26
transformed_env_constructor ,
27
+ EnvConfig ,
35
28
)
36
- from torchrl .trainers .helpers .losses import make_ddpg_loss , parser_loss_args
29
+ from torchrl .trainers .helpers .losses import make_ddpg_loss , LossConfig
37
30
from torchrl .trainers .helpers .models import (
38
31
make_ddpg_actor ,
39
- parser_model_args_continuous ,
32
+ DDPGModelConfig ,
40
33
)
41
- from torchrl .trainers .helpers .recorder import parser_recorder_args
34
+ from torchrl .trainers .helpers .recorder import RecorderConfig
42
35
from torchrl .trainers .helpers .replay_buffer import (
43
36
make_replay_buffer ,
44
- parser_replay_args ,
37
+ ReplayArgsConfig ,
45
38
)
46
- from torchrl .trainers .helpers .trainers import make_trainer , parser_trainer_args
47
-
48
-
49
- def make_args ():
50
- parser = argparse .ArgumentParser ()
51
- if _configargparse :
52
- parser .add_argument (
53
- "-c" ,
54
- "--config" ,
55
- required = True ,
56
- is_config_file = True ,
57
- help = "config file path" ,
58
- )
59
- parser_trainer_args (parser )
60
- parser_collector_args_offpolicy (parser )
61
- parser_env_args (parser )
62
- parser_loss_args (parser , algorithm = "DDPG" )
63
- parser_model_args_continuous (parser , "DDPG" )
64
- parser_recorder_args (parser )
65
- parser_replay_args (parser )
66
- return parser
67
-
68
-
69
- parser = make_args ()
39
+ from torchrl .trainers .helpers .trainers import make_trainer , TrainerConfig
40
+
41
+ config_fields = [
42
+ (config_field .name , config_field .type , config_field )
43
+ for config_cls in (
44
+ TrainerConfig ,
45
+ OffPolicyCollectorConfig ,
46
+ EnvConfig ,
47
+ LossConfig ,
48
+ DDPGModelConfig ,
49
+ RecorderConfig ,
50
+ ReplayArgsConfig ,
51
+ )
52
+ for config_field in dataclasses .fields (config_cls )
53
+ ]
54
+ Config = dataclasses .make_dataclass (cls_name = "Config" , fields = config_fields )
55
+ cs = ConfigStore .instance ()
56
+ cs .store (name = "config" , node = Config )
70
57
71
58
DEFAULT_REWARD_SCALING = {
72
59
"Hopper-v1" : 5 ,
@@ -79,13 +66,14 @@ def make_args():
79
66
}
80
67
81
68
82
- def main (args ):
69
+ @hydra .main (version_base = None , config_path = None , config_name = "config" )
70
+ def main (cfg : "DictConfig" ):
83
71
from torch .utils .tensorboard import SummaryWriter
84
72
85
- args = correct_for_frame_skip (args )
73
+ cfg = correct_for_frame_skip (cfg )
86
74
87
- if not isinstance (args .reward_scaling , float ):
88
- args .reward_scaling = DEFAULT_REWARD_SCALING .get (args .env_name , 5.0 )
75
+ if not isinstance (cfg .reward_scaling , float ):
76
+ cfg .reward_scaling = DEFAULT_REWARD_SCALING .get (cfg .env_name , 5.0 )
89
77
90
78
device = (
91
79
torch .device ("cpu" )
@@ -96,50 +84,50 @@ def main(args):
96
84
exp_name = "_" .join (
97
85
[
98
86
"DDPG" ,
99
- args .exp_name ,
87
+ cfg .exp_name ,
100
88
str (uuid .uuid4 ())[:8 ],
101
89
datetime .now ().strftime ("%y_%m_%d-%H_%M_%S" ),
102
90
]
103
91
)
104
92
writer = SummaryWriter (f"ddpg_logging/{ exp_name } " )
105
- video_tag = exp_name if args .record_video else ""
93
+ video_tag = exp_name if cfg .record_video else ""
106
94
107
95
stats = None
108
- if not args .vecnorm and args .norm_stats :
109
- proof_env = transformed_env_constructor (args = args , use_env_creator = False )()
96
+ if not cfg .vecnorm and cfg .norm_stats :
97
+ proof_env = transformed_env_constructor (cfg = cfg , use_env_creator = False )()
110
98
stats = get_stats_random_rollout (
111
- args , proof_env , key = "next_pixels" if args .from_pixels else None
99
+ cfg , proof_env , key = "next_pixels" if cfg .from_pixels else None
112
100
)
113
101
# make sure proof_env is closed
114
102
proof_env .close ()
115
- elif args .from_pixels :
103
+ elif cfg .from_pixels :
116
104
stats = {"loc" : 0.5 , "scale" : 0.5 }
117
105
proof_env = transformed_env_constructor (
118
- args = args , use_env_creator = False , stats = stats
106
+ cfg = cfg , use_env_creator = False , stats = stats
119
107
)()
120
108
121
109
model = make_ddpg_actor (
122
110
proof_env ,
123
- args = args ,
111
+ cfg = cfg ,
124
112
device = device ,
125
113
)
126
- loss_module , target_net_updater = make_ddpg_loss (model , args )
114
+ loss_module , target_net_updater = make_ddpg_loss (model , cfg )
127
115
128
116
actor_model_explore = model [0 ]
129
- if args .ou_exploration :
130
- if args .gSDE :
117
+ if cfg .ou_exploration :
118
+ if cfg .gSDE :
131
119
raise RuntimeError ("gSDE and ou_exploration are incompatible" )
132
120
actor_model_explore = OrnsteinUhlenbeckProcessWrapper (
133
121
actor_model_explore ,
134
- annealing_num_steps = args .annealing_frames ,
135
- sigma = args .ou_sigma ,
136
- theta = args .ou_theta ,
122
+ annealing_num_steps = cfg .annealing_frames ,
123
+ sigma = cfg .ou_sigma ,
124
+ theta = cfg .ou_theta ,
137
125
).to (device )
138
126
if device == torch .device ("cpu" ):
139
127
# mostly for debugging
140
128
actor_model_explore .share_memory ()
141
129
142
- if args .gSDE :
130
+ if cfg .gSDE :
143
131
with torch .no_grad (), set_exploration_mode ("random" ):
144
132
# get dimensions to build the parallel env
145
133
proof_td = actor_model_explore (proof_env .reset ().to (device ))
@@ -150,7 +138,7 @@ def main(args):
150
138
151
139
proof_env .close ()
152
140
create_env_fn = parallel_env_constructor (
153
- args = args ,
141
+ cfg = cfg ,
154
142
stats = stats ,
155
143
action_dim_gsde = action_dim_gsde ,
156
144
state_dim_gsde = state_dim_gsde ,
@@ -159,17 +147,17 @@ def main(args):
159
147
collector = make_collector_offpolicy (
160
148
make_env = create_env_fn ,
161
149
actor_model_explore = actor_model_explore ,
162
- args = args ,
150
+ cfg = cfg ,
163
151
# make_env_kwargs=[
164
152
# {"device": device} if device >= 0 else {}
165
153
# for device in args.env_rendering_devices
166
154
# ],
167
155
)
168
156
169
- replay_buffer = make_replay_buffer (device , args )
157
+ replay_buffer = make_replay_buffer (device , cfg )
170
158
171
159
recorder = transformed_env_constructor (
172
- args ,
160
+ cfg ,
173
161
video_tag = video_tag ,
174
162
norm_obs_only = True ,
175
163
stats = stats ,
@@ -178,7 +166,7 @@ def main(args):
178
166
)()
179
167
180
168
# remove video recorder from recorder to have matching state_dict keys
181
- if args .record_video :
169
+ if cfg .record_video :
182
170
recorder_rm = TransformedEnv (recorder .base_env )
183
171
for transform in recorder .transform :
184
172
if not isinstance (transform , VideoRecorder ):
@@ -208,7 +196,7 @@ def main(args):
208
196
actor_model_explore ,
209
197
replay_buffer ,
210
198
writer ,
211
- args ,
199
+ cfg ,
212
200
)
213
201
214
202
def select_keys (batch ):
@@ -225,13 +213,12 @@ def select_keys(batch):
225
213
226
214
trainer .register_op ("batch_process" , select_keys )
227
215
228
- final_seed = collector .set_seed (args .seed )
229
- print (f"init seed: { args .seed } , final seed: { final_seed } " )
216
+ final_seed = collector .set_seed (cfg .seed )
217
+ print (f"init seed: { cfg .seed } , final seed: { final_seed } " )
230
218
231
219
trainer .train ()
232
220
return (writer .log_dir , trainer ._log_dict , trainer .state_dict ())
233
221
234
222
235
223
if __name__ == "__main__" :
236
- args = parser .parse_args ()
237
- main (args )
224
+ main ()
0 commit comments