Skip to content

[RLlib] Issue #9437 (PyTorch converts to CPU tensor, even if on GPU). #9497

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion rllib/agents/dqn/dqn_tf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from ray.rllib.policy.tf_policy_template import build_tf_policy
from ray.rllib.utils.error import UnsupportedSpaceException
from ray.rllib.utils.exploration import ParameterNoise
from ray.rllib.utils.numpy import convert_to_numpy
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.tf_ops import huber_loss, reduce_mean_ignore_inf, \
minimize_and_clip
Expand Down Expand Up @@ -378,7 +379,8 @@ def postprocess_nstep_and_prio(policy, batch, other_agent=None, episode=None):
batch[SampleBatch.REWARDS], batch[SampleBatch.NEXT_OBS],
batch[SampleBatch.DONES], batch[PRIO_WEIGHTS])
new_priorities = (
np.abs(td_errors) + policy.config["prioritized_replay_eps"])
np.abs(convert_to_numpy(td_errors)) +
policy.config["prioritized_replay_eps"])
batch.data[PRIO_WEIGHTS] = new_priorities

return batch
Expand Down
2 changes: 1 addition & 1 deletion rllib/agents/es/es_torch_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def _compute_actions(policy,
observation = policy.observation_filter(
observation[None], update=update)

observation = convert_to_torch_tensor(observation)
observation = convert_to_torch_tensor(observation, policy.device)
dist_inputs, _ = policy.model({
SampleBatch.CUR_OBS: observation
}, [], None)
Expand Down
13 changes: 7 additions & 6 deletions rllib/agents/ppo/ppo_torch_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,15 +188,16 @@ def __init__(self, obs_space, action_space, config):
def value(ob, prev_action, prev_reward, *state):
model_out, _ = self.model({
SampleBatch.CUR_OBS: convert_to_torch_tensor(
np.asarray([ob])),
np.asarray([ob]), self.device),
SampleBatch.PREV_ACTIONS: convert_to_torch_tensor(
np.asarray([prev_action])),
np.asarray([prev_action]), self.device),
SampleBatch.PREV_REWARDS: convert_to_torch_tensor(
np.asarray([prev_reward])),
np.asarray([prev_reward]), self.device),
"is_training": False,
}, [convert_to_torch_tensor(np.asarray([s])) for s in state],
convert_to_torch_tensor(
np.asarray([1])))
}, [convert_to_torch_tensor(np.asarray([s]), self.device) for
s in state],
convert_to_torch_tensor(
np.asarray([1]), self.device))
return self.model.value_function()[0]

else:
Expand Down
11 changes: 7 additions & 4 deletions rllib/examples/centralized_critic.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,13 @@ def centralized_critic_postprocessing(policy,
# overwrite default VF prediction with the central VF
if args.torch:
sample_batch[SampleBatch.VF_PREDS] = policy.compute_central_vf(
convert_to_torch_tensor(sample_batch[SampleBatch.CUR_OBS]),
convert_to_torch_tensor(sample_batch[OPPONENT_OBS]),
convert_to_torch_tensor(sample_batch[OPPONENT_ACTION])). \
detach().numpy()
convert_to_torch_tensor(
sample_batch[SampleBatch.CUR_OBS], policy.device),
convert_to_torch_tensor(
sample_batch[OPPONENT_OBS], policy.device),
convert_to_torch_tensor(
sample_batch[OPPONENT_ACTION], policy.device)) \
.detach().numpy()
else:
sample_batch[SampleBatch.VF_PREDS] = policy.compute_central_vf(
sample_batch[SampleBatch.CUR_OBS], sample_batch[OPPONENT_OBS],
Expand Down
3 changes: 2 additions & 1 deletion rllib/models/tf/tf_action_dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,12 +178,13 @@ def __init__(self, inputs, model=None, temperature=1.0):
assert temperature >= 0.0
self.dist = tfp.distributions.RelaxedOneHotCategorical(
temperature=temperature, logits=inputs)
self.probs = tf.nn.softmax(self.dist._distribution.logits)
super().__init__(inputs, model)

@override(ActionDistribution)
def deterministic_sample(self):
# Return the dist object's prob values.
return self.dist._distribution.probs
return self.probs

@override(ActionDistribution)
def logp(self, x):
Expand Down
6 changes: 4 additions & 2 deletions rllib/policy/torch_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,8 @@ def compute_actions(
input_dict[SampleBatch.PREV_REWARDS] = \
np.asarray(prev_reward_batch)
state_batches = [
convert_to_torch_tensor(s) for s in (state_batches or [])
convert_to_torch_tensor(s, self.device)
for s in (state_batches or [])
]
actions, state_out, extra_fetches, logp = \
self._compute_action_helper(
Expand Down Expand Up @@ -556,7 +557,8 @@ def import_model_from_h5(self, import_file: str) -> None:

def _lazy_tensor_dict(self, postprocessed_batch):
train_batch = UsageTrackingDict(postprocessed_batch)
train_batch.set_get_interceptor(convert_to_torch_tensor)
train_batch.set_get_interceptor(functools.partial(
convert_to_torch_tensor, device=self.device))
return train_batch


Expand Down
4 changes: 3 additions & 1 deletion rllib/utils/exploration/per_worker_epsilon_greedy.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,10 @@ def __init__(self, action_space, *, framework, num_workers, worker_index,
if worker_index > 0:
# From page 5 of https://arxiv.org/pdf/1803.00933.pdf
alpha, eps, i = 7, 0.4, worker_index - 1
num_workers_minus_1 = float(num_workers - 1) \
if num_workers > 1 else 1.0
epsilon_schedule = ConstantSchedule(
eps**(1 + i / float(num_workers - 1) * alpha),
eps**(1 + (i / num_workers_minus_1) * alpha),
framework=framework)
# Local worker should have zero exploration so that eval
# rollouts run properly.
Expand Down
4 changes: 3 additions & 1 deletion rllib/utils/exploration/per_worker_gaussian_noise.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ def __init__(self, action_space, *, framework, num_workers, worker_index,
# Use a fixed, different epsilon per worker. See: Ape-X paper.
if num_workers > 0:
if worker_index > 0:
exponent = (1 + worker_index / float(num_workers - 1) * 7)
num_workers_minus_1 = float(num_workers - 1) \
if num_workers > 1 else 1.0
exponent = (1 + (worker_index / num_workers_minus_1) * 7)
scale_schedule = ConstantSchedule(
0.4**exponent, framework=framework)
# Local worker should have zero exploration so that eval
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@ def __init__(self, action_space, *, framework, num_workers, worker_index,
# Use a fixed, different epsilon per worker. See: Ape-X paper.
if num_workers > 0:
if worker_index > 0:
exponent = (1 + worker_index / float(num_workers - 1) * 7)
num_workers_minus_1 = float(num_workers - 1) \
if num_workers > 1 else 1.0
exponent = (1 + (worker_index / num_workers_minus_1) * 7)
scale_schedule = ConstantSchedule(
0.4**exponent, framework=framework)
# Local worker should have zero exploration so that eval
Expand Down
6 changes: 3 additions & 3 deletions rllib/utils/torch_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,10 +110,10 @@ def mapping(item):
return tree.map_structure(mapping, stats)


def convert_to_torch_tensor(stats, device=None):
def convert_to_torch_tensor(x, device=None):
"""Converts any struct to torch.Tensors.

stats (any): Any (possibly nested) struct, the values in which will be
x (any): Any (possibly nested) struct, the values in which will be
converted and returned as a new struct with all leaves converted
to torch tensors.

Expand All @@ -137,7 +137,7 @@ def mapping(item):
tensor = tensor.float()
return tensor if device is None else tensor.to(device)

return tree.map_structure(mapping, stats)
return tree.map_structure(mapping, x)


def atanh(x):
Expand Down