Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix the replay buffer overflow issue #75

Merged
merged 2 commits into from
Sep 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 30 additions & 8 deletions qdax/core/neuroevolution/buffers/buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,20 +343,42 @@ def insert(self, transitions: Transition) -> ReplayBuffer:
num_transitions = flattened_transitions.shape[0]
max_replay_size = self.buffer_size

new_current_position = self.current_position + num_transitions
new_current_size = jnp.minimum(
self.current_size + num_transitions, max_replay_size
)
# Make sure update is not larger than the maximum replay size.
if num_transitions > max_replay_size:
raise ValueError(
"Trying to insert a batch of samples larger than the maximum replay "
f"size. num_samples: {num_transitions}, "
f"max replay size {max_replay_size}"
)

# get current position
position = self.current_position

# check if there is an overlap
roll = jnp.minimum(0, max_replay_size - position - num_transitions)

# roll the data to avoid overlap
data = jnp.roll(self.data, roll, axis=0)

# update the position accordingly
new_position = position + roll

# replace old data by the new one
new_data = jax.lax.dynamic_update_slice_in_dim(
self.data,
data,
flattened_transitions,
start_index=self.current_position % max_replay_size,
start_index=new_position,
axis=0,
)

# update the position and the size
new_position = (new_position + num_transitions) % max_replay_size
new_size = jnp.minimum(self.current_size + num_transitions, max_replay_size)

# update the replay buffer
replay_buffer = self.replace(
current_position=new_current_position,
current_size=new_current_size,
current_position=new_position,
current_size=new_size,
data=new_data,
)

Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
96 changes: 96 additions & 0 deletions tests/core_test/neuroevolution_test/buffers_test/buffer_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import jax
import jax.numpy as jnp
import pytest

from qdax.core.neuroevolution.buffers.buffer import QDTransition, ReplayBuffer


def test_insert() -> None:
"""Tests if an insert of a dummy transition results in a buffer of size 1"""
observation_size = 2
action_size = 8
buffer_size = 10

# Initialize buffer
dummy_transition = QDTransition.init_dummy(
observation_dim=observation_size,
action_dim=action_size,
descriptor_dim=0,
)
replay_buffer = ReplayBuffer.init(
buffer_size=buffer_size, transition=dummy_transition
)

replay_buffer = replay_buffer.insert(dummy_transition)
pytest.assume(replay_buffer.current_size == 1)


def test_insert_batch() -> None:
"""Tests if inserting transitions such that we exceed the max size
of the buffer leads to the desired behavior."""
observation_size = 2
action_size = 8
buffer_size = 5

# Initialize buffer
dummy_transition = QDTransition.init_dummy(
observation_dim=observation_size,
action_dim=action_size,
descriptor_dim=0,
)
replay_buffer = ReplayBuffer.init(
buffer_size=buffer_size, transition=dummy_transition
)

simple_transition = jax.tree_map(lambda x: x.repeat(3, axis=0), dummy_transition)
simple_transition = simple_transition.replace(rewards=jnp.arange(3))
data = QDTransition.from_flatten(replay_buffer.data, dummy_transition)
pytest.assume(
jnp.array_equal(data.rewards, jnp.array([jnp.nan] * 5), equal_nan=True).all()
)

replay_buffer = replay_buffer.insert(simple_transition)
data = QDTransition.from_flatten(replay_buffer.data, dummy_transition)
pytest.assume(
jnp.array_equal(
data.rewards, jnp.array([0, 1, 2, jnp.nan, jnp.nan]), equal_nan=True
).all()
)

simple_transition_2 = simple_transition.replace(rewards=jnp.arange(3, 6))
replay_buffer = replay_buffer.insert(simple_transition_2)
data = QDTransition.from_flatten(replay_buffer.data, dummy_transition)
pytest.assume(
jnp.array_equal(data.rewards, jnp.array([1, 2, 3, 4, 5]), equal_nan=True).all()
)


def test_sample() -> None:
"""
Tests if sampled transitions have valid shape.
"""
observation_size = 2
action_size = 8
buffer_size = 5

# Initialize buffer
dummy_transition = QDTransition.init_dummy(
observation_dim=observation_size,
action_dim=action_size,
descriptor_dim=0,
)
replay_buffer = ReplayBuffer.init(
buffer_size=buffer_size, transition=dummy_transition
)

simple_transition = jax.tree_map(lambda x: x.repeat(3, axis=0), dummy_transition)
simple_transition = simple_transition.replace(rewards=jnp.arange(3))

replay_buffer = replay_buffer.insert(simple_transition)
random_key = jax.random.PRNGKey(0)

samples, random_key = replay_buffer.sample(random_key, 3)

samples_shapes = jax.tree_map(lambda x: x.shape, samples)
transition_shapes = jax.tree_map(lambda x: x.shape, simple_transition)
pytest.assume((samples_shapes == transition_shapes))