@@ -1765,11 +1765,9 @@ class _MultiDataCollector(DataCollectorBase):
1765
1765
.. warning:: `policy_factory` is currently not compatible with multiprocessed data
1766
1766
collectors.
1767
1767
1768
- frames_per_batch (int, optional): A keyword-only argument representing the
1769
- total number of elements in a batch.
1770
- frames_per_batch_worker (Sequence[int], optional): A keyword-only argument representing the
1771
- number of elements in a batch for each worker. This argument is mutually exclusive with `frames_per_batch`.
1772
- If `frames_per_batch_worker` is specified, `frames_per_batch` is computed as the sum across all workers.
1768
+ frames_per_batch (int, Sequence[int]): A keyword-only argument representing the
1769
+ total number of elements in a batch. If a sequence is provided, represents the number of elements in a
1770
+ batch per worker. Total number of elements in a batch is then the sum over the sequence.
1773
1771
total_frames (int, optional): A keyword-only argument representing the
1774
1772
total number of frames returned by the collector
1775
1773
during its lifespan. If the ``total_frames`` is not divisible by
@@ -1926,8 +1924,7 @@ def __init__(
1926
1924
policy_factory : Callable [[], Callable ]
1927
1925
| list [Callable [[], Callable ]]
1928
1926
| None = None ,
1929
- frames_per_batch : int | None = None ,
1930
- frames_per_batch_worker : Sequence [int ] | None = None ,
1927
+ frames_per_batch : int | Sequence [int ],
1931
1928
total_frames : int | None = - 1 ,
1932
1929
device : DEVICE_TYPING | Sequence [DEVICE_TYPING ] | None = None ,
1933
1930
storing_device : DEVICE_TYPING | Sequence [DEVICE_TYPING ] | None = None ,
@@ -1963,26 +1960,21 @@ def __init__(
1963
1960
self .closed = True
1964
1961
self .num_workers = len (create_env_fn )
1965
1962
1966
- if (frames_per_batch is None and frames_per_batch_worker is None ) or (
1967
- frames_per_batch is not None and frames_per_batch_worker is not None
1968
- ):
1969
- raise ValueError (
1970
- "`frames_per_batch` and `frames_per_batch_worker` are mutually exclusive and one need to be set."
1971
- f"Got { frames_per_batch = } and { frames_per_batch_worker = } "
1972
- )
1973
-
1974
- if frames_per_batch is None :
1975
- frames_per_batch = sum (frames_per_batch_worker )
1976
-
1977
1963
if (
1978
- frames_per_batch_worker is not None
1979
- and len (frames_per_batch_worker ) != self .num_workers
1964
+ isinstance ( frames_per_batch , Sequence )
1965
+ and len (frames_per_batch ) != self .num_workers
1980
1966
):
1981
1967
raise ValueError (
1982
- "If specified, `frames_per_batch_worker` should contain exactly one value per worker."
1983
- f"Got { len (frames_per_batch_worker )} values for { self .num_workers } workers."
1968
+ "If `frames_per_batch` is provided as a sequence, it should contain exactly one value per worker."
1969
+ f"Got { len (frames_per_batch )} values for { self .num_workers } workers."
1984
1970
)
1985
- self ._frames_per_batch_worker = frames_per_batch_worker
1971
+
1972
+ self ._frames_per_batch = frames_per_batch
1973
+ total_frames_per_batch = (
1974
+ sum (frames_per_batch )
1975
+ if isinstance (frames_per_batch , Sequence )
1976
+ else frames_per_batch
1977
+ )
1986
1978
1987
1979
self .set_truncated = set_truncated
1988
1980
self .num_sub_threads = num_sub_threads
@@ -2101,11 +2093,11 @@ def __init__(
2101
2093
if total_frames is None or total_frames < 0 :
2102
2094
total_frames = float ("inf" )
2103
2095
else :
2104
- remainder = total_frames % frames_per_batch
2096
+ remainder = total_frames % total_frames_per_batch
2105
2097
if remainder != 0 and RL_WARNINGS :
2106
2098
warnings .warn (
2107
- f"total_frames ({ total_frames } ) is not exactly divisible by frames_per_batch ({ frames_per_batch } ). "
2108
- f"This means { frames_per_batch - remainder } additional frames will be collected. "
2099
+ f"total_frames ({ total_frames } ) is not exactly divisible by frames_per_batch ({ total_frames_per_batch } ). "
2100
+ f"This means { total_frames_per_batch - remainder } additional frames will be collected. "
2109
2101
"To silence this message, set the environment variable RL_WARNINGS to False."
2110
2102
)
2111
2103
self .total_frames = (
@@ -2116,7 +2108,8 @@ def __init__(
2116
2108
self .max_frames_per_traj = (
2117
2109
int (max_frames_per_traj ) if max_frames_per_traj is not None else 0
2118
2110
)
2119
- self .requested_frames_per_batch = int (frames_per_batch )
2111
+
2112
+ self .requested_frames_per_batch = total_frames_per_batch
2120
2113
self .reset_when_done = reset_when_done
2121
2114
if split_trajs is None :
2122
2115
split_trajs = False
@@ -2798,8 +2791,8 @@ def update_policy_weights_(
2798
2791
)
2799
2792
2800
2793
def frames_per_batch_worker (self , worker_idx : int | None ) -> int :
2801
- if worker_idx is not None and self ._frames_per_batch_worker is not None :
2802
- return self ._frames_per_batch_worker [worker_idx ]
2794
+ if worker_idx is not None and isinstance ( self ._frames_per_batch , Sequence ) :
2795
+ return self ._frames_per_batch [worker_idx ]
2803
2796
if self .requested_frames_per_batch % self .num_workers != 0 and RL_WARNINGS :
2804
2797
warnings .warn (
2805
2798
f"frames_per_batch { self .requested_frames_per_batch } is not exactly divisible by the number of collector workers { self .num_workers } ,"
@@ -2931,7 +2924,7 @@ def iterator(self) -> Iterator[TensorDictBase]:
2931
2924
self ._frames += sum (
2932
2925
[
2933
2926
self .frames_per_batch_worker (worker_idx )
2934
- for worker_idx in range (len ( self .num_workers ) )
2927
+ for worker_idx in range (self .num_workers )
2935
2928
]
2936
2929
)
2937
2930
continue
0 commit comments