3
3
# This source code is licensed under the MIT license found in the
4
4
# LICENSE file in the root directory of this source tree.
5
5
6
+ import dataclasses
6
7
import uuid
7
8
from datetime import datetime
8
9
9
- from torchrl .envs import ParallelEnv , EnvCreator
10
- from torchrl .record import VideoRecorder
11
-
12
- try :
13
- import configargparse as argparse
14
-
15
- _configargparse = True
16
- except ImportError :
17
- import argparse
18
-
19
- _configargparse = False
10
+ import hydra
20
11
import torch .cuda
12
+ from hydra .core .config_store import ConfigStore
13
+ from torchrl .envs import ParallelEnv , EnvCreator
21
14
from torchrl .envs .transforms import RewardScaling , TransformedEnv
22
15
from torchrl .modules import EGreedyWrapper
16
+ from torchrl .record import VideoRecorder
23
17
from torchrl .trainers .helpers .collectors import (
24
18
make_collector_offpolicy ,
25
- parser_collector_args_offpolicy ,
19
+ OffPolicyCollectorConfig ,
26
20
)
27
21
from torchrl .trainers .helpers .envs import (
28
22
correct_for_frame_skip ,
29
23
get_stats_random_rollout ,
30
24
parallel_env_constructor ,
31
- parser_env_args ,
32
25
transformed_env_constructor ,
26
+ EnvConfig ,
33
27
)
34
- from torchrl .trainers .helpers .losses import make_dqn_loss , parser_loss_args
28
+ from torchrl .trainers .helpers .losses import make_dqn_loss , LossConfig
35
29
from torchrl .trainers .helpers .models import (
36
30
make_dqn_actor ,
37
- parser_model_args_discrete ,
31
+ DiscreteModelConfig ,
38
32
)
39
- from torchrl .trainers .helpers .recorder import parser_recorder_args
33
+ from torchrl .trainers .helpers .recorder import RecorderConfig
40
34
from torchrl .trainers .helpers .replay_buffer import (
41
35
make_replay_buffer ,
42
- parser_replay_args ,
36
+ ReplayArgsConfig ,
43
37
)
44
- from torchrl .trainers .helpers .trainers import make_trainer , parser_trainer_args
45
-
46
-
47
- def make_args ():
48
- parser = argparse .ArgumentParser ()
49
- if _configargparse :
50
- parser .add_argument (
51
- "-c" ,
52
- "--config" ,
53
- required = True ,
54
- is_config_file = True ,
55
- help = "config file path" ,
56
- )
57
- parser_trainer_args (parser )
58
- parser_collector_args_offpolicy (parser )
59
- parser_env_args (parser )
60
- parser_loss_args (parser , algorithm = "DQN" )
61
- parser_model_args_discrete (parser )
62
- parser_recorder_args (parser )
63
- parser_replay_args (parser )
64
- return parser
65
-
38
+ from torchrl .trainers .helpers .trainers import make_trainer , TrainerConfig
39
+
40
+
41
+ config_fields = [
42
+ (config_field .name , config_field .type , config_field )
43
+ for config_cls in (
44
+ TrainerConfig ,
45
+ OffPolicyCollectorConfig ,
46
+ EnvConfig ,
47
+ LossConfig ,
48
+ DiscreteModelConfig ,
49
+ RecorderConfig ,
50
+ ReplayArgsConfig ,
51
+ )
52
+ for config_field in dataclasses .fields (config_cls )
53
+ ]
54
+ Config = dataclasses .make_dataclass (cls_name = "Config" , fields = config_fields )
55
+ cs = ConfigStore .instance ()
56
+ cs .store (name = "config" , node = Config )
66
57
67
- parser = make_args ()
68
58
59
+ @hydra .main (version_base = None , config_path = None , config_name = "config" )
60
+ def main (cfg : "DictConfig" ):
69
61
70
- def main (args ):
71
62
from torch .utils .tensorboard import SummaryWriter
72
63
73
- args = correct_for_frame_skip (args )
64
+ cfg = correct_for_frame_skip (cfg )
74
65
75
- if not isinstance (args .reward_scaling , float ):
76
- args .reward_scaling = 1.0
66
+ if not isinstance (cfg .reward_scaling , float ):
67
+ cfg .reward_scaling = 1.0
77
68
78
69
device = (
79
70
torch .device ("cpu" )
@@ -84,41 +75,42 @@ def main(args):
84
75
exp_name = "_" .join (
85
76
[
86
77
"DQN" ,
87
- args .exp_name ,
78
+ cfg .exp_name ,
88
79
str (uuid .uuid4 ())[:8 ],
89
80
datetime .now ().strftime ("%y_%m_%d-%H_%M_%S" ),
90
81
]
91
82
)
92
83
writer = SummaryWriter (f"dqn_logging/{ exp_name } " )
93
- video_tag = exp_name if args .record_video else ""
84
+ video_tag = exp_name if cfg .record_video else ""
94
85
95
86
stats = None
96
- if not args .vecnorm and args .norm_stats :
97
- proof_env = transformed_env_constructor (args = args , use_env_creator = False )()
87
+ if not cfg .vecnorm and cfg .norm_stats :
88
+ proof_env = transformed_env_constructor (cfg = cfg , use_env_creator = False )()
98
89
stats = get_stats_random_rollout (
99
- args , proof_env , key = "next_pixels" if args .from_pixels else None
90
+ cfg , proof_env , key = "next_pixels" if cfg .from_pixels else None
100
91
)
101
92
# make sure proof_env is closed
102
93
proof_env .close ()
103
- elif args .from_pixels :
94
+ elif cfg .from_pixels :
104
95
stats = {"loc" : 0.5 , "scale" : 0.5 }
105
96
proof_env = transformed_env_constructor (
106
- args = args , use_env_creator = False , stats = stats
97
+ cfg = cfg , use_env_creator = False , stats = stats
107
98
)()
108
99
model = make_dqn_actor (
109
100
proof_environment = proof_env ,
110
- args = args ,
101
+ cfg = cfg ,
111
102
device = device ,
112
103
)
113
104
114
- loss_module , target_net_updater = make_dqn_loss (model , args )
115
- model_explore = EGreedyWrapper (model , annealing_num_steps = args .annealing_frames ).to (
105
+ loss_module , target_net_updater = make_dqn_loss (model , cfg )
106
+ model_explore = EGreedyWrapper (model , annealing_num_steps = cfg .annealing_frames ).to (
116
107
device
117
108
)
109
+
118
110
action_dim_gsde , state_dim_gsde = None , None
119
111
proof_env .close ()
120
112
create_env_fn = parallel_env_constructor (
121
- args = args ,
113
+ cfg = cfg ,
122
114
stats = stats ,
123
115
action_dim_gsde = action_dim_gsde ,
124
116
state_dim_gsde = state_dim_gsde ,
@@ -127,26 +119,26 @@ def main(args):
127
119
collector = make_collector_offpolicy (
128
120
make_env = create_env_fn ,
129
121
actor_model_explore = model_explore ,
130
- args = args ,
122
+ cfg = cfg ,
131
123
# make_env_kwargs=[
132
124
# {"device": device} if device >= 0 else {}
133
125
# for device in args.env_rendering_devices
134
126
# ],
135
127
)
136
128
137
- replay_buffer = make_replay_buffer (device , args )
129
+ replay_buffer = make_replay_buffer (device , cfg )
138
130
139
131
recorder = transformed_env_constructor (
140
- args ,
132
+ cfg ,
141
133
video_tag = video_tag ,
142
134
norm_obs_only = True ,
143
135
stats = stats ,
144
136
writer = writer ,
145
137
)()
146
138
147
139
# remove video recorder from recorder to have matching state_dict keys
148
- if args .record_video :
149
- recorder_rm = TransformedEnv (recorder .env )
140
+ if cfg .record_video :
141
+ recorder_rm = TransformedEnv (recorder .base_env )
150
142
for transform in recorder .transform :
151
143
if not isinstance (transform , VideoRecorder ):
152
144
recorder_rm .append_transform (transform )
@@ -171,10 +163,10 @@ def main(args):
171
163
loss_module ,
172
164
recorder ,
173
165
target_net_updater ,
174
- model_explore ,
166
+ model ,
175
167
replay_buffer ,
176
168
writer ,
177
- args ,
169
+ cfg ,
178
170
)
179
171
180
172
def select_keys (batch ):
@@ -191,13 +183,12 @@ def select_keys(batch):
191
183
192
184
trainer .register_op ("batch_process" , select_keys )
193
185
194
- final_seed = collector .set_seed (args .seed )
195
- print (f"init seed: { args .seed } , final seed: { final_seed } " )
186
+ final_seed = collector .set_seed (cfg .seed )
187
+ print (f"init seed: { cfg .seed } , final seed: { final_seed } " )
196
188
197
189
trainer .train ()
198
190
return (writer .log_dir , trainer ._log_dict , trainer .state_dict ())
199
191
200
192
201
193
if __name__ == "__main__" :
202
- args = parser .parse_args ()
203
- main (args )
194
+ main ()
0 commit comments