Skip to content

Commit

Permalink
None as default value for env in HerReplayBuffer.sample + DQN
Browse files Browse the repository at this point in the history
… batch size typing fix (#790)

* `env` to `None` by default in `HerReplayBuffer.sample` (#788)

* Fix DQN batch_size typing

* Fix changelog

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
Co-authored-by: Antonin Raffin <antonin.raffin@dlr.de>
  • Loading branch information
3 people authored Feb 24, 2022
1 parent 13fcb12 commit db5366f
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 7 deletions.
4 changes: 3 additions & 1 deletion docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ Bug Fixes:
with very long keys.)
- Routing all the ``nn.Module`` calls through implicit rather than explict forward as per pytorch guidelines (@manuel-delverme)
- Fixed a bug in ``VecNormalize`` where error occurs when ``norm_obs`` is set to False for environment with dictionary observation (@buoyancy99)
- Set default ``env`` argument to ``None`` in ``HerReplayBuffer.sample`` (@qgallouedec)
- Fix ``batch_size`` typing in ``DQN`` (@qgallouedec)
- Fixed sample normalization in ``DictReplayBuffer`` (@qgallouedec)

Deprecations:
^^^^^^^^^^^^^
Expand Down Expand Up @@ -88,7 +91,6 @@ Bug Fixes:
- Fixed evaluation script for recurrent policies (experimental feature in SB3 contrib)
- Fixed a bug where the observation would be incorrectly detected as non-vectorized instead of throwing an error
- The env checker now properly checks and warns about potential issues for continuous action spaces when the boundaries are too small or when the dtype is not float32
- Fixed sample normalization in ``DictReplayBuffer`` (@qgallouedec)
- Fixed a bug in ``VecFrameStack`` with channel first image envs, where the terminal observation would be wrongly created.

Deprecations:
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/dqn/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def __init__(
learning_rate: Union[float, Schedule] = 1e-4,
buffer_size: int = 1_000_000, # 1e6
learning_starts: int = 50000,
batch_size: Optional[int] = 32,
batch_size: int = 32,
tau: float = 1.0,
gamma: float = 0.99,
train_freq: Union[int, Tuple[int, str]] = 4,
Expand Down
6 changes: 1 addition & 5 deletions stable_baselines3/her/her_replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,11 +192,7 @@ def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = Non
"""
raise NotImplementedError()

def sample(
self,
batch_size: int,
env: Optional[VecNormalize],
) -> DictReplayBufferSamples:
def sample(self, batch_size: int, env: Optional[VecNormalize] = None) -> DictReplayBufferSamples:
"""
Sample function for online sampling of HER transition,
this replaces the "regular" replay buffer ``sample()``
Expand Down

0 comments on commit db5366f

Please sign in to comment.