-
Notifications
You must be signed in to change notification settings - Fork 393
[Algorithm] SOTA discrete offline CQL #3098
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[Algorithm] SOTA discrete offline CQL #3098
Conversation
Sync branch
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/rl/3098
Note: Links to docs will display an error until the docs builds have been completed. ❌ 4 New Failures, 2 Cancelled Jobs, 3 Unrelated FailuresAs of commit 27a065a with merge base 3f10cb1 ( NEW FAILURES - The following jobs have failed:
CANCELLED JOBS - The following jobs were cancelled. Please retry:
BROKEN TRUNK - The following jobs failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
sota-check/run_discrete_cql.sh
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added sota-check
@@ -195,6 +195,49 @@ def make_offline_replay_buffer(rb_cfg): | |||
return data | |||
|
|||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Replay buffer from a custom minari dataset
loss_module = DiscreteCQLLoss( | ||
model, | ||
loss_function=loss_cfg.loss_function, | ||
action_space="categorical", | ||
delay_value=True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated make_discrete_loss to follow torchrl documentation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we're good to go when the linter is fixed!
@@ -30,7 +30,7 @@ logger: | |||
eval_steps: 200 | |||
mode: online | |||
eval_iter: 1000 | |||
video: False | |||
video: True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we want this?
Description
This PR solves issue #3097 by providing a state-of-the-art (SOTA) implementation for discrete offline Conservative Q-Learning (CQL) within the repository.
MinariExperienceReplay
class for loading experiences from a custom CartPole dataset into the offline replay buffer.Motivation and Context
This change was required to complete the set of SOTA CQL implementations in the repository, specifically by adding support for discrete offline CQL. Previously, only the online discrete CQL and both online/offline continuous CQL were available. By implementing discrete offline CQL, the repository now covers all major CQL benchmarks, enabling users to run and benchmark offline RL in discrete action spaces natively within torchrl.
The change also brings the repository in line with other top RL codebases and helps the community reproducibly benchmark and compare discrete offline RL algorithms.
This PR closes #3097.
Types of changes
What types of changes does your code introduce? Remove all that do not apply:
Checklist
Go over all the following points, and put an
x
in all the boxes that apply.If you are unsure about any of these, don't hesitate to ask. We are here to help!