Skip to content

Commit

Permalink
[rllib] Fix APEX priorities returning zero all the time (ray-project#…
Browse files Browse the repository at this point in the history
…5980)

* fix

* move example tests to end

* level err

* guard against none

* no trace test

* ignore thumbs

* np

* fix multi node

* fix
  • Loading branch information
ericl authored Oct 26, 2019
1 parent 0bb922c commit a0dcb45
Show file tree
Hide file tree
Showing 6 changed files with 15 additions and 11 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ scripts/nodes.txt

# Generated documentation files
/doc/_build
/doc/source/_static/thumbs

# User-specific stuff:
.idea/**/workspace.xml
Expand Down
5 changes: 5 additions & 0 deletions doc/source/rllib-toc.rst
Original file line number Diff line number Diff line change
Expand Up @@ -171,3 +171,8 @@ If you encounter out-of-memory errors, consider setting ``redis_max_memory`` and
For debugging unexpected hangs or performance problems, you can run ``ray stack`` to dump
the stack traces of all Ray workers on the current node, and ``ray timeline`` to dump
a timeline visualization of tasks to a file.

TensorFlow 2.0
~~~~~~~~~~~~~~

RLlib currently runs in ``tf.compat.v1`` mode. This means eager execution is disabled by default, and RLlib imports TF with ``import tensorflow.compat.v1 as tf; tf.disable_v2_behaviour()``. Eager execution can be enabled manually by calling ``tf.enable_eager_execution()`` or setting the ``"eager": True`` trainer config.
3 changes: 0 additions & 3 deletions rllib/agents/dqn/dqn_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,9 +140,6 @@ def __init__(self):
@make_tf_callable(self.get_session(), dynamic_shape=True)
def compute_td_error(obs_t, act_t, rew_t, obs_tp1, done_mask,
importance_weights):
if not self.loss_initialized():
return tf.zeros_like(rew_t)

# Do forward pass on loss to update td error attribute
build_q_losses(
self, self.model, None, {
Expand Down
3 changes: 0 additions & 3 deletions rllib/agents/sac/sac_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,9 +290,6 @@ def __init__(self):
@make_tf_callable(self.get_session(), dynamic_shape=True)
def compute_td_error(obs_t, act_t, rew_t, obs_tp1, done_mask,
importance_weights):
if not self.loss_initialized():
return tf.zeros_like(rew_t)

# Do forward pass on loss to update td error attribute
actor_critic_loss(
self, self.model, None, {
Expand Down
3 changes: 2 additions & 1 deletion rllib/policy/eager_tf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ def _convert_to_tf(x):
return x

if x is not None:
x = tf.nest.map_structure(tf.convert_to_tensor, x)
x = tf.nest.map_structure(
lambda f: tf.convert_to_tensor(f) if f is not None else None, x)
return x


Expand Down
11 changes: 7 additions & 4 deletions rllib/tests/test_eager_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,21 @@
from ray.rllib.agents.registry import get_agent_class


def check_support(alg, config):
def check_support(alg, config, test_trace=True):
config["eager"] = True
if alg in ["APEX_DDPG", "TD3", "DDPG", "SAC"]:
config["env"] = "Pendulum-v0"
else:
config["env"] = "CartPole-v0"
a = get_agent_class(alg)
config["log_level"] = "ERROR"

config["eager_tracing"] = False
tune.run(a, config=config, stop={"training_iteration": 0})

config["eager_tracing"] = True
tune.run(a, config=config, stop={"training_iteration": 0})
if test_trace:
config["eager_tracing"] = True
tune.run(a, config=config, stop={"training_iteration": 0})


class TestEagerSupport(unittest.TestCase):
Expand All @@ -37,7 +39,8 @@ def testA2C(self):
check_support("A2C", {"num_workers": 0})

def testA3C(self):
check_support("A3C", {"num_workers": 1})
# TODO(ekl) trace on is flaky
check_support("A3C", {"num_workers": 1}, test_trace=False)

def testPG(self):
check_support("PG", {"num_workers": 0})
Expand Down

0 comments on commit a0dcb45

Please sign in to comment.