[rllib] Support torch device and distributions.#4553
[rllib] Support torch device and distributions.#4553ericl merged 8 commits intoray-project:masterfrom
Conversation
|
Can one of the admins verify this patch? |
|
Test FAILed. |
|
The test environment is using pytorch-cpu. Should I change it to pytorch? |
|
Test FAILed. |
|
Hm, does that make a gpu available? AFAIK, none of our tests are currently run with GPUs. What is the limitation of pytorch-cpu? |
|
Test FAILed. |
|
Never mind. CPU case should be handled correctly now. |
| logits, _, values, _ = policy_model( | ||
| {SampleBatch.CUR_OBS: observations}, []) | ||
| logits = logits | ||
| values = values |
There was a problem hiding this comment.
These two lines seem redundant?
| log_probs = log_probs.sum(-1) | ||
| self.entropy = dist.entropy().mean().cpu() | ||
| self.pi_err = -advantages.dot(log_probs.reshape(-1)).cpu() | ||
| self.value_err = F.mse_loss(values.reshape(-1), value_targets).cpu() |
There was a problem hiding this comment.
Is it necessary to move the loss to cpu?
| action_distribution_cls=dist_class) | ||
|
|
||
| @override(PolicyGraph) | ||
| def compute_gradients(self, postprocessed_batch): |
There was a problem hiding this comment.
Could we keep this method impl in TorchPolicyGraph and have options to clip grads / return extra stats as generic functionality?
There was a problem hiding this comment.
I can make an abstract method for getting extra grad info.
For grad clipping, I can either make config a property of TorchPolicyGraph so compute_gradients() in TorchPolicyGraph would know whether to clip grad or make an abstract method extra_grad_processing(self, grad) in TorchPolicyGraph and let subclass process the grad. What's your preference?
There was a problem hiding this comment.
TFPolicyGraph offers the extra grad processing method, so it's probably better to do that for consistency.
| dist = self.dist_class(logits) | ||
| log_probs = dist.logp(actions) | ||
| if len(log_probs.shape) > 1: | ||
| log_probs = log_probs.sum(-1) |
There was a problem hiding this comment.
In which cases does log_probs have a nontrivial second dimension? Wondering if the reshape() is sufficient?
Same question for A3CLoss.
There was a problem hiding this comment.
I haven't tried others but Normal distribution's log_prob returns vector of shape (n,) where n is the number of gaussians. I can absorb this into TorchDiagGaussian.
| self.action_space = action_space | ||
| self.lock = Lock() | ||
| self._model = model | ||
| cuda_devices = os.environ['CUDA_VISIBLE_DEVICES'].split(',') |
There was a problem hiding this comment.
could simply check if bool(os.environ.get("CUDA_VISIBLE_DEVICES"))
| return self.dist.sample() | ||
|
|
||
|
|
||
| class TorchDistributionWrapper(ActionDistribution): |
There was a problem hiding this comment.
Could the torch classes for action dist be placed in a separate file?
| from collections import namedtuple | ||
| import distutils.version | ||
| import tensorflow as tf | ||
| import torch |
There was a problem hiding this comment.
Let's make sure to not import torch unless we hit a torch=true code path, to avoid acquiring a hard dependency on torch.
| return Deterministic, action_space.shape[0] | ||
| elif isinstance(action_space, gym.spaces.Discrete): | ||
| return Categorical, action_space.n | ||
| dist = TorchCategorical if torch else Categorical |
There was a problem hiding this comment.
could we add if torch: raise NotImplementedError for the other dist types?
|
Thanks for opening this! Overall looks solid; have some comments. |
| logits, _, vf, state = model_out | ||
| actions = F.softmax(logits, dim=1).multinomial(1).squeeze(0) | ||
| return (actions.numpy(), [h.numpy() for h in state], | ||
| action_dist = self._action_dist_cls(logits) |
There was a problem hiding this comment.
Now that A2C/PG presumably work with continuous action spaces, you can add two entries to run_rllib_tests.sh to check they work on Pendulum-v0:
Similar to the CartPole-v0 entries:
https://github.com/ray-project/ray/blob/master/ci/jenkins_tests/run_rllib_tests.sh#L407
|
Test FAILed. |
|
Test FAILed. |
|
Test FAILed. |
ericl
left a comment
There was a problem hiding this comment.
LGTM. One thing I'm wondering is if it's possible to test that GPU mode works properly, without a real GPU. It seems easy to forget a cpu().
|
Is it possible to spin up a gpu instance every night and run a nightly test build on it to catch some errors? |
|
Hm, potentially. I'm not sure if travis supports GPU instances though. @FlyClover tests look good, but you have a couple lint changes: https://travis-ci.com/ray-project/ray/jobs/192067752 |
|
Test FAILed. |
|
Test FAILed. |
|
Merged, thanks! |
What do these changes do?
Related issue number
Closes #4333
Linter
scripts/format.shto lint the changes in this PR.