[rllib] Torch framework incorrectly converts numpy array observations to GPU #9437
Description
What is the problem?
The TorchPolicy seems to ignore the fact that observation tensors need to be on the GPU if the policy is running on the GPU at lines:
ray/rllib/policy/torch_policy.py
Line 559 in b5a6c57
and
ray/rllib/policy/torch_policy.py
Line 152 in b5a6c57
This can be fixed by altering both of these calls to pass self.device
to the convert_to_torch_tensor. Alternatively, this can be fixed in the actual model call (which would preserve the data on CPU if the user wanted to do something custom in the model).
The code fails with the error:
File "/home/david/train_rllib.py", line 16, in <module>
result = trainer.train()
File "/home/david/miniconda3/envs/ray/lib/python3.7/site-packages/ray/rllib/agents/trainer.py", line 511, in train
raise e
File "/home/david/miniconda3/envs/ray/lib/python3.7/site-packages/ray/rllib/agents/trainer.py", line 497, in train
result = Trainable.train(self)
File "/home/david/miniconda3/envs/ray/lib/python3.7/site-packages/ray/tune/trainable.py", line 319, in train
result = self.step()
File "/home/david/miniconda3/envs/ray/lib/python3.7/site-packages/ray/rllib/agents/trainer_template.py", line 110, in step
res = next(self.train_exec_impl)
File "/home/david/miniconda3/envs/ray/lib/python3.7/site-packages/ray/util/iter.py", line 758, in __next__
return next(self.built_iterator)
File "/home/david/miniconda3/envs/ray/lib/python3.7/site-packages/ray/util/iter.py", line 785, in apply_foreach
for item in it:
File "/home/david/miniconda3/envs/ray/lib/python3.7/site-packages/ray/util/iter.py", line 785, in apply_foreach
for item in it:
File "/home/david/miniconda3/envs/ray/lib/python3.7/site-packages/ray/util/iter.py", line 845, in apply_filter
for item in it:
File "/home/david/miniconda3/envs/ray/lib/python3.7/site-packages/ray/util/iter.py", line 845, in apply_filter
for item in it:
File "/home/david/miniconda3/envs/ray/lib/python3.7/site-packages/ray/util/iter.py", line 785, in apply_foreach
for item in it:
File "/home/david/miniconda3/envs/ray/lib/python3.7/site-packages/ray/util/iter.py", line 785, in apply_foreach
for item in it:
File "/home/david/miniconda3/envs/ray/lib/python3.7/site-packages/ray/util/iter.py", line 793, in apply_foreach
result = fn(item)
File "/home/david/miniconda3/envs/ray/lib/python3.7/site-packages/ray/rllib/execution/train_ops.py", line 65, in __call__
self.sgd_minibatch_size, [])
File "/home/david/miniconda3/envs/ray/lib/python3.7/site-packages/ray/rllib/utils/sgd.py", line 117, in do_minibatch_sgd
}, minibatch.count)))[policy_id]
File "/home/david/miniconda3/envs/ray/lib/python3.7/site-packages/ray/rllib/evaluation/rollout_worker.py", line 737, in learn_on_batch
info_out[pid] = policy.learn_on_batch(batch)
File "/home/david/miniconda3/envs/ray/lib/python3.7/site-packages/ray/rllib/policy/torch_policy.py", line 324, in learn_on_batch
self._loss(self, self.model, self.dist_class, train_batch))
File "/home/david/miniconda3/envs/ray/lib/python3.7/site-packages/ray/rllib/agents/ppo/ppo_torch_policy.py", line 115, in ppo_surrogate_loss
logits, state = model.from_batch(train_batch)
File "/home/david/miniconda3/envs/ray/lib/python3.7/site-packages/ray/rllib/models/modelv2.py", line 241, in from_batch
return self.__call__(input_dict, states, train_batch.get("seq_lens"))
File "/home/david/miniconda3/envs/ray/lib/python3.7/site-packages/ray/rllib/models/modelv2.py", line 197, in __call__
res = self.forward(restored, state or [], seq_lens)
File "/home/david/miniconda3/envs/ray/lib/python3.7/site-packages/ray/rllib/models/torch/fcnet.py", line 118, in forward
self._features = self._hidden_layers(self._last_flat_in)
File "/home/david/miniconda3/envs/ray/lib/python3.7/site-packages/torch/nn/modules/module.py", line 550, in __call__
result = self.forward(*input, **kwargs)
File "/home/david/miniconda3/envs/ray/lib/python3.7/site-packages/torch/nn/modules/container.py", line 100, in forward
input = module(input)
File "/home/david/miniconda3/envs/ray/lib/python3.7/site-packages/torch/nn/modules/module.py", line 550, in __call__
result = self.forward(*input, **kwargs)
File "/home/david/miniconda3/envs/ray/lib/python3.7/site-packages/ray/rllib/models/torch/misc.py", line 110, in forward
return self._model(x)
File "/home/david/miniconda3/envs/ray/lib/python3.7/site-packages/torch/nn/modules/module.py", line 550, in __call__
result = self.forward(*input, **kwargs)
File "/home/david/miniconda3/envs/ray/lib/python3.7/site-packages/torch/nn/modules/container.py", line 100, in forward
input = module(input)
File "/home/david/miniconda3/envs/ray/lib/python3.7/site-packages/torch/nn/modules/module.py", line 550, in __call__
result = self.forward(*input, **kwargs)
File "/home/david/miniconda3/envs/ray/lib/python3.7/site-packages/torch/nn/modules/linear.py", line 87, in forward
return F.linear(input, self.weight, self.bias)
File "/home/david/miniconda3/envs/ray/lib/python3.7/site-packages/torch/nn/functional.py", line 1610, in linear
ret = torch.addmm(bias, input, weight.t())
RuntimeError: Expected object of device type cuda but got device type cpu for argument #2 'mat1' in call to _th_addmm
Ray version and other system information (Python version, TensorFlow version, OS):
Ray version: 0.9.0dev0 (July 12 nightly)
Torch version: 1.5.1
OS: Ubuntu 18.04
Reproduction (REQUIRED)
Please provide a script that can be run to reproduce the issue. The script should have no external library dependencies (i.e., use fake or mock data / environments):
import ray
import ray.rllib.agents.ppo as ppo
from ray.tune.logger import pretty_print
ray.init()
config = ppo.DEFAULT_CONFIG.copy()
config["num_gpus"] = 1
config["num_workers"] = 2
config["framework"] = "torch"
trainer = ppo.PPOTrainer(config=config, env="CartPole-v0")
# Can optionally call trainer.restore(path) to load a checkpoint.
for i in range(1000):
# Perform one iteration of training the policy with PPO
result = trainer.train()
print(pretty_print(result))
if i % 100 == 0:
checkpoint = trainer.save()
print("checkpoint saved at", checkpoint)
- I have verified my script runs in a clean environment and reproduces the issue.
- I have verified the issue also occurs with the latest wheels.