Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

prototype jax with ddpg #187

Merged
merged 32 commits into from
Jul 12, 2022
Merged

prototype jax with ddpg #187

merged 32 commits into from
Jul 12, 2022

Conversation

vwxyzjn
Copy link
Owner

@vwxyzjn vwxyzjn commented May 29, 2022

Description

Types of changes

  • New algorithm

Checklist:

  • I've read the CONTRIBUTION guide (required).
  • I have ensured pre-commit run --all-files passes (required).
  • I have updated the documentation and previewed the changes via mkdocs serve.
  • I have updated the tests accordingly (if applicable).

If you are adding new algorithms or your change could result in performance difference, you may need to (re-)run tracked experiments. See #137 as an example PR.

  • I have contacted @vwxyzjn to obtain access to the openrlbenchmark W&B team (required).
  • I have tracked applicable experiments in openrlbenchmark/cleanrl with --capture-video flag toggled on (required).
  • I have added additional documentation and previewed the changes via mkdocs serve.
    • I have created a table comparing my results against those from reputable sources (i.e., the original paper or other reference implementation).
    • I have added the learning curves (in PNG format with width=500 and height=300).
    • I have added links to the tracked experiments.
  • I have updated the tests accordingly (if applicable).

@gitpod-io
Copy link

gitpod-io bot commented May 29, 2022

@vercel
Copy link

vercel bot commented May 29, 2022

The latest updates on your projects. Learn more about Vercel for Git ↗︎

Name Status Preview Updated
cleanrl ✅ Ready (Inspect) Visit Preview Jul 12, 2022 at 9:17PM (UTC)

@vwxyzjn
Copy link
Owner Author

vwxyzjn commented Jun 24, 2022

@dosssman @huxiao09 I seem to have gotten CleanRL's DDPG + Jax working: about 5x speed up for free.

image

@vwxyzjn
Copy link
Owner Author

vwxyzjn commented Jun 26, 2022

@dosssman and @ikostrikov could you help review this, please? I am unfamiliar with JAX so might be coding up things wrong or have really bad format...

@ikostrikov
Copy link

Looks good to me! The only thing I would add is TrainState:
https://flax.readthedocs.io/en/latest/flax.training.html#flax.training.train_state.TrainState

This was referenced Jun 26, 2022
target_params are initialized with the same RNG key
@vwxyzjn
Copy link
Owner Author

vwxyzjn commented Jun 29, 2022

@dosssman @yooceii could you give a review, please? The changes have been finalized.

Copy link
Collaborator

@dosssman dosssman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not very familiar with Jax, so can't really suggest any quality improvement.
Beside that, its relatively easy to understand, and the algorithm logic looks good to me.
Great work as always.

cleanrl/ddpg_continuous_action_jax.py Show resolved Hide resolved
@vwxyzjn
Copy link
Owner Author

vwxyzjn commented Jun 30, 2022

@dosssman, @yooceii, and @joaogui1 this is ready for review with docs (https://cleanrl-git-jax-ddpg-vwxyzjn.vercel.app/rl-algorithms/ddpg/#ddpg_continuous_action_jaxpy, note some of the links don't work until this PR is merged).

x = nn.relu(x)
x = nn.Dense(self.action_dim)(x)
x = nn.tanh(x)
x * self.action_scale + self.action_bias
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be x = x * self.action_scale + self.action_bias, no?

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this great catch! I am fixing this and will merge after CI passes.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants