Skip to content

Commit

Permalink
Fix PrioritizedReplayBuffer add method (#202)
Browse files Browse the repository at this point in the history
  • Loading branch information
younik authored Oct 22, 2024
1 parent c62b5b0 commit ac0832f
Showing 1 changed file with 9 additions and 4 deletions.
13 changes: 9 additions & 4 deletions src/gfn/containers/replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,11 @@ def __init__(
self.cutoff_distance = cutoff_distance
self.p_norm_distance = p_norm_distance

def _add_objs(self, training_objects: Transitions | Trajectories | tuple[States]):
def _add_objs(
self,
training_objects: Transitions | Trajectories | tuple[States],
terminating_states: States | None = None
):
"""Adds a training object to the buffer."""
# Adds the objects to the buffer.
self.training_objects.extend(training_objects)
Expand All @@ -155,15 +159,16 @@ def _add_objs(self, training_objects: Transitions | Trajectories | tuple[States]

# Add the terminating states to the buffer.
if self.terminating_states is not None:
assert self.terminating_states is not None
self.terminating_states.extend(self.terminating_states)
assert terminating_states is not None
self.terminating_states.extend(terminating_states)

# Sort terminating states by logreward as well.
self.terminating_states = self.terminating_states[ix]
self.terminating_states = self.terminating_states[-self.capacity :]

def add(self, training_objects: Transitions | Trajectories | tuple[States]):
"""Adds a training object to the buffer."""
terminating_states = None
if isinstance(training_objects, tuple):
assert self.objects_type == "states" and self.terminating_states is not None
training_objects, terminating_states = training_objects
Expand All @@ -173,7 +178,7 @@ def add(self, training_objects: Transitions | Trajectories | tuple[States]):

# The buffer isn't full yet.
if len(self.training_objects) < self.capacity:
self._add_objs(training_objects)
self._add_objs(training_objects, terminating_states)

# Our buffer is full and we will prioritize diverse, high reward additions.
else:
Expand Down

0 comments on commit ac0832f

Please sign in to comment.