Skip to content

Commit

Permalink
Allow replay ops to stop if they are unhealthy (#36)
Browse files Browse the repository at this point in the history
* Allow the replay ops to stop if they are unhealthy

* Allowing to configure dqn execution plan consistently
  • Loading branch information
Edilmo authored Sep 10, 2020
1 parent 130b053 commit ee92c5c
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 19 deletions.
27 changes: 24 additions & 3 deletions python/ray/util/iter.py
Original file line number Diff line number Diff line change
Expand Up @@ -666,6 +666,13 @@ class LocalIterator(Generic[T]):
# used to measure the underlying wait latency for measurement purposes.
ON_FETCH_START_HOOK_NAME = "_on_fetch_start"

# If a function passed to LocalIterator.for_each() has this method,
# we will call it each not ready value condition. This can be
# used to implement an early stop of the iterator by stoping
# the iteration, or any default value by returning something.
# This method should expect the same arguments of the __call__.
HANDLE_NEXT_VALUE_NOT_READY_HOOK_NAME = "_handle_next_value_not_ready"

thread_local = threading.local()

def __init__(self,
Expand Down Expand Up @@ -743,7 +750,12 @@ def for_each(self, fn: Callable[[T], U], max_concurrency=1,
def apply_foreach(it):
for item in it:
if isinstance(item, _NextValueNotReady):
yield item
if hasattr(fn, LocalIterator.HANDLE_NEXT_VALUE_NOT_READY_HOOK_NAME):
with self._metrics_context():
result = fn._handle_next_value_not_ready(item)
yield result
else:
yield item
else:
# Keep retrying the function until it returns a valid
# value. This allows for non-blocking functions.
Expand Down Expand Up @@ -1059,15 +1071,17 @@ def build_union(timeout=None):
else:
yield_counts[i] += 1
yield item
except StopIteration:
except StopIteration as ex:
fix_weights = [
w != "*" for w in round_robin_weights
]
expected_yield_counts = weight if weight != "*" else MAX_PULL
if (any(fix_weights) and
yield_counts[i] < expected_yield_counts and
pull_counts[i] >= MAX_PULL):
raise
raise ex
elif isinstance(ex, ForceIteratorStopIteration):
raise ex
else:
removed_iter_indices.append(i)
active.remove((weight, it))
Expand Down Expand Up @@ -1231,3 +1245,10 @@ def init_actors(self):

def with_transform(self, fn):
return _ActorSet(self.actors, self.transforms + [fn])


class ForceIteratorStopIteration(StopIteration):
"""
Indicates that the iterator should stop yielding regardless of which at point
is located.
"""
8 changes: 6 additions & 2 deletions rllib/agents/dqn/apex.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,9 +135,13 @@ def update_prio_and_stats(item: ("ActorHandle", dict, int)):

# (2) Read experiences from the replay buffer actors and send to the
# learner thread via its in-queue.
post_fn = config.get("before_learn_on_batch") or (lambda b, *a: b)
if config.get("before_learn_on_batch"):
before_learn_on_batch = config["before_learn_on_batch"]
before_learn_on_batch = before_learn_on_batch(workers, config)
else:
before_learn_on_batch = lambda b: b
replay_op = Replay(actors=replay_actors, num_async=4) \
.for_each(lambda x: post_fn(x, workers, config)) \
.for_each(before_learn_on_batch) \
.zip_with_source_actor() \
.for_each(Enqueue(learner_thread.inqueue))

Expand Down
24 changes: 17 additions & 7 deletions rllib/agents/dqn/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,21 +291,31 @@ def update_prio(item):
# (2) Read and train on experiences from the replay buffer. Every batch
# returned from the LocalReplay() iterator is passed to TrainOneStep to
# take a SGD step, and then we decide whether to update the target network.
post_fn = config.get("before_learn_on_batch") or (lambda b, *a: b)
if config.get("before_learn_on_batch"):
before_learn_on_batch = config["before_learn_on_batch"]
before_learn_on_batch = before_learn_on_batch(workers, config)
else:
before_learn_on_batch = lambda b: b
replay_op = Replay(local_buffer=local_replay_buffer) \
.for_each(lambda x: post_fn(x, workers, config)) \
.for_each(before_learn_on_batch) \
.for_each(TrainOneStep(workers)) \
.for_each(update_prio) \
.for_each(UpdateTargetNetwork(
workers, config["target_network_update_freq"]))

# Alternate deterministically between (1) and (2). Only return the output
# of (2) since training metrics are not available until (2) runs.
train_op = Concurrently(
[store_op, replay_op],
mode="round_robin",
output_indexes=[1],
round_robin_weights=calculate_rr_weights(config))
if parallel_rollouts_mode == "bulk_sync":
train_op = Concurrently(
[store_op, replay_op],
mode="round_robin",
output_indexes=[1],
round_robin_weights=calculate_rr_weights(config))
else:
train_op = Concurrently(
[store_op, replay_op],
mode="async",
output_indexes=[1])

return StandardMetricsReporting(train_op, workers, config)

Expand Down
19 changes: 12 additions & 7 deletions rllib/contrib/maddpg/maddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,13 +157,18 @@ def add_maddpg_postprocessing(config):
setups for DQN and APEX.
"""

def f(batch, workers, config):
policies = dict(workers.local_worker()
.foreach_trainable_policy(lambda p, i: (i, p)))
return before_learn_on_batch(batch, policies,
config["train_batch_size"])

config["before_learn_on_batch"] = f
class _CustomBeforeLearnOnBatch:
def __init__(self, workers, config):
self.workers = workers
self.config = config

def __call__(self, batch):
policies = dict(self.workers.local_worker()
.foreach_trainable_policy(lambda p, i: (i, p)))
return before_learn_on_batch(batch, policies,
self.config["train_batch_size"])

config["before_learn_on_batch"] = _CustomBeforeLearnOnBatch
return config


Expand Down

0 comments on commit ee92c5c

Please sign in to comment.