Skip to content

Commit 6ed1326

Browse files
committed
Merging kwargs under single frames_per_batch
1 parent 0e10633 commit 6ed1326

File tree

2 files changed

+25
-43
lines changed

2 files changed

+25
-43
lines changed

test/test_collector.py

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1591,7 +1591,7 @@ def env_fn():
15911591
collector = MultiaSyncDataCollector(
15921592
create_env_fn=[env_fn for _ in range(num_workers)],
15931593
policy=policy,
1594-
frames_per_batch_worker=frames_per_batch_worker,
1594+
frames_per_batch=frames_per_batch_worker,
15951595
max_frames_per_traj=1000,
15961596
total_frames=frames_per_batch * 100,
15971597
)
@@ -1631,23 +1631,12 @@ def env_fn():
16311631

16321632
with pytest.raises(
16331633
ValueError,
1634-
match="`frames_per_batch` and `frames_per_batch_worker` are mutually exclusive and one need to be set.",
1634+
match="If `frames_per_batch` is provided as a sequence, it should contain exactly one value per worker.",
16351635
):
16361636
collector = MultiSyncDataCollector(
16371637
create_env_fn=[env_fn for _ in range(num_workers)],
16381638
policy=policy,
1639-
max_frames_per_traj=1000,
1640-
total_frames=frames_per_batch * 100,
1641-
)
1642-
1643-
with pytest.raises(
1644-
ValueError,
1645-
match="If specified, `frames_per_batch_worker` should contain exactly one value per worker.",
1646-
):
1647-
collector = MultiSyncDataCollector(
1648-
create_env_fn=[env_fn for _ in range(num_workers)],
1649-
policy=policy,
1650-
frames_per_batch_worker=frames_per_batch_worker[:-1],
1639+
frames_per_batch=frames_per_batch_worker[:-1],
16511640
max_frames_per_traj=1000,
16521641
total_frames=frames_per_batch * 100,
16531642
)

torchrl/collectors/collectors.py

Lines changed: 22 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1765,11 +1765,9 @@ class _MultiDataCollector(DataCollectorBase):
17651765
.. warning:: `policy_factory` is currently not compatible with multiprocessed data
17661766
collectors.
17671767
1768-
frames_per_batch (int, optional): A keyword-only argument representing the
1769-
total number of elements in a batch.
1770-
frames_per_batch_worker (Sequence[int], optional): A keyword-only argument representing the
1771-
number of elements in a batch for each worker. This argument is mutually exclusive with `frames_per_batch`.
1772-
If `frames_per_batch_worker` is specified, `frames_per_batch` is computed as the sum across all workers.
1768+
frames_per_batch (int, Sequence[int]): A keyword-only argument representing the
1769+
total number of elements in a batch. If a sequence is provided, represents the number of elements in a
1770+
batch per worker. Total number of elements in a batch is then the sum over the sequence.
17731771
total_frames (int, optional): A keyword-only argument representing the
17741772
total number of frames returned by the collector
17751773
during its lifespan. If the ``total_frames`` is not divisible by
@@ -1926,8 +1924,7 @@ def __init__(
19261924
policy_factory: Callable[[], Callable]
19271925
| list[Callable[[], Callable]]
19281926
| None = None,
1929-
frames_per_batch: int | None = None,
1930-
frames_per_batch_worker: Sequence[int] | None = None,
1927+
frames_per_batch: int | Sequence[int],
19311928
total_frames: int | None = -1,
19321929
device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None,
19331930
storing_device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None,
@@ -1963,26 +1960,21 @@ def __init__(
19631960
self.closed = True
19641961
self.num_workers = len(create_env_fn)
19651962

1966-
if (frames_per_batch is None and frames_per_batch_worker is None) or (
1967-
frames_per_batch is not None and frames_per_batch_worker is not None
1968-
):
1969-
raise ValueError(
1970-
"`frames_per_batch` and `frames_per_batch_worker` are mutually exclusive and one need to be set."
1971-
f"Got {frames_per_batch=} and {frames_per_batch_worker=}"
1972-
)
1973-
1974-
if frames_per_batch is None:
1975-
frames_per_batch = sum(frames_per_batch_worker)
1976-
19771963
if (
1978-
frames_per_batch_worker is not None
1979-
and len(frames_per_batch_worker) != self.num_workers
1964+
isinstance(frames_per_batch, Sequence)
1965+
and len(frames_per_batch) != self.num_workers
19801966
):
19811967
raise ValueError(
1982-
"If specified, `frames_per_batch_worker` should contain exactly one value per worker."
1983-
f"Got {len(frames_per_batch_worker)} values for {self.num_workers} workers."
1968+
"If `frames_per_batch` is provided as a sequence, it should contain exactly one value per worker."
1969+
f"Got {len(frames_per_batch)} values for {self.num_workers} workers."
19841970
)
1985-
self._frames_per_batch_worker = frames_per_batch_worker
1971+
1972+
self._frames_per_batch = frames_per_batch
1973+
total_frames_per_batch = (
1974+
sum(frames_per_batch)
1975+
if isinstance(frames_per_batch, Sequence)
1976+
else frames_per_batch
1977+
)
19861978

19871979
self.set_truncated = set_truncated
19881980
self.num_sub_threads = num_sub_threads
@@ -2101,11 +2093,11 @@ def __init__(
21012093
if total_frames is None or total_frames < 0:
21022094
total_frames = float("inf")
21032095
else:
2104-
remainder = total_frames % frames_per_batch
2096+
remainder = total_frames % total_frames_per_batch
21052097
if remainder != 0 and RL_WARNINGS:
21062098
warnings.warn(
2107-
f"total_frames ({total_frames}) is not exactly divisible by frames_per_batch ({frames_per_batch}). "
2108-
f"This means {frames_per_batch - remainder} additional frames will be collected. "
2099+
f"total_frames ({total_frames}) is not exactly divisible by frames_per_batch ({total_frames_per_batch}). "
2100+
f"This means {total_frames_per_batch - remainder} additional frames will be collected. "
21092101
"To silence this message, set the environment variable RL_WARNINGS to False."
21102102
)
21112103
self.total_frames = (
@@ -2116,7 +2108,8 @@ def __init__(
21162108
self.max_frames_per_traj = (
21172109
int(max_frames_per_traj) if max_frames_per_traj is not None else 0
21182110
)
2119-
self.requested_frames_per_batch = int(frames_per_batch)
2111+
2112+
self.requested_frames_per_batch = total_frames_per_batch
21202113
self.reset_when_done = reset_when_done
21212114
if split_trajs is None:
21222115
split_trajs = False
@@ -2798,8 +2791,8 @@ def update_policy_weights_(
27982791
)
27992792

28002793
def frames_per_batch_worker(self, worker_idx: int | None) -> int:
2801-
if worker_idx is not None and self._frames_per_batch_worker is not None:
2802-
return self._frames_per_batch_worker[worker_idx]
2794+
if worker_idx is not None and isinstance(self._frames_per_batch, Sequence):
2795+
return self._frames_per_batch[worker_idx]
28032796
if self.requested_frames_per_batch % self.num_workers != 0 and RL_WARNINGS:
28042797
warnings.warn(
28052798
f"frames_per_batch {self.requested_frames_per_batch} is not exactly divisible by the number of collector workers {self.num_workers},"

0 commit comments

Comments
 (0)