-
Notifications
You must be signed in to change notification settings - Fork 44
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fix the replay buffer overflow issue (#75)
overload issue in replay buffer, add buffer test + move baselines test to separate folder
- Loading branch information
1 parent
a04fadc
commit 82262e5
Showing
12 changed files
with
126 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
96 changes: 96 additions & 0 deletions
96
tests/core_test/neuroevolution_test/buffers_test/buffer_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
import jax | ||
import jax.numpy as jnp | ||
import pytest | ||
|
||
from qdax.core.neuroevolution.buffers.buffer import QDTransition, ReplayBuffer | ||
|
||
|
||
def test_insert() -> None: | ||
"""Tests if an insert of a dummy transition results in a buffer of size 1""" | ||
observation_size = 2 | ||
action_size = 8 | ||
buffer_size = 10 | ||
|
||
# Initialize buffer | ||
dummy_transition = QDTransition.init_dummy( | ||
observation_dim=observation_size, | ||
action_dim=action_size, | ||
descriptor_dim=0, | ||
) | ||
replay_buffer = ReplayBuffer.init( | ||
buffer_size=buffer_size, transition=dummy_transition | ||
) | ||
|
||
replay_buffer = replay_buffer.insert(dummy_transition) | ||
pytest.assume(replay_buffer.current_size == 1) | ||
|
||
|
||
def test_insert_batch() -> None: | ||
"""Tests if inserting transitions such that we exceed the max size | ||
of the buffer leads to the desired behavior.""" | ||
observation_size = 2 | ||
action_size = 8 | ||
buffer_size = 5 | ||
|
||
# Initialize buffer | ||
dummy_transition = QDTransition.init_dummy( | ||
observation_dim=observation_size, | ||
action_dim=action_size, | ||
descriptor_dim=0, | ||
) | ||
replay_buffer = ReplayBuffer.init( | ||
buffer_size=buffer_size, transition=dummy_transition | ||
) | ||
|
||
simple_transition = jax.tree_map(lambda x: x.repeat(3, axis=0), dummy_transition) | ||
simple_transition = simple_transition.replace(rewards=jnp.arange(3)) | ||
data = QDTransition.from_flatten(replay_buffer.data, dummy_transition) | ||
pytest.assume( | ||
jnp.array_equal(data.rewards, jnp.array([jnp.nan] * 5), equal_nan=True).all() | ||
) | ||
|
||
replay_buffer = replay_buffer.insert(simple_transition) | ||
data = QDTransition.from_flatten(replay_buffer.data, dummy_transition) | ||
pytest.assume( | ||
jnp.array_equal( | ||
data.rewards, jnp.array([0, 1, 2, jnp.nan, jnp.nan]), equal_nan=True | ||
).all() | ||
) | ||
|
||
simple_transition_2 = simple_transition.replace(rewards=jnp.arange(3, 6)) | ||
replay_buffer = replay_buffer.insert(simple_transition_2) | ||
data = QDTransition.from_flatten(replay_buffer.data, dummy_transition) | ||
pytest.assume( | ||
jnp.array_equal(data.rewards, jnp.array([1, 2, 3, 4, 5]), equal_nan=True).all() | ||
) | ||
|
||
|
||
def test_sample() -> None: | ||
""" | ||
Tests if sampled transitions have valid shape. | ||
""" | ||
observation_size = 2 | ||
action_size = 8 | ||
buffer_size = 5 | ||
|
||
# Initialize buffer | ||
dummy_transition = QDTransition.init_dummy( | ||
observation_dim=observation_size, | ||
action_dim=action_size, | ||
descriptor_dim=0, | ||
) | ||
replay_buffer = ReplayBuffer.init( | ||
buffer_size=buffer_size, transition=dummy_transition | ||
) | ||
|
||
simple_transition = jax.tree_map(lambda x: x.repeat(3, axis=0), dummy_transition) | ||
simple_transition = simple_transition.replace(rewards=jnp.arange(3)) | ||
|
||
replay_buffer = replay_buffer.insert(simple_transition) | ||
random_key = jax.random.PRNGKey(0) | ||
|
||
samples, random_key = replay_buffer.sample(random_key, 3) | ||
|
||
samples_shapes = jax.tree_map(lambda x: x.shape, samples) | ||
transition_shapes = jax.tree_map(lambda x: x.shape, simple_transition) | ||
pytest.assume((samples_shapes == transition_shapes)) |