Skip to content

[rllib] Torch framework incorrectly converts numpy array observations to GPU #9437

Closed
@DavidMChan

Description

What is the problem?

The TorchPolicy seems to ignore the fact that observation tensors need to be on the GPU if the policy is running on the GPU at lines:

train_batch.set_get_interceptor(convert_to_torch_tensor)

and

convert_to_torch_tensor(s) for s in (state_batches or [])

This can be fixed by altering both of these calls to pass self.device to the convert_to_torch_tensor. Alternatively, this can be fixed in the actual model call (which would preserve the data on CPU if the user wanted to do something custom in the model).

The code fails with the error:

File "/home/david/train_rllib.py", line 16, in <module>
    result = trainer.train()
  File "/home/david/miniconda3/envs/ray/lib/python3.7/site-packages/ray/rllib/agents/trainer.py", line 511, in train
    raise e
  File "/home/david/miniconda3/envs/ray/lib/python3.7/site-packages/ray/rllib/agents/trainer.py", line 497, in train
    result = Trainable.train(self)
  File "/home/david/miniconda3/envs/ray/lib/python3.7/site-packages/ray/tune/trainable.py", line 319, in train
    result = self.step()
  File "/home/david/miniconda3/envs/ray/lib/python3.7/site-packages/ray/rllib/agents/trainer_template.py", line 110, in step
    res = next(self.train_exec_impl)
  File "/home/david/miniconda3/envs/ray/lib/python3.7/site-packages/ray/util/iter.py", line 758, in __next__
    return next(self.built_iterator)
  File "/home/david/miniconda3/envs/ray/lib/python3.7/site-packages/ray/util/iter.py", line 785, in apply_foreach
    for item in it:
  File "/home/david/miniconda3/envs/ray/lib/python3.7/site-packages/ray/util/iter.py", line 785, in apply_foreach
    for item in it:
  File "/home/david/miniconda3/envs/ray/lib/python3.7/site-packages/ray/util/iter.py", line 845, in apply_filter
    for item in it:
  File "/home/david/miniconda3/envs/ray/lib/python3.7/site-packages/ray/util/iter.py", line 845, in apply_filter
    for item in it:
  File "/home/david/miniconda3/envs/ray/lib/python3.7/site-packages/ray/util/iter.py", line 785, in apply_foreach
    for item in it:
  File "/home/david/miniconda3/envs/ray/lib/python3.7/site-packages/ray/util/iter.py", line 785, in apply_foreach
    for item in it:
  File "/home/david/miniconda3/envs/ray/lib/python3.7/site-packages/ray/util/iter.py", line 793, in apply_foreach
    result = fn(item)
  File "/home/david/miniconda3/envs/ray/lib/python3.7/site-packages/ray/rllib/execution/train_ops.py", line 65, in __call__
    self.sgd_minibatch_size, [])
  File "/home/david/miniconda3/envs/ray/lib/python3.7/site-packages/ray/rllib/utils/sgd.py", line 117, in do_minibatch_sgd
    }, minibatch.count)))[policy_id]
  File "/home/david/miniconda3/envs/ray/lib/python3.7/site-packages/ray/rllib/evaluation/rollout_worker.py", line 737, in learn_on_batch
    info_out[pid] = policy.learn_on_batch(batch)
  File "/home/david/miniconda3/envs/ray/lib/python3.7/site-packages/ray/rllib/policy/torch_policy.py", line 324, in learn_on_batch
    self._loss(self, self.model, self.dist_class, train_batch))
  File "/home/david/miniconda3/envs/ray/lib/python3.7/site-packages/ray/rllib/agents/ppo/ppo_torch_policy.py", line 115, in ppo_surrogate_loss
    logits, state = model.from_batch(train_batch)
  File "/home/david/miniconda3/envs/ray/lib/python3.7/site-packages/ray/rllib/models/modelv2.py", line 241, in from_batch
    return self.__call__(input_dict, states, train_batch.get("seq_lens"))
  File "/home/david/miniconda3/envs/ray/lib/python3.7/site-packages/ray/rllib/models/modelv2.py", line 197, in __call__
    res = self.forward(restored, state or [], seq_lens)
  File "/home/david/miniconda3/envs/ray/lib/python3.7/site-packages/ray/rllib/models/torch/fcnet.py", line 118, in forward
    self._features = self._hidden_layers(self._last_flat_in)
  File "/home/david/miniconda3/envs/ray/lib/python3.7/site-packages/torch/nn/modules/module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/david/miniconda3/envs/ray/lib/python3.7/site-packages/torch/nn/modules/container.py", line 100, in forward
    input = module(input)
  File "/home/david/miniconda3/envs/ray/lib/python3.7/site-packages/torch/nn/modules/module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/david/miniconda3/envs/ray/lib/python3.7/site-packages/ray/rllib/models/torch/misc.py", line 110, in forward
    return self._model(x)
  File "/home/david/miniconda3/envs/ray/lib/python3.7/site-packages/torch/nn/modules/module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/david/miniconda3/envs/ray/lib/python3.7/site-packages/torch/nn/modules/container.py", line 100, in forward
    input = module(input)
  File "/home/david/miniconda3/envs/ray/lib/python3.7/site-packages/torch/nn/modules/module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/david/miniconda3/envs/ray/lib/python3.7/site-packages/torch/nn/modules/linear.py", line 87, in forward
    return F.linear(input, self.weight, self.bias)
  File "/home/david/miniconda3/envs/ray/lib/python3.7/site-packages/torch/nn/functional.py", line 1610, in linear
    ret = torch.addmm(bias, input, weight.t())
RuntimeError: Expected object of device type cuda but got device type cpu for argument #2 'mat1' in call to _th_addmm

Ray version and other system information (Python version, TensorFlow version, OS):

Ray version: 0.9.0dev0 (July 12 nightly)
Torch version: 1.5.1
OS: Ubuntu 18.04

Reproduction (REQUIRED)

Please provide a script that can be run to reproduce the issue. The script should have no external library dependencies (i.e., use fake or mock data / environments):

import ray
import ray.rllib.agents.ppo as ppo
from ray.tune.logger import pretty_print

ray.init()
config = ppo.DEFAULT_CONFIG.copy()
config["num_gpus"] = 1
config["num_workers"] = 2
config["framework"] = "torch"
trainer = ppo.PPOTrainer(config=config, env="CartPole-v0")

# Can optionally call trainer.restore(path) to load a checkpoint.

for i in range(1000):
    # Perform one iteration of training the policy with PPO
    result = trainer.train()
    print(pretty_print(result))

    if i % 100 == 0:
        checkpoint = trainer.save()
        print("checkpoint saved at", checkpoint)
  • I have verified my script runs in a clean environment and reproduces the issue.
  • I have verified the issue also occurs with the latest wheels.

Metadata

Assignees

Labels

bugSomething that is supposed to be working; but isn'ttriageNeeds triage (eg: priority, bug/not-bug, and owning component)

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions