Skip to content

Commit 1e9941b

Browse files
committed
Re-introducing recorder_log_keys (#225)
1 parent d3e022c commit 1e9941b

File tree

3 files changed

+7
-3
lines changed

3 files changed

+7
-3
lines changed

torchrl/collectors/collectors.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040

4141
_TIMEOUT = 1.0
4242
_MIN_TIMEOUT = 1e-3 # should be several orders of magnitude inferior wrt time spent collecting a trajectory
43-
_MAX_IDLE_COUNT = int(os.environ.get("MAX_IDLE_COUNT", 10))
43+
_MAX_IDLE_COUNT = int(os.environ.get("MAX_IDLE_COUNT", 1000))
4444

4545

4646
class RandomPolicy:

torchrl/trainers/helpers/recorder.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
55

6-
from dataclasses import dataclass
6+
from dataclasses import dataclass, field
7+
from typing import Any
78

89

910
@dataclass
@@ -19,3 +20,5 @@ class RecorderConfig:
1920
# number of batch collections in between two collections of validation rollouts. Default=1000.
2021
record_frames: int = 1000
2122
# number of steps in validation rollouts. " "Default=1000.
23+
recorder_log_keys: Any = field(default_factory=lambda: ["reward"])
24+
# Keys to log in the recorder

torchrl/trainers/helpers/trainers.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,7 @@ def make_trainer(
248248
policy_exploration=policy_exploration,
249249
recorder=recorder,
250250
record_interval=cfg.record_interval,
251+
log_keys=cfg.recorder_log_keys,
251252
)
252253
trainer.register_op(
253254
"post_steps_log",
@@ -262,7 +263,7 @@ def make_trainer(
262263
record_interval=cfg.record_interval,
263264
exploration_mode="random",
264265
suffix="exploration",
265-
out_key="r_evaluation_exploration",
266+
out_keys={"reward": "r_evaluation_exploration"},
266267
)
267268
trainer.register_op(
268269
"post_steps_log",

0 commit comments

Comments
 (0)