Skip to content

Commit

Permalink
[RLlib] Issue #9437 (PyTorch converts to CPU tensor, even if on GPU). (
Browse files Browse the repository at this point in the history
  • Loading branch information
sven1977 authored Jul 16, 2020
1 parent 2f67472 commit 935d830
Show file tree
Hide file tree
Showing 10 changed files with 36 additions and 21 deletions.
4 changes: 3 additions & 1 deletion rllib/agents/dqn/dqn_tf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from ray.rllib.policy.tf_policy_template import build_tf_policy
from ray.rllib.utils.error import UnsupportedSpaceException
from ray.rllib.utils.exploration import ParameterNoise
from ray.rllib.utils.numpy import convert_to_numpy
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.tf_ops import huber_loss, reduce_mean_ignore_inf, \
minimize_and_clip
Expand Down Expand Up @@ -378,7 +379,8 @@ def postprocess_nstep_and_prio(policy, batch, other_agent=None, episode=None):
batch[SampleBatch.REWARDS], batch[SampleBatch.NEXT_OBS],
batch[SampleBatch.DONES], batch[PRIO_WEIGHTS])
new_priorities = (
np.abs(td_errors) + policy.config["prioritized_replay_eps"])
np.abs(convert_to_numpy(td_errors)) +
policy.config["prioritized_replay_eps"])
batch.data[PRIO_WEIGHTS] = new_priorities

return batch
Expand Down
2 changes: 1 addition & 1 deletion rllib/agents/es/es_torch_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def _compute_actions(policy,
observation = policy.observation_filter(
observation[None], update=update)

observation = convert_to_torch_tensor(observation)
observation = convert_to_torch_tensor(observation, policy.device)
dist_inputs, _ = policy.model({
SampleBatch.CUR_OBS: observation
}, [], None)
Expand Down
13 changes: 7 additions & 6 deletions rllib/agents/ppo/ppo_torch_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,15 +188,16 @@ def __init__(self, obs_space, action_space, config):
def value(ob, prev_action, prev_reward, *state):
model_out, _ = self.model({
SampleBatch.CUR_OBS: convert_to_torch_tensor(
np.asarray([ob])),
np.asarray([ob]), self.device),
SampleBatch.PREV_ACTIONS: convert_to_torch_tensor(
np.asarray([prev_action])),
np.asarray([prev_action]), self.device),
SampleBatch.PREV_REWARDS: convert_to_torch_tensor(
np.asarray([prev_reward])),
np.asarray([prev_reward]), self.device),
"is_training": False,
}, [convert_to_torch_tensor(np.asarray([s])) for s in state],
convert_to_torch_tensor(
np.asarray([1])))
}, [convert_to_torch_tensor(np.asarray([s]), self.device) for
s in state],
convert_to_torch_tensor(
np.asarray([1]), self.device))
return self.model.value_function()[0]

else:
Expand Down
11 changes: 7 additions & 4 deletions rllib/examples/centralized_critic.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,13 @@ def centralized_critic_postprocessing(policy,
# overwrite default VF prediction with the central VF
if args.torch:
sample_batch[SampleBatch.VF_PREDS] = policy.compute_central_vf(
convert_to_torch_tensor(sample_batch[SampleBatch.CUR_OBS]),
convert_to_torch_tensor(sample_batch[OPPONENT_OBS]),
convert_to_torch_tensor(sample_batch[OPPONENT_ACTION])). \
detach().numpy()
convert_to_torch_tensor(
sample_batch[SampleBatch.CUR_OBS], policy.device),
convert_to_torch_tensor(
sample_batch[OPPONENT_OBS], policy.device),
convert_to_torch_tensor(
sample_batch[OPPONENT_ACTION], policy.device)) \
.detach().numpy()
else:
sample_batch[SampleBatch.VF_PREDS] = policy.compute_central_vf(
sample_batch[SampleBatch.CUR_OBS], sample_batch[OPPONENT_OBS],
Expand Down
3 changes: 2 additions & 1 deletion rllib/models/tf/tf_action_dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,12 +178,13 @@ def __init__(self, inputs, model=None, temperature=1.0):
assert temperature >= 0.0
self.dist = tfp.distributions.RelaxedOneHotCategorical(
temperature=temperature, logits=inputs)
self.probs = tf.nn.softmax(self.dist._distribution.logits)
super().__init__(inputs, model)

@override(ActionDistribution)
def deterministic_sample(self):
# Return the dist object's prob values.
return self.dist._distribution.probs
return self.probs

@override(ActionDistribution)
def logp(self, x):
Expand Down
6 changes: 4 additions & 2 deletions rllib/policy/torch_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,8 @@ def compute_actions(
input_dict[SampleBatch.PREV_REWARDS] = \
np.asarray(prev_reward_batch)
state_batches = [
convert_to_torch_tensor(s) for s in (state_batches or [])
convert_to_torch_tensor(s, self.device)
for s in (state_batches or [])
]
actions, state_out, extra_fetches, logp = \
self._compute_action_helper(
Expand Down Expand Up @@ -556,7 +557,8 @@ def import_model_from_h5(self, import_file: str) -> None:

def _lazy_tensor_dict(self, postprocessed_batch):
train_batch = UsageTrackingDict(postprocessed_batch)
train_batch.set_get_interceptor(convert_to_torch_tensor)
train_batch.set_get_interceptor(functools.partial(
convert_to_torch_tensor, device=self.device))
return train_batch


Expand Down
4 changes: 3 additions & 1 deletion rllib/utils/exploration/per_worker_epsilon_greedy.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,10 @@ def __init__(self, action_space, *, framework, num_workers, worker_index,
if worker_index > 0:
# From page 5 of https://arxiv.org/pdf/1803.00933.pdf
alpha, eps, i = 7, 0.4, worker_index - 1
num_workers_minus_1 = float(num_workers - 1) \
if num_workers > 1 else 1.0
epsilon_schedule = ConstantSchedule(
eps**(1 + i / float(num_workers - 1) * alpha),
eps**(1 + (i / num_workers_minus_1) * alpha),
framework=framework)
# Local worker should have zero exploration so that eval
# rollouts run properly.
Expand Down
4 changes: 3 additions & 1 deletion rllib/utils/exploration/per_worker_gaussian_noise.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ def __init__(self, action_space, *, framework, num_workers, worker_index,
# Use a fixed, different epsilon per worker. See: Ape-X paper.
if num_workers > 0:
if worker_index > 0:
exponent = (1 + worker_index / float(num_workers - 1) * 7)
num_workers_minus_1 = float(num_workers - 1) \
if num_workers > 1 else 1.0
exponent = (1 + (worker_index / num_workers_minus_1) * 7)
scale_schedule = ConstantSchedule(
0.4**exponent, framework=framework)
# Local worker should have zero exploration so that eval
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@ def __init__(self, action_space, *, framework, num_workers, worker_index,
# Use a fixed, different epsilon per worker. See: Ape-X paper.
if num_workers > 0:
if worker_index > 0:
exponent = (1 + worker_index / float(num_workers - 1) * 7)
num_workers_minus_1 = float(num_workers - 1) \
if num_workers > 1 else 1.0
exponent = (1 + (worker_index / num_workers_minus_1) * 7)
scale_schedule = ConstantSchedule(
0.4**exponent, framework=framework)
# Local worker should have zero exploration so that eval
Expand Down
6 changes: 3 additions & 3 deletions rllib/utils/torch_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,10 +110,10 @@ def mapping(item):
return tree.map_structure(mapping, stats)


def convert_to_torch_tensor(stats, device=None):
def convert_to_torch_tensor(x, device=None):
"""Converts any struct to torch.Tensors.
stats (any): Any (possibly nested) struct, the values in which will be
x (any): Any (possibly nested) struct, the values in which will be
converted and returned as a new struct with all leaves converted
to torch tensors.
Expand All @@ -137,7 +137,7 @@ def mapping(item):
tensor = tensor.float()
return tensor if device is None else tensor.to(device)

return tree.map_structure(mapping, stats)
return tree.map_structure(mapping, x)


def atanh(x):
Expand Down

0 comments on commit 935d830

Please sign in to comment.