Skip to content

DDPG implementation fails to learn well on at least five MuJoCo-v2 envs for all three noise types. I report steps to reproduce and learning curve plots [and show that PPO2 seems to work fine]. #938

Open
@DanielTakeshi

Description

@DanielTakeshi

Dear @pzhokhov @matthiasplappert @christopherhesse et al.,

Thank you for providing an implementation of DDPG. However, I have been unable to get it to learn well on the standard MuJoCo environments by running the provided command in the README (and with related commands). Here are the steps to reproduce. I apologize for the length of the post, but I want to show what I tried to reduce ambiguity and to potentially counter the potential argument that it might be due to bad hyperparameters.

First, here's the machine I am using with relevant versions of software:

  • Ubuntu 18.04
  • MuJoCo 2.0
  • Create a clean Python 3.6.7 virtualenv and install all the required stuff with pip install commands. I'm using TensorFlow 1.13, gym 0.12.1, and mujoco-py 2.0.2.2. All appear to be installed correctly and show no signs of error.
  • Use baselines master branch, commit ba2b017

Next, here are the set of commands to run. I'm splitting these into three groups based on the three types of noise we can inject into our policy.

Group 1: Parameter Noise

I first decided to take the default command provided in the README because I assumed that hyperparameters here have been tuned to save users the time and compute needed for expensive hyperparameter sweeps.

python -m baselines.run --alg=ddpg --env=Ant-v2 --num_timesteps=1e6
python -m baselines.run --alg=ddpg --env=HalfCheetah-v2 --num_timesteps=1e6
python -m baselines.run --alg=ddpg --env=Hopper-v2 --num_timesteps=1e6
python -m baselines.run --alg=ddpg --env=Swimmer-v2 --num_timesteps=1e6
python -m baselines.run --alg=ddpg --env=Walker2d-v2 --num_timesteps=1e6

I use my plotting code to get plots. Here it is:

import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
plt.style.use('seaborn-darkgrid')
import argparse
import csv
import pandas
import os
import sys
import pickle
import numpy as np
from os.path import join

# matplotlib
titlesize = 33
xsize = 30
ysize = 30
ticksize = 25
legendsize = 25
error_region_alpha = 0.25


def smoothed(x, w):
    """Smooth x by averaging over sliding windows of w, assuming sufficient length.
    """
    if len(x) <= w:
        return x
    smooth = []
    for i in range(1, w):
        smooth.append( np.mean(x[0:i]) )
    for i in range(w, len(x)+1):
        smooth.append( np.mean(x[i-w:i]) )
    assert len(x) == len(smooth), "lengths: {}, {}".format(len(x), len(smooth))
    return np.array(smooth)


def _get_stuff_from_monitor(mon):
    """Get stuff from `monitor` log files.

    Monitor files are named `0.envidx.monitor.csv` and have one line for each
    episode that finished in that CPU 'core', with the reward, length (number
    of steps) and the time (in seconds). The lengths are not cumulative, but
    time is cumulative.
    """
    scores = []
    steps  = []
    times  = []
    with open(mon, 'r') as csv_file:
        csv_reader = csv.reader(csv_file, delimiter=',')
        line_count = 0
        for csv_row in csv_reader:
            # First two lines don't contain interesting stuff.
            if line_count == 0 or line_count == 1:
                line_count += 1
                continue
            scores.append(float(csv_row[0]))
            steps.append(int(csv_row[1]))
            times.append(float(csv_row[2]))
            line_count += 1
    print("finished: {}".format(mon))
    return scores, steps, times


def plot(args):
    """Load monitor curves and the progress csv file. And plot from those.
    """
    nrows, ncols = 1, 2
    fig, ax = plt.subplots(nrows, ncols, squeeze=False, sharey=True, figsize=(11*ncols,7*nrows))
    title = args.title

    # Global statistics across all monitors
    scores_all = []
    steps_all = []
    times_all = []
    total_train_steps = 0
    train_hours = 0

    monitors = sorted(
        [x for x in os.listdir(args.path) if 'monitor.csv' in x and '.swp' not in x]
    )
    progfile = join(args.path,'progress.csv')

    # First row, info from all the monitors, i.e., number of CPUs.
    for env_idx,mon in enumerate(monitors):
        monitor_path = join(args.path, mon)
        scores, steps, times = _get_stuff_from_monitor(monitor_path)

        # Now process to see as a function of episodes and training steps, etc.
        num_episodes = len(scores)
        tr_episodes = np.arange(num_episodes)
        tr_steps = np.cumsum(steps)
        tr_times = np.array(times) / 60.0 # get it in minutes

        # Plot for individual monitors.
        envlabel = 'env {}'.format(env_idx)
        sm_10 = smoothed(scores, w=10)
        ax[0,0].plot(tr_steps, sm_10, label=envlabel+'; avg {:.1f} last {:.1f}'.format(
                np.mean(sm_10), sm_10[-1]))
        sm_100 = smoothed(scores, w=100)
        ax[0,1].plot(tr_times, sm_100, label=envlabel+'; avg {:.1f} last {:.1f}'.format(
                np.mean(sm_100), sm_100[-1]))

        # Handle global stuff.
        total_train_steps += tr_steps[-1]
        train_hours = max(train_hours, tr_times[-1] / 60.0)

    # Bells and whistles
    for row in range(nrows):
        for col in range(ncols):
            ax[row,col].set_ylabel("Scores", fontsize=30)
            ax[row,col].tick_params(axis='x', labelsize=25)
            ax[row,col].tick_params(axis='y', labelsize=25)
            leg = ax[row,col].legend(loc="best", ncol=1, prop={'size':25})
            for legobj in leg.legendHandles:
                legobj.set_linewidth(5.0)
    ax[0,0].set_title(title+', Smoothed (w=10)', fontsize=titlesize)
    ax[0,0].set_xlabel("Train Steps (total {})".format(total_train_steps), fontsize=xsize)
    ax[0,1].set_title(title+', Smoothed (w=100)', fontsize=titlesize)
    ax[0,1].set_xlabel("Train Time (in Hours {:.2f})".format(train_hours), fontsize=xsize)
    plt.tight_layout()
    figname = '{}.png'.format(title)
    plt.savefig(figname)
    print("\nJust saved: {}".format(figname))


if __name__ == "__main__":
    pp = argparse.ArgumentParser()
    pp.add_argument('--path', type=str)
    pp.add_argument('--title', type=str)
    args = pp.parse_args()
    plot(args)

To use this code, just run python [script].py --path [PATH] --title [TITLE]. Feed in the path to the 0.0.monitor.csv file (i.e., that's the /tmp/openai-[DATE] directory) and some title. I did this for all five environments above and got these results:

ant00

halfcheetah00

hopper00

swimmer00

walker00

None of these curves appear to be getting better than random performance. Maaaaaybe Ant-v2 is getting better than random performance, but it seems to be stuck at 0 and many, many papers report values far above 0. Perhaps it has something to do with the number of environments? I briefly tried increasing the number of parallel environments to 8 but that did not seem to work:

python -m baselines.run --alg=ddpg --env=Ant-v2 --num_timesteps=1e6 --num_env=8
python -m baselines.run --alg=ddpg --env=HalfCheetah-v2 --num_timesteps=1e6 --num_env=8

ant-08envs

halfcheetah-08envs

Incidentally, it seems like having N environments means that the actual number of steps increases by a factor of N. This is different behavior from PPO2 where increasing N does not change the number of actual time steps total at the end; increasing N for PPO2 means each individual environment can execute fewer steps.

PS: for some of the above plots, I did not run to exactly 1M steps, i.e., I terminated it near the end if it was clear that the algorithm was not learning well.

Group 2: Gaussian Noise

All right, next I decided to avoid parameter space noise. In the TD3 paper which used DDPG, the authors used Gaussian noise with standard deviation 0.1. I decided to try that, keeping all other settings fixed:

python -m baselines.run --alg=ddpg --env=Ant-v2 --num_timesteps=1e6 --noise_type=normal_0.1
python -m baselines.run --alg=ddpg --env=HalfCheetah-v2 --num_timesteps=1e6 --noise_type=normal_0.1
python -m baselines.run --alg=ddpg --env=Hopper-v2 --num_timesteps=1e6 --noise_type=normal_0.1
python -m baselines.run --alg=ddpg --env=Swimmer-v2 --num_timesteps=1e6 --noise_type=normal_0.1
python -m baselines.run --alg=ddpg --env=Walker2d-v2 --num_timesteps=1e6 --noise_type=normal_0.1

Here are the results:

ant01

halfcheetah01

hopper01

swimmer01

walker01

Once again, it seems like there is no learning happening. The performance appears to be similar to the parameter space noise case.

Group 3: OU Noise (along with tau=0.001)

I decided to run one last batch of commands, this time with the original OU noise. After carefully checking the TD3 paper, and the DDPG directory from the July 27, 2017 commit when DDPG was first released, I saw that the tau parameter back then was set at 0.001. Now for some reason it is 0.01. DeepMind used 0.001 so I decided to try OU noise with tau 0.001. This appears to be the only hyperparameter difference that I can see from this code base and the values used by DeepMind.

python -m baselines.run --alg=ddpg --env=Ant-v2 --num_timesteps=1e6  --noise_type=ou_0.2 --tau=0.001
python -m baselines.run --alg=ddpg --env=HalfCheetah-v2 --num_timesteps=1e6 --noise_type=ou_0.2 --tau=0.001
python -m baselines.run --alg=ddpg --env=Hopper-v2 --num_timesteps=1e6 --noise_type=ou_0.2 --tau=0.001
python -m baselines.run --alg=ddpg --env=Swimmer-v2 --num_timesteps=1e6 --noise_type=ou_0.2 --tau=0.001
python -m baselines.run --alg=ddpg --env=Walker2d-v2 --num_timesteps=1e6 --noise_type=ou_0.2 --tau=0.001

Results:

ant02

halfcheetah02

hopper02

swimmer02

walker02

(The swimmer curve looks like it's going up, but the reward is lower as compared to the other two plots.)

The results I am getting seem to differ from the blog post here which shows HalfCheetah rewards of at least +1500, and much larger depending on the parameter noise setting, and for 2M steps. It might be a hyperparameter issue, but I'm not sure. In particular, notice that the hyperparameters here (for the most part) match those from the DDPG or TD3 papers.

The TD3 paper reports these results:

td3_paper

The TD3 paper says it used DDPG (presumably from OpenAI baselines as of late 2017?) and then "Our DDPG" above is when the author tuned hyperparameters. Both get final rewards that are far higher than what I am seeing, and we are all using 1M training steps here. Unfortunately, from reading the TD3 code base, it is not clear which commit from baselines was used for the results.

The paper above does not report results for Swimmer, so I looked at the "Benchmarking DeepRL" paper, which says DDPG on Swimmer should get 85 +/- 1.8, and this is far higher than the Swimmer results I am getting above.

I suspect that there must be have been some change to the code that caused it to somehow either stop working well or be exorbitantly sensitive to hyperparameters? For example, maybe the process of removing MPI caused some unexpected results? Or it could be due to MuJoCo environments v1 to v2, since the TD3 paper used MuJoCo v1 environments, but as this report suggests, RL performance should be similar. Notice that all the reward curves there for PPO show increasing reward, whereas I'm just seeing stagnation and noise for DDPG.

This is perhaps relevant to the following issue reports:

all of which have noticed issues with DDPG. If the fix is found, then the above can probably all be closed.

Hopefully in the spirit of my previous report on DQN here, we can resolve this issue together. Does anyone have any general suggestions or ideas about the potential causes? At this point I am unable to confidently use the DDPG code because it does not pass standard benchmarks. My previous issue report about DQN suggests that it could be an environment processing issue. Is the code processing the MuJoCo environments in a simliar way as in July 2017? Do the PPO2 results apper to be fine, but the DDPG results off? Is there a difference with how the two algorithms process observations and normalize data?

I'm happy to help investigate this if you have ideas on what might be the root cause. I only report this issue because having highly tuned algorithms and hyper-parameters ready to go "off-the-shelf" greatly helps the entire research community by accelerating research cycles and reducing the need to write our own error-prone implementations of algorithms.

Thanks!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions