Skip to content

Episode reward logged with SAC not correct #726

@nicoguertler

Description

@nicoguertler

Describe the bug
When using SAC with a tensorboard log, the episode reward (appears as "episode_reward" in tensorboard) is not calculated correctly. It is reset to zero before the last step of the environment and then summed up during the consecutive steps. It would be correct to reset it to zero after the last step of an episode and then start summing the rewards of the new episode up.

Code example
This custom environment returns a random number as reward in the last time step of an episode and one minus this random number in the first time step. In all other time steps it gives zero reward. Hence, the rewards in an episode always add up to 1. In tensorboard, however, the "episode_reward" fluctuates around 1 due to the behaviour described above. The correct behaviour would be an episode_reward which is always equal to 1.

import numpy as np

import gym
from gym import spaces

from stable_baselines import SAC

class CustomEnv(gym.Env): 
    metadata = {'render.modes': ['human']}

    def __init__(self):
        super(CustomEnv, self).__init__()

        self.N = 10

        self.action_space = spaces.Box(low = 0, high = 1, shape = (1,), dtype = np.float32)

        self.observation_space = spaces.Discrete(self.N)
        self.state = 0

    def step(self, action):
        if self.state == self.N - 1:
            done = True
            reward = self.ep_random_number
        elif self.state == 0:
            done = False
            reward = 1.0 - self.ep_random_number
        else:
            done = False
            reward = 0
        self.state += 1
        info = {}
        return np.array(self.state), reward, done, info

    def reset(self):
        self.state = 0
        self.ep_random_number = np.random.rand()
        return np.array(self.state)

    def render(self, mode='human'):
        pass

    def close (self):
        pass

if __name__ == "__main__":
    env = CustomEnv()
    model = SAC("MlpPolicy", env, verbose=1, tensorboard_log = "./tensorboard")
    model.learn(total_timesteps = 500)

System Info

  • Repository cloned from GitHub

Additional context
I think the problem is related to this part of the learn method of SAC

if writer is not None:
# Write reward per episode to tensorboard
ep_reward = np.array([reward]).reshape((1, -1))
ep_done = np.array([done]).reshape((1, -1))
tf_util.total_episode_reward_logger(self.episode_reward, ep_reward,
ep_done, writer, self.num_timesteps)

which uses this logger function

def total_episode_reward_logger(rew_acc, rewards, masks, writer, steps):

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions