-
Notifications
You must be signed in to change notification settings - Fork 6.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[rllib] Support torch device and distributions. #4553
Conversation
Can one of the admins verify this patch? |
Test FAILed. |
The test environment is using pytorch-cpu. Should I change it to pytorch? |
Test FAILed. |
Hm, does that make a gpu available? AFAIK, none of our tests are currently run with GPUs. What is the limitation of pytorch-cpu? |
Test FAILed. |
Never mind. CPU case should be handled correctly now. |
logits, _, values, _ = policy_model( | ||
{SampleBatch.CUR_OBS: observations}, []) | ||
logits = logits | ||
values = values |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These two lines seem redundant?
log_probs = log_probs.sum(-1) | ||
self.entropy = dist.entropy().mean().cpu() | ||
self.pi_err = -advantages.dot(log_probs.reshape(-1)).cpu() | ||
self.value_err = F.mse_loss(values.reshape(-1), value_targets).cpu() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it necessary to move the loss to cpu?
action_distribution_cls=dist_class) | ||
|
||
@override(PolicyGraph) | ||
def compute_gradients(self, postprocessed_batch): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we keep this method impl in TorchPolicyGraph and have options to clip grads / return extra stats as generic functionality?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can make an abstract method for getting extra grad info.
For grad clipping, I can either make config a property of TorchPolicyGraph so compute_gradients() in TorchPolicyGraph would know whether to clip grad or make an abstract method extra_grad_processing(self, grad) in TorchPolicyGraph and let subclass process the grad. What's your preference?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TFPolicyGraph offers the extra grad processing method, so it's probably better to do that for consistency.
dist = self.dist_class(logits) | ||
log_probs = dist.logp(actions) | ||
if len(log_probs.shape) > 1: | ||
log_probs = log_probs.sum(-1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In which cases does log_probs have a nontrivial second dimension? Wondering if the reshape() is sufficient?
Same question for A3CLoss.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I haven't tried others but Normal distribution's log_prob returns vector of shape (n,) where n is the number of gaussians. I can absorb this into TorchDiagGaussian.
""" | ||
self.observation_space = observation_space | ||
self.action_space = action_space | ||
self.lock = Lock() | ||
self._model = model | ||
cuda_devices = os.environ['CUDA_VISIBLE_DEVICES'].split(',') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could simply check if bool(os.environ.get("CUDA_VISIBLE_DEVICES"))
@@ -285,3 +286,40 @@ def kl(self, other): | |||
@override(ActionDistribution) | |||
def _build_sample_op(self): | |||
return self.dist.sample() | |||
|
|||
|
|||
class TorchDistributionWrapper(ActionDistribution): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could the torch classes for action dist be placed in a separate file?
@@ -5,6 +5,7 @@ | |||
from collections import namedtuple | |||
import distutils.version | |||
import tensorflow as tf | |||
import torch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's make sure to not import torch unless we hit a torch=true code path, to avoid acquiring a hard dependency on torch.
@@ -120,7 +121,8 @@ def get_action_dist(action_space, config, dist_type=None): | |||
elif dist_type == "deterministic": | |||
return Deterministic, action_space.shape[0] | |||
elif isinstance(action_space, gym.spaces.Discrete): | |||
return Categorical, action_space.n | |||
dist = TorchCategorical if torch else Categorical |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could we add if torch: raise NotImplementedError
for the other dist types?
Thanks for opening this! Overall looks solid; have some comments. |
model_out = self._model({"obs": ob}, state_batches) | ||
logits, _, vf, state = model_out | ||
actions = F.softmax(logits, dim=1).multinomial(1).squeeze(0) | ||
return (actions.numpy(), [h.numpy() for h in state], | ||
action_dist = self._action_dist_cls(logits) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now that A2C/PG presumably work with continuous action spaces, you can add two entries to run_rllib_tests.sh to check they work on Pendulum-v0:
Similar to the CartPole-v0 entries:
https://github.com/ray-project/ray/blob/master/ci/jenkins_tests/run_rllib_tests.sh#L407
Test FAILed. |
Test FAILed. |
Test FAILed. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. One thing I'm wondering is if it's possible to test that GPU mode works properly, without a real GPU. It seems easy to forget a cpu().
Is it possible to spin up a gpu instance every night and run a nightly test build on it to catch some errors? |
Hm, potentially. I'm not sure if travis supports GPU instances though. @FlyClover tests look good, but you have a couple lint changes: https://travis-ci.com/ray-project/ray/jobs/192067752 |
Test FAILed. |
Test FAILed. |
Merged, thanks! |
What do these changes do?
Related issue number
Closes #4333
Linter
scripts/format.sh
to lint the changes in this PR.