Skip to content

Commit

Permalink
Fix the replay buffer overflow issue (#75)
Browse files Browse the repository at this point in the history
overload issue in replay buffer, add buffer test + move baselines test to separate folder
  • Loading branch information
felixchalumeau authored Sep 6, 2022
1 parent a04fadc commit 82262e5
Show file tree
Hide file tree
Showing 12 changed files with 126 additions and 8 deletions.
38 changes: 30 additions & 8 deletions qdax/core/neuroevolution/buffers/buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,20 +343,42 @@ def insert(self, transitions: Transition) -> ReplayBuffer:
num_transitions = flattened_transitions.shape[0]
max_replay_size = self.buffer_size

new_current_position = self.current_position + num_transitions
new_current_size = jnp.minimum(
self.current_size + num_transitions, max_replay_size
)
# Make sure update is not larger than the maximum replay size.
if num_transitions > max_replay_size:
raise ValueError(
"Trying to insert a batch of samples larger than the maximum replay "
f"size. num_samples: {num_transitions}, "
f"max replay size {max_replay_size}"
)

# get current position
position = self.current_position

# check if there is an overlap
roll = jnp.minimum(0, max_replay_size - position - num_transitions)

# roll the data to avoid overlap
data = jnp.roll(self.data, roll, axis=0)

# update the position accordingly
new_position = position + roll

# replace old data by the new one
new_data = jax.lax.dynamic_update_slice_in_dim(
self.data,
data,
flattened_transitions,
start_index=self.current_position % max_replay_size,
start_index=new_position,
axis=0,
)

# update the position and the size
new_position = (new_position + num_transitions) % max_replay_size
new_size = jnp.minimum(self.current_size + num_transitions, max_replay_size)

# update the replay buffer
replay_buffer = self.replace(
current_position=new_current_position,
current_size=new_current_size,
current_position=new_position,
current_size=new_size,
data=new_data,
)

Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
96 changes: 96 additions & 0 deletions tests/core_test/neuroevolution_test/buffers_test/buffer_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import jax
import jax.numpy as jnp
import pytest

from qdax.core.neuroevolution.buffers.buffer import QDTransition, ReplayBuffer


def test_insert() -> None:
"""Tests if an insert of a dummy transition results in a buffer of size 1"""
observation_size = 2
action_size = 8
buffer_size = 10

# Initialize buffer
dummy_transition = QDTransition.init_dummy(
observation_dim=observation_size,
action_dim=action_size,
descriptor_dim=0,
)
replay_buffer = ReplayBuffer.init(
buffer_size=buffer_size, transition=dummy_transition
)

replay_buffer = replay_buffer.insert(dummy_transition)
pytest.assume(replay_buffer.current_size == 1)


def test_insert_batch() -> None:
"""Tests if inserting transitions such that we exceed the max size
of the buffer leads to the desired behavior."""
observation_size = 2
action_size = 8
buffer_size = 5

# Initialize buffer
dummy_transition = QDTransition.init_dummy(
observation_dim=observation_size,
action_dim=action_size,
descriptor_dim=0,
)
replay_buffer = ReplayBuffer.init(
buffer_size=buffer_size, transition=dummy_transition
)

simple_transition = jax.tree_map(lambda x: x.repeat(3, axis=0), dummy_transition)
simple_transition = simple_transition.replace(rewards=jnp.arange(3))
data = QDTransition.from_flatten(replay_buffer.data, dummy_transition)
pytest.assume(
jnp.array_equal(data.rewards, jnp.array([jnp.nan] * 5), equal_nan=True).all()
)

replay_buffer = replay_buffer.insert(simple_transition)
data = QDTransition.from_flatten(replay_buffer.data, dummy_transition)
pytest.assume(
jnp.array_equal(
data.rewards, jnp.array([0, 1, 2, jnp.nan, jnp.nan]), equal_nan=True
).all()
)

simple_transition_2 = simple_transition.replace(rewards=jnp.arange(3, 6))
replay_buffer = replay_buffer.insert(simple_transition_2)
data = QDTransition.from_flatten(replay_buffer.data, dummy_transition)
pytest.assume(
jnp.array_equal(data.rewards, jnp.array([1, 2, 3, 4, 5]), equal_nan=True).all()
)


def test_sample() -> None:
"""
Tests if sampled transitions have valid shape.
"""
observation_size = 2
action_size = 8
buffer_size = 5

# Initialize buffer
dummy_transition = QDTransition.init_dummy(
observation_dim=observation_size,
action_dim=action_size,
descriptor_dim=0,
)
replay_buffer = ReplayBuffer.init(
buffer_size=buffer_size, transition=dummy_transition
)

simple_transition = jax.tree_map(lambda x: x.repeat(3, axis=0), dummy_transition)
simple_transition = simple_transition.replace(rewards=jnp.arange(3))

replay_buffer = replay_buffer.insert(simple_transition)
random_key = jax.random.PRNGKey(0)

samples, random_key = replay_buffer.sample(random_key, 3)

samples_shapes = jax.tree_map(lambda x: x.shape, samples)
transition_shapes = jax.tree_map(lambda x: x.shape, simple_transition)
pytest.assume((samples_shapes == transition_shapes))

0 comments on commit 82262e5

Please sign in to comment.