Skip to content

Commit ac46c37

Browse files
authored
Fix partial minibatch computation in GAIL dataset. (#724)
* Fix partial minibatch computation in GAIL dataset. * Updated changelog. * Added name to bottom of changelog.
1 parent a4efff0 commit ac46c37

File tree

2 files changed

+7
-6
lines changed

2 files changed

+7
-6
lines changed

docs/misc/changelog.rst

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,16 +22,16 @@ Breaking Changes:
2222

2323
- Algorithms no longer import from each other, and ``common`` does not import from algorithms.
2424
- ``a2c/utils.py`` removed and split into other files:
25-
26-
- common/tf_util.py: ``sample``, ``calc_entropy``, ``mse``, ``avg_norm``, ``total_episode_reward_logger``,
27-
``q_explained_variance``, ``gradient_add``, ``avg_norm``, ``check_shape``,
25+
26+
- common/tf_util.py: ``sample``, ``calc_entropy``, ``mse``, ``avg_norm``, ``total_episode_reward_logger``,
27+
``q_explained_variance``, ``gradient_add``, ``avg_norm``, ``check_shape``,
2828
``seq_to_batch``, ``batch_to_seq``.
2929
- common/tf_layers.py: ``conv``, ``linear``, ``lstm``, ``_ln``, ``lnlstm``, ``conv_to_fc``, ``ortho_init``.
3030
- a2c/a2c.py: ``discount_with_dones``.
3131
- acer/acer_simple.py: ``get_by_index``, ``EpisodeStats``.
3232
- common/schedules.py: ``constant``, ``linear_schedule``, ``middle_drop``, ``double_linear_con``, ``double_middle_drop``,
3333
``SCHEDULES``, ``Scheduler``.
34-
34+
3535
- ``trpo_mpi/utils.py`` functions moved (``traj_segment_generator`` moved to ``common/runners.py``, ``flatten_lists`` to ``common/misc_util.py``).
3636
- ``ppo2/ppo2.py`` functions moved (``safe_mean`` to ``common/math_util.py``, ``constfn`` and ``get_schedule_fn`` to ``common/schedules.py``).
3737
- ``sac/policies.py`` function ``mlp`` moved to ``common/tf_layers.py``.
@@ -69,6 +69,7 @@ Bug Fixes:
6969
- Fixed a bug in ``BaseRLModel`` when seeding vectorized environments. (@NeoExtended)
7070
- Fixed ``num_timesteps`` computation to be consistent between algorithms (updated after ``env.step()``)
7171
Only ``TRPO`` and ``PPO1`` update it differently (after synchronization) because they rely on MPI
72+
- Fixed partial minibatch computation in ExpertDataset (@richardwu)
7273

7374
Deprecations:
7475
^^^^^^^^^^^^^
@@ -652,4 +653,4 @@ Thanks to @bjmuld @iambenzo @iandanforth @r7vme @brendenpetersen @huvar @abhiskk
652653
@XMaster96 @kantneel @Pastafarianist @GerardMaggiolino @PatrickWalter214 @yutingsz @sc420 @Aaahh @billtubbs
653654
@Miffyli @dwiel @miguelrass @qxcv @jaberkow @eavelardev @ruifeng96150 @pedrohbtp @srivatsankrishnan @evilsocket
654655
@MarvineGothic @jdossgollin @SyllogismRXS @rusu24edward @jbulow @Antymon @seheevic @justinkterry @edbeeching
655-
@flodorner @KuKuXia @NeoExtended @solliet @mmcenta
656+
@flodorner @KuKuXia @NeoExtended @solliet @mmcenta @richardwu

stable_baselines/gail/dataset/dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ def __init__(self, indices, observations, actions, batch_size, n_workers=1,
204204
self.n_minibatches = len(indices) // batch_size
205205
# Add a partial minibatch, for instance
206206
# when there is not enough samples
207-
if partial_minibatch and len(indices) / batch_size > 0:
207+
if partial_minibatch and len(indices) % batch_size > 0:
208208
self.n_minibatches += 1
209209
self.batch_size = batch_size
210210
self.observations = observations

0 commit comments

Comments
 (0)